vhost-user: support protocol features
[dpdk.git] / lib / librte_vhost / vhost_user / vhost-net-user.c
1 /*-
2  *   BSD LICENSE
3  *
4  *   Copyright(c) 2010-2014 Intel Corporation. All rights reserved.
5  *   All rights reserved.
6  *
7  *   Redistribution and use in source and binary forms, with or without
8  *   modification, are permitted provided that the following conditions
9  *   are met:
10  *
11  *     * Redistributions of source code must retain the above copyright
12  *       notice, this list of conditions and the following disclaimer.
13  *     * Redistributions in binary form must reproduce the above copyright
14  *       notice, this list of conditions and the following disclaimer in
15  *       the documentation and/or other materials provided with the
16  *       distribution.
17  *     * Neither the name of Intel Corporation nor the names of its
18  *       contributors may be used to endorse or promote products derived
19  *       from this software without specific prior written permission.
20  *
21  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <stdint.h>
35 #include <stdio.h>
36 #include <limits.h>
37 #include <stdlib.h>
38 #include <unistd.h>
39 #include <string.h>
40 #include <sys/types.h>
41 #include <sys/socket.h>
42 #include <sys/un.h>
43 #include <errno.h>
44 #include <pthread.h>
45
46 #include <rte_log.h>
47 #include <rte_virtio_net.h>
48
49 #include "fd_man.h"
50 #include "vhost-net-user.h"
51 #include "vhost-net.h"
52 #include "virtio-net-user.h"
53
54 #define MAX_VIRTIO_BACKLOG 128
55
56 static void vserver_new_vq_conn(int fd, void *data, int *remove);
57 static void vserver_message_handler(int fd, void *dat, int *remove);
58 struct vhost_net_device_ops const *ops;
59
60 struct connfd_ctx {
61         struct vhost_server *vserver;
62         uint32_t fh;
63 };
64
65 #define MAX_VHOST_SERVER 1024
66 struct _vhost_server {
67         struct vhost_server *server[MAX_VHOST_SERVER];
68         struct fdset fdset;
69         int vserver_cnt;
70         pthread_mutex_t server_mutex;
71 };
72
73 static struct _vhost_server g_vhost_server = {
74         .fdset = {
75                 .fd = { [0 ... MAX_FDS - 1] = {-1, NULL, NULL, NULL, 0} },
76                 .fd_mutex = PTHREAD_MUTEX_INITIALIZER,
77                 .num = 0
78         },
79         .vserver_cnt = 0,
80         .server_mutex = PTHREAD_MUTEX_INITIALIZER,
81 };
82
83 static const char *vhost_message_str[VHOST_USER_MAX] = {
84         [VHOST_USER_NONE] = "VHOST_USER_NONE",
85         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
86         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
87         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
88         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
89         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
90         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
91         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
92         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
93         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
94         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
95         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
96         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
97         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
98         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
99         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
100         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
101 };
102
103 /**
104  * Create a unix domain socket, bind to path and listen for connection.
105  * @return
106  *  socket fd or -1 on failure
107  */
108 static int
109 uds_socket(const char *path)
110 {
111         struct sockaddr_un un;
112         int sockfd;
113         int ret;
114
115         if (path == NULL)
116                 return -1;
117
118         sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
119         if (sockfd < 0)
120                 return -1;
121         RTE_LOG(INFO, VHOST_CONFIG, "socket created, fd:%d\n", sockfd);
122
123         memset(&un, 0, sizeof(un));
124         un.sun_family = AF_UNIX;
125         snprintf(un.sun_path, sizeof(un.sun_path), "%s", path);
126         ret = bind(sockfd, (struct sockaddr *)&un, sizeof(un));
127         if (ret == -1) {
128                 RTE_LOG(ERR, VHOST_CONFIG, "fail to bind fd:%d, remove file:%s and try again.\n",
129                         sockfd, path);
130                 goto err;
131         }
132         RTE_LOG(INFO, VHOST_CONFIG, "bind to %s\n", path);
133
134         ret = listen(sockfd, MAX_VIRTIO_BACKLOG);
135         if (ret == -1)
136                 goto err;
137
138         return sockfd;
139
140 err:
141         close(sockfd);
142         return -1;
143 }
144
145 /* return bytes# of read on success or negative val on failure. */
146 static int
147 read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
148 {
149         struct iovec iov;
150         struct msghdr msgh;
151         size_t fdsize = fd_num * sizeof(int);
152         char control[CMSG_SPACE(fdsize)];
153         struct cmsghdr *cmsg;
154         int ret;
155
156         memset(&msgh, 0, sizeof(msgh));
157         iov.iov_base = buf;
158         iov.iov_len  = buflen;
159
160         msgh.msg_iov = &iov;
161         msgh.msg_iovlen = 1;
162         msgh.msg_control = control;
163         msgh.msg_controllen = sizeof(control);
164
165         ret = recvmsg(sockfd, &msgh, 0);
166         if (ret <= 0) {
167                 RTE_LOG(ERR, VHOST_CONFIG, "recvmsg failed\n");
168                 return ret;
169         }
170
171         if (msgh.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
172                 RTE_LOG(ERR, VHOST_CONFIG, "truncted msg\n");
173                 return -1;
174         }
175
176         for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != NULL;
177                 cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
178                 if ((cmsg->cmsg_level == SOL_SOCKET) &&
179                         (cmsg->cmsg_type == SCM_RIGHTS)) {
180                         memcpy(fds, CMSG_DATA(cmsg), fdsize);
181                         break;
182                 }
183         }
184
185         return ret;
186 }
187
188 /* return bytes# of read on success or negative val on failure. */
189 static int
190 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
191 {
192         int ret;
193
194         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
195                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
196         if (ret <= 0)
197                 return ret;
198
199         if (msg && msg->size) {
200                 if (msg->size > sizeof(msg->payload)) {
201                         RTE_LOG(ERR, VHOST_CONFIG,
202                                 "invalid msg size: %d\n", msg->size);
203                         return -1;
204                 }
205                 ret = read(sockfd, &msg->payload, msg->size);
206                 if (ret <= 0)
207                         return ret;
208                 if (ret != (int)msg->size) {
209                         RTE_LOG(ERR, VHOST_CONFIG,
210                                 "read control message failed\n");
211                         return -1;
212                 }
213         }
214
215         return ret;
216 }
217
218 static int
219 send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
220 {
221
222         struct iovec iov;
223         struct msghdr msgh;
224         size_t fdsize = fd_num * sizeof(int);
225         char control[CMSG_SPACE(fdsize)];
226         struct cmsghdr *cmsg;
227         int ret;
228
229         memset(&msgh, 0, sizeof(msgh));
230         iov.iov_base = buf;
231         iov.iov_len = buflen;
232
233         msgh.msg_iov = &iov;
234         msgh.msg_iovlen = 1;
235
236         if (fds && fd_num > 0) {
237                 msgh.msg_control = control;
238                 msgh.msg_controllen = sizeof(control);
239                 cmsg = CMSG_FIRSTHDR(&msgh);
240                 cmsg->cmsg_len = CMSG_LEN(fdsize);
241                 cmsg->cmsg_level = SOL_SOCKET;
242                 cmsg->cmsg_type = SCM_RIGHTS;
243                 memcpy(CMSG_DATA(cmsg), fds, fdsize);
244         } else {
245                 msgh.msg_control = NULL;
246                 msgh.msg_controllen = 0;
247         }
248
249         do {
250                 ret = sendmsg(sockfd, &msgh, 0);
251         } while (ret < 0 && errno == EINTR);
252
253         if (ret < 0) {
254                 RTE_LOG(ERR, VHOST_CONFIG,  "sendmsg error\n");
255                 return ret;
256         }
257
258         return ret;
259 }
260
261 static int
262 send_vhost_message(int sockfd, struct VhostUserMsg *msg)
263 {
264         int ret;
265
266         if (!msg)
267                 return 0;
268
269         msg->flags &= ~VHOST_USER_VERSION_MASK;
270         msg->flags |= VHOST_USER_VERSION;
271         msg->flags |= VHOST_USER_REPLY_MASK;
272
273         ret = send_fd_message(sockfd, (char *)msg,
274                 VHOST_USER_HDR_SIZE + msg->size, NULL, 0);
275
276         return ret;
277 }
278
279 /* call back when there is new virtio connection.  */
280 static void
281 vserver_new_vq_conn(int fd, void *dat, __rte_unused int *remove)
282 {
283         struct vhost_server *vserver = (struct vhost_server *)dat;
284         int conn_fd;
285         struct connfd_ctx *ctx;
286         int fh;
287         struct vhost_device_ctx vdev_ctx = { (pid_t)0, 0 };
288         unsigned int size;
289
290         conn_fd = accept(fd, NULL, NULL);
291         RTE_LOG(INFO, VHOST_CONFIG,
292                 "new virtio connection is %d\n", conn_fd);
293         if (conn_fd < 0)
294                 return;
295
296         ctx = calloc(1, sizeof(*ctx));
297         if (ctx == NULL) {
298                 close(conn_fd);
299                 return;
300         }
301
302         fh = ops->new_device(vdev_ctx);
303         if (fh == -1) {
304                 free(ctx);
305                 close(conn_fd);
306                 return;
307         }
308
309         vdev_ctx.fh = fh;
310         size = strnlen(vserver->path, PATH_MAX);
311         ops->set_ifname(vdev_ctx, vserver->path,
312                 size);
313
314         RTE_LOG(INFO, VHOST_CONFIG, "new device, handle is %d\n", fh);
315
316         ctx->vserver = vserver;
317         ctx->fh = fh;
318         fdset_add(&g_vhost_server.fdset,
319                 conn_fd, vserver_message_handler, NULL, ctx);
320 }
321
322 /* callback when there is message on the connfd */
323 static void
324 vserver_message_handler(int connfd, void *dat, int *remove)
325 {
326         struct vhost_device_ctx ctx;
327         struct connfd_ctx *cfd_ctx = (struct connfd_ctx *)dat;
328         struct VhostUserMsg msg;
329         uint64_t features;
330         int ret;
331
332         ctx.fh = cfd_ctx->fh;
333         ret = read_vhost_message(connfd, &msg);
334         if (ret <= 0 || msg.request > VHOST_USER_MAX) {
335                 if (ret < 0)
336                         RTE_LOG(ERR, VHOST_CONFIG,
337                                 "vhost read message failed\n");
338                 else if (ret == 0)
339                         RTE_LOG(INFO, VHOST_CONFIG,
340                                 "vhost peer closed\n");
341                 else
342                         RTE_LOG(ERR, VHOST_CONFIG,
343                                 "vhost read incorrect message\n");
344
345                 close(connfd);
346                 *remove = 1;
347                 free(cfd_ctx);
348                 user_destroy_device(ctx);
349                 ops->destroy_device(ctx);
350
351                 return;
352         }
353
354         RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
355                 vhost_message_str[msg.request]);
356         switch (msg.request) {
357         case VHOST_USER_GET_FEATURES:
358                 ret = ops->get_features(ctx, &features);
359                 msg.payload.u64 = features;
360                 msg.size = sizeof(msg.payload.u64);
361                 send_vhost_message(connfd, &msg);
362                 break;
363         case VHOST_USER_SET_FEATURES:
364                 features = msg.payload.u64;
365                 ops->set_features(ctx, &features);
366                 break;
367
368         case VHOST_USER_GET_PROTOCOL_FEATURES:
369                 msg.payload.u64 = VHOST_USER_PROTOCOL_FEATURES;
370                 msg.size = sizeof(msg.payload.u64);
371                 send_vhost_message(connfd, &msg);
372                 break;
373         case VHOST_USER_SET_PROTOCOL_FEATURES:
374                 user_set_protocol_features(ctx, msg.payload.u64);
375                 break;
376
377         case VHOST_USER_SET_OWNER:
378                 ops->set_owner(ctx);
379                 break;
380         case VHOST_USER_RESET_OWNER:
381                 ops->reset_owner(ctx);
382                 break;
383
384         case VHOST_USER_SET_MEM_TABLE:
385                 user_set_mem_table(ctx, &msg);
386                 break;
387
388         case VHOST_USER_SET_LOG_BASE:
389                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
390         case VHOST_USER_SET_LOG_FD:
391                 close(msg.fds[0]);
392                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
393                 break;
394
395         case VHOST_USER_SET_VRING_NUM:
396                 ops->set_vring_num(ctx, &msg.payload.state);
397                 break;
398         case VHOST_USER_SET_VRING_ADDR:
399                 ops->set_vring_addr(ctx, &msg.payload.addr);
400                 break;
401         case VHOST_USER_SET_VRING_BASE:
402                 ops->set_vring_base(ctx, &msg.payload.state);
403                 break;
404
405         case VHOST_USER_GET_VRING_BASE:
406                 ret = user_get_vring_base(ctx, &msg.payload.state);
407                 msg.size = sizeof(msg.payload.state);
408                 send_vhost_message(connfd, &msg);
409                 break;
410
411         case VHOST_USER_SET_VRING_KICK:
412                 user_set_vring_kick(ctx, &msg);
413                 break;
414         case VHOST_USER_SET_VRING_CALL:
415                 user_set_vring_call(ctx, &msg);
416                 break;
417
418         case VHOST_USER_SET_VRING_ERR:
419                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
420                         close(msg.fds[0]);
421                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
422                 break;
423
424         default:
425                 break;
426
427         }
428 }
429
430 /**
431  * Creates and initialise the vhost server.
432  */
433 int
434 rte_vhost_driver_register(const char *path)
435 {
436         struct vhost_server *vserver;
437
438         pthread_mutex_lock(&g_vhost_server.server_mutex);
439         if (ops == NULL)
440                 ops = get_virtio_net_callbacks();
441
442         if (g_vhost_server.vserver_cnt == MAX_VHOST_SERVER) {
443                 RTE_LOG(ERR, VHOST_CONFIG,
444                         "error: the number of servers reaches maximum\n");
445                 pthread_mutex_unlock(&g_vhost_server.server_mutex);
446                 return -1;
447         }
448
449         vserver = calloc(sizeof(struct vhost_server), 1);
450         if (vserver == NULL) {
451                 pthread_mutex_unlock(&g_vhost_server.server_mutex);
452                 return -1;
453         }
454
455         vserver->listenfd = uds_socket(path);
456         if (vserver->listenfd < 0) {
457                 free(vserver);
458                 pthread_mutex_unlock(&g_vhost_server.server_mutex);
459                 return -1;
460         }
461
462         vserver->path = strdup(path);
463
464         fdset_add(&g_vhost_server.fdset, vserver->listenfd,
465                 vserver_new_vq_conn, NULL, vserver);
466
467         g_vhost_server.server[g_vhost_server.vserver_cnt++] = vserver;
468         pthread_mutex_unlock(&g_vhost_server.server_mutex);
469
470         return 0;
471 }
472
473
474 /**
475  * Unregister the specified vhost server
476  */
477 int
478 rte_vhost_driver_unregister(const char *path)
479 {
480         int i;
481         int count;
482
483         pthread_mutex_lock(&g_vhost_server.server_mutex);
484
485         for (i = 0; i < g_vhost_server.vserver_cnt; i++) {
486                 if (!strcmp(g_vhost_server.server[i]->path, path)) {
487                         fdset_del(&g_vhost_server.fdset,
488                                 g_vhost_server.server[i]->listenfd);
489
490                         close(g_vhost_server.server[i]->listenfd);
491                         free(g_vhost_server.server[i]->path);
492                         free(g_vhost_server.server[i]);
493
494                         unlink(path);
495
496                         count = --g_vhost_server.vserver_cnt;
497                         g_vhost_server.server[i] = g_vhost_server.server[count];
498                         g_vhost_server.server[count] = NULL;
499                         pthread_mutex_unlock(&g_vhost_server.server_mutex);
500
501                         return 0;
502                 }
503         }
504         pthread_mutex_unlock(&g_vhost_server.server_mutex);
505
506         return -1;
507 }
508
509 int
510 rte_vhost_driver_session_start(void)
511 {
512         fdset_event_dispatch(&g_vhost_server.fdset);
513         return 0;
514 }