vhost: rename structs for enabling client mode
[dpdk.git] / lib / librte_vhost / vhost_user / vhost-net-user.c
1 /*-
2  *   BSD LICENSE
3  *
4  *   Copyright(c) 2010-2014 Intel Corporation. All rights reserved.
5  *   All rights reserved.
6  *
7  *   Redistribution and use in source and binary forms, with or without
8  *   modification, are permitted provided that the following conditions
9  *   are met:
10  *
11  *     * Redistributions of source code must retain the above copyright
12  *       notice, this list of conditions and the following disclaimer.
13  *     * Redistributions in binary form must reproduce the above copyright
14  *       notice, this list of conditions and the following disclaimer in
15  *       the documentation and/or other materials provided with the
16  *       distribution.
17  *     * Neither the name of Intel Corporation nor the names of its
18  *       contributors may be used to endorse or promote products derived
19  *       from this software without specific prior written permission.
20  *
21  *   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22  *   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23  *   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24  *   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25  *   OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26  *   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27  *   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28  *   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29  *   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30  *   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31  *   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32  */
33
34 #include <stdint.h>
35 #include <stdio.h>
36 #include <limits.h>
37 #include <stdlib.h>
38 #include <unistd.h>
39 #include <string.h>
40 #include <sys/types.h>
41 #include <sys/socket.h>
42 #include <sys/un.h>
43 #include <errno.h>
44 #include <pthread.h>
45
46 #include <rte_log.h>
47 #include <rte_virtio_net.h>
48
49 #include "fd_man.h"
50 #include "vhost-net-user.h"
51 #include "vhost-net.h"
52 #include "virtio-net-user.h"
53
54 /*
55  * Every time rte_vhost_driver_register() is invoked, an associated
56  * vhost_user_socket struct will be created.
57  */
58 struct vhost_user_socket {
59         char *path;
60         int listenfd;
61 };
62
63 struct vhost_user_connection {
64         struct vhost_user_socket *vsocket;
65         int vid;
66 };
67
68 #define MAX_VHOST_SOCKET 1024
69 struct vhost_user {
70         struct vhost_user_socket *vsockets[MAX_VHOST_SOCKET];
71         struct fdset fdset;
72         int vsocket_cnt;
73         pthread_mutex_t mutex;
74 };
75
76 #define MAX_VIRTIO_BACKLOG 128
77
78 static void vhost_user_new_connection(int fd, void *data, int *remove);
79 static void vhost_user_msg_handler(int fd, void *dat, int *remove);
80
81 static struct vhost_user vhost_user = {
82         .fdset = {
83                 .fd = { [0 ... MAX_FDS - 1] = {-1, NULL, NULL, NULL, 0} },
84                 .fd_mutex = PTHREAD_MUTEX_INITIALIZER,
85                 .num = 0
86         },
87         .vsocket_cnt = 0,
88         .mutex = PTHREAD_MUTEX_INITIALIZER,
89 };
90
91 static const char *vhost_message_str[VHOST_USER_MAX] = {
92         [VHOST_USER_NONE] = "VHOST_USER_NONE",
93         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
94         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
95         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
96         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
97         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
98         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
99         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
100         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
101         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
102         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
103         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
104         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
105         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
106         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
107         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
108         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
109         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
110         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
111         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
112 };
113
114 /**
115  * Create a unix domain socket, bind to path and listen for connection.
116  * @return
117  *  socket fd or -1 on failure
118  */
119 static int
120 uds_socket(const char *path)
121 {
122         struct sockaddr_un un;
123         int sockfd;
124         int ret;
125
126         if (path == NULL)
127                 return -1;
128
129         sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
130         if (sockfd < 0)
131                 return -1;
132         RTE_LOG(INFO, VHOST_CONFIG, "socket created, fd:%d\n", sockfd);
133
134         memset(&un, 0, sizeof(un));
135         un.sun_family = AF_UNIX;
136         snprintf(un.sun_path, sizeof(un.sun_path), "%s", path);
137         ret = bind(sockfd, (struct sockaddr *)&un, sizeof(un));
138         if (ret == -1) {
139                 RTE_LOG(ERR, VHOST_CONFIG, "fail to bind fd:%d, remove file:%s and try again.\n",
140                         sockfd, path);
141                 goto err;
142         }
143         RTE_LOG(INFO, VHOST_CONFIG, "bind to %s\n", path);
144
145         ret = listen(sockfd, MAX_VIRTIO_BACKLOG);
146         if (ret == -1)
147                 goto err;
148
149         return sockfd;
150
151 err:
152         close(sockfd);
153         return -1;
154 }
155
156 /* return bytes# of read on success or negative val on failure. */
157 static int
158 read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
159 {
160         struct iovec iov;
161         struct msghdr msgh;
162         size_t fdsize = fd_num * sizeof(int);
163         char control[CMSG_SPACE(fdsize)];
164         struct cmsghdr *cmsg;
165         int ret;
166
167         memset(&msgh, 0, sizeof(msgh));
168         iov.iov_base = buf;
169         iov.iov_len  = buflen;
170
171         msgh.msg_iov = &iov;
172         msgh.msg_iovlen = 1;
173         msgh.msg_control = control;
174         msgh.msg_controllen = sizeof(control);
175
176         ret = recvmsg(sockfd, &msgh, 0);
177         if (ret <= 0) {
178                 RTE_LOG(ERR, VHOST_CONFIG, "recvmsg failed\n");
179                 return ret;
180         }
181
182         if (msgh.msg_flags & (MSG_TRUNC | MSG_CTRUNC)) {
183                 RTE_LOG(ERR, VHOST_CONFIG, "truncted msg\n");
184                 return -1;
185         }
186
187         for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != NULL;
188                 cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
189                 if ((cmsg->cmsg_level == SOL_SOCKET) &&
190                         (cmsg->cmsg_type == SCM_RIGHTS)) {
191                         memcpy(fds, CMSG_DATA(cmsg), fdsize);
192                         break;
193                 }
194         }
195
196         return ret;
197 }
198
199 /* return bytes# of read on success or negative val on failure. */
200 static int
201 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
202 {
203         int ret;
204
205         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
206                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
207         if (ret <= 0)
208                 return ret;
209
210         if (msg && msg->size) {
211                 if (msg->size > sizeof(msg->payload)) {
212                         RTE_LOG(ERR, VHOST_CONFIG,
213                                 "invalid msg size: %d\n", msg->size);
214                         return -1;
215                 }
216                 ret = read(sockfd, &msg->payload, msg->size);
217                 if (ret <= 0)
218                         return ret;
219                 if (ret != (int)msg->size) {
220                         RTE_LOG(ERR, VHOST_CONFIG,
221                                 "read control message failed\n");
222                         return -1;
223                 }
224         }
225
226         return ret;
227 }
228
229 static int
230 send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
231 {
232
233         struct iovec iov;
234         struct msghdr msgh;
235         size_t fdsize = fd_num * sizeof(int);
236         char control[CMSG_SPACE(fdsize)];
237         struct cmsghdr *cmsg;
238         int ret;
239
240         memset(&msgh, 0, sizeof(msgh));
241         iov.iov_base = buf;
242         iov.iov_len = buflen;
243
244         msgh.msg_iov = &iov;
245         msgh.msg_iovlen = 1;
246
247         if (fds && fd_num > 0) {
248                 msgh.msg_control = control;
249                 msgh.msg_controllen = sizeof(control);
250                 cmsg = CMSG_FIRSTHDR(&msgh);
251                 cmsg->cmsg_len = CMSG_LEN(fdsize);
252                 cmsg->cmsg_level = SOL_SOCKET;
253                 cmsg->cmsg_type = SCM_RIGHTS;
254                 memcpy(CMSG_DATA(cmsg), fds, fdsize);
255         } else {
256                 msgh.msg_control = NULL;
257                 msgh.msg_controllen = 0;
258         }
259
260         do {
261                 ret = sendmsg(sockfd, &msgh, 0);
262         } while (ret < 0 && errno == EINTR);
263
264         if (ret < 0) {
265                 RTE_LOG(ERR, VHOST_CONFIG,  "sendmsg error\n");
266                 return ret;
267         }
268
269         return ret;
270 }
271
272 static int
273 send_vhost_message(int sockfd, struct VhostUserMsg *msg)
274 {
275         int ret;
276
277         if (!msg)
278                 return 0;
279
280         msg->flags &= ~VHOST_USER_VERSION_MASK;
281         msg->flags |= VHOST_USER_VERSION;
282         msg->flags |= VHOST_USER_REPLY_MASK;
283
284         ret = send_fd_message(sockfd, (char *)msg,
285                 VHOST_USER_HDR_SIZE + msg->size, NULL, 0);
286
287         return ret;
288 }
289
290 /* call back when there is new vhost-user connection.  */
291 static void
292 vhost_user_new_connection(int fd, void *dat, int *remove __rte_unused)
293 {
294         struct vhost_user_socket *vsocket = dat;
295         int conn_fd;
296         struct vhost_user_connection *conn;
297         int vid;
298         unsigned int size;
299
300         conn_fd = accept(fd, NULL, NULL);
301         RTE_LOG(INFO, VHOST_CONFIG,
302                 "new virtio connection is %d\n", conn_fd);
303         if (conn_fd < 0)
304                 return;
305
306         conn = calloc(1, sizeof(*conn));
307         if (conn == NULL) {
308                 close(conn_fd);
309                 return;
310         }
311
312         vid = vhost_new_device();
313         if (vid == -1) {
314                 free(conn);
315                 close(conn_fd);
316                 return;
317         }
318
319         size = strnlen(vsocket->path, PATH_MAX);
320         vhost_set_ifname(vid, vsocket->path, size);
321
322         RTE_LOG(INFO, VHOST_CONFIG, "new device, handle is %d\n", vid);
323
324         conn->vsocket = vsocket;
325         conn->vid = vid;
326         fdset_add(&vhost_user.fdset,
327                 conn_fd, vhost_user_msg_handler, NULL, conn);
328 }
329
330 /* callback when there is message on the connfd */
331 static void
332 vhost_user_msg_handler(int connfd, void *dat, int *remove)
333 {
334         int vid;
335         struct vhost_user_connection *conn = dat;
336         struct VhostUserMsg msg;
337         uint64_t features;
338         int ret;
339
340         vid = conn->vid;
341         ret = read_vhost_message(connfd, &msg);
342         if (ret <= 0 || msg.request >= VHOST_USER_MAX) {
343                 if (ret < 0)
344                         RTE_LOG(ERR, VHOST_CONFIG,
345                                 "vhost read message failed\n");
346                 else if (ret == 0)
347                         RTE_LOG(INFO, VHOST_CONFIG,
348                                 "vhost peer closed\n");
349                 else
350                         RTE_LOG(ERR, VHOST_CONFIG,
351                                 "vhost read incorrect message\n");
352
353                 close(connfd);
354                 *remove = 1;
355                 free(conn);
356                 vhost_destroy_device(vid);
357
358                 return;
359         }
360
361         RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
362                 vhost_message_str[msg.request]);
363         switch (msg.request) {
364         case VHOST_USER_GET_FEATURES:
365                 ret = vhost_get_features(vid, &features);
366                 msg.payload.u64 = features;
367                 msg.size = sizeof(msg.payload.u64);
368                 send_vhost_message(connfd, &msg);
369                 break;
370         case VHOST_USER_SET_FEATURES:
371                 features = msg.payload.u64;
372                 vhost_set_features(vid, &features);
373                 break;
374
375         case VHOST_USER_GET_PROTOCOL_FEATURES:
376                 msg.payload.u64 = VHOST_USER_PROTOCOL_FEATURES;
377                 msg.size = sizeof(msg.payload.u64);
378                 send_vhost_message(connfd, &msg);
379                 break;
380         case VHOST_USER_SET_PROTOCOL_FEATURES:
381                 user_set_protocol_features(vid, msg.payload.u64);
382                 break;
383
384         case VHOST_USER_SET_OWNER:
385                 vhost_set_owner(vid);
386                 break;
387         case VHOST_USER_RESET_OWNER:
388                 vhost_reset_owner(vid);
389                 break;
390
391         case VHOST_USER_SET_MEM_TABLE:
392                 user_set_mem_table(vid, &msg);
393                 break;
394
395         case VHOST_USER_SET_LOG_BASE:
396                 user_set_log_base(vid, &msg);
397
398                 /* it needs a reply */
399                 msg.size = sizeof(msg.payload.u64);
400                 send_vhost_message(connfd, &msg);
401                 break;
402         case VHOST_USER_SET_LOG_FD:
403                 close(msg.fds[0]);
404                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
405                 break;
406
407         case VHOST_USER_SET_VRING_NUM:
408                 vhost_set_vring_num(vid, &msg.payload.state);
409                 break;
410         case VHOST_USER_SET_VRING_ADDR:
411                 vhost_set_vring_addr(vid, &msg.payload.addr);
412                 break;
413         case VHOST_USER_SET_VRING_BASE:
414                 vhost_set_vring_base(vid, &msg.payload.state);
415                 break;
416
417         case VHOST_USER_GET_VRING_BASE:
418                 ret = user_get_vring_base(vid, &msg.payload.state);
419                 msg.size = sizeof(msg.payload.state);
420                 send_vhost_message(connfd, &msg);
421                 break;
422
423         case VHOST_USER_SET_VRING_KICK:
424                 user_set_vring_kick(vid, &msg);
425                 break;
426         case VHOST_USER_SET_VRING_CALL:
427                 user_set_vring_call(vid, &msg);
428                 break;
429
430         case VHOST_USER_SET_VRING_ERR:
431                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
432                         close(msg.fds[0]);
433                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
434                 break;
435
436         case VHOST_USER_GET_QUEUE_NUM:
437                 msg.payload.u64 = VHOST_MAX_QUEUE_PAIRS;
438                 msg.size = sizeof(msg.payload.u64);
439                 send_vhost_message(connfd, &msg);
440                 break;
441
442         case VHOST_USER_SET_VRING_ENABLE:
443                 user_set_vring_enable(vid, &msg.payload.state);
444                 break;
445         case VHOST_USER_SEND_RARP:
446                 user_send_rarp(vid, &msg);
447                 break;
448
449         default:
450                 break;
451
452         }
453 }
454
455 /**
456  * Creates and initialise the vhost server.
457  */
458 int
459 rte_vhost_driver_register(const char *path)
460 {
461         struct vhost_user_socket *vsocket;
462
463         pthread_mutex_lock(&vhost_user.mutex);
464
465         if (vhost_user.vsocket_cnt == MAX_VHOST_SOCKET) {
466                 RTE_LOG(ERR, VHOST_CONFIG,
467                         "error: the number of servers reaches maximum\n");
468                 pthread_mutex_unlock(&vhost_user.mutex);
469                 return -1;
470         }
471
472         vsocket = calloc(sizeof(struct vhost_user_socket), 1);
473         if (vsocket == NULL) {
474                 pthread_mutex_unlock(&vhost_user.mutex);
475                 return -1;
476         }
477
478         vsocket->listenfd = uds_socket(path);
479         if (vsocket->listenfd < 0) {
480                 free(vsocket);
481                 pthread_mutex_unlock(&vhost_user.mutex);
482                 return -1;
483         }
484
485         vsocket->path = strdup(path);
486
487         fdset_add(&vhost_user.fdset, vsocket->listenfd,
488                 vhost_user_new_connection, NULL, vsocket);
489
490         vhost_user.vsockets[vhost_user.vsocket_cnt++] = vsocket;
491         pthread_mutex_unlock(&vhost_user.mutex);
492
493         return 0;
494 }
495
496
497 /**
498  * Unregister the specified vhost server
499  */
500 int
501 rte_vhost_driver_unregister(const char *path)
502 {
503         int i;
504         int count;
505
506         pthread_mutex_lock(&vhost_user.mutex);
507
508         for (i = 0; i < vhost_user.vsocket_cnt; i++) {
509                 if (!strcmp(vhost_user.vsockets[i]->path, path)) {
510                         fdset_del(&vhost_user.fdset,
511                                 vhost_user.vsockets[i]->listenfd);
512
513                         close(vhost_user.vsockets[i]->listenfd);
514                         free(vhost_user.vsockets[i]->path);
515                         free(vhost_user.vsockets[i]);
516
517                         unlink(path);
518
519                         count = --vhost_user.vsocket_cnt;
520                         vhost_user.vsockets[i] = vhost_user.vsockets[count];
521                         vhost_user.vsockets[count] = NULL;
522                         pthread_mutex_unlock(&vhost_user.mutex);
523
524                         return 0;
525                 }
526         }
527         pthread_mutex_unlock(&vhost_user.mutex);
528
529         return -1;
530 }
531
532 int
533 rte_vhost_driver_session_start(void)
534 {
535         fdset_event_dispatch(&vhost_user.fdset);
536         return 0;
537 }