573ef07f9b8dde4b0349bc2acd7fcf9a11089d65
[dpdk.git] / drivers / net / virtio / virtio_user / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2016 Intel Corporation
3  */
4
5 #include <sys/socket.h>
6 #include <sys/types.h>
7 #include <sys/stat.h>
8 #include <unistd.h>
9 #include <fcntl.h>
10 #include <sys/un.h>
11 #include <string.h>
12 #include <errno.h>
13
14 #include "vhost.h"
15 #include "virtio_user_dev.h"
16
17 /* The version of the protocol we support */
18 #define VHOST_USER_VERSION    0x1
19
20 #define VHOST_MEMORY_MAX_NREGIONS 8
21 struct vhost_memory {
22         uint32_t nregions;
23         uint32_t padding;
24         struct vhost_memory_region regions[VHOST_MEMORY_MAX_NREGIONS];
25 };
26
27 struct vhost_user_msg {
28         enum vhost_user_request request;
29
30 #define VHOST_USER_VERSION_MASK     0x3
31 #define VHOST_USER_REPLY_MASK       (0x1 << 2)
32         uint32_t flags;
33         uint32_t size; /* the following payload size */
34         union {
35 #define VHOST_USER_VRING_IDX_MASK   0xff
36 #define VHOST_USER_VRING_NOFD_MASK  (0x1 << 8)
37                 uint64_t u64;
38                 struct vhost_vring_state state;
39                 struct vhost_vring_addr addr;
40                 struct vhost_memory memory;
41         } payload;
42         int fds[VHOST_MEMORY_MAX_NREGIONS];
43 } __attribute((packed));
44
45 #define VHOST_USER_HDR_SIZE offsetof(struct vhost_user_msg, payload.u64)
46 #define VHOST_USER_PAYLOAD_SIZE \
47         (sizeof(struct vhost_user_msg) - VHOST_USER_HDR_SIZE)
48
49 static int
50 vhost_user_write(int fd, void *buf, int len, int *fds, int fd_num)
51 {
52         int r;
53         struct msghdr msgh;
54         struct iovec iov;
55         size_t fd_size = fd_num * sizeof(int);
56         char control[CMSG_SPACE(fd_size)];
57         struct cmsghdr *cmsg;
58
59         memset(&msgh, 0, sizeof(msgh));
60         memset(control, 0, sizeof(control));
61
62         iov.iov_base = (uint8_t *)buf;
63         iov.iov_len = len;
64
65         msgh.msg_iov = &iov;
66         msgh.msg_iovlen = 1;
67         msgh.msg_control = control;
68         msgh.msg_controllen = sizeof(control);
69
70         cmsg = CMSG_FIRSTHDR(&msgh);
71         cmsg->cmsg_len = CMSG_LEN(fd_size);
72         cmsg->cmsg_level = SOL_SOCKET;
73         cmsg->cmsg_type = SCM_RIGHTS;
74         memcpy(CMSG_DATA(cmsg), fds, fd_size);
75
76         do {
77                 r = sendmsg(fd, &msgh, 0);
78         } while (r < 0 && errno == EINTR);
79
80         return r;
81 }
82
83 static int
84 vhost_user_read(int fd, struct vhost_user_msg *msg)
85 {
86         uint32_t valid_flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION;
87         int ret, sz_hdr = VHOST_USER_HDR_SIZE, sz_payload;
88
89         ret = recv(fd, (void *)msg, sz_hdr, 0);
90         if (ret < sz_hdr) {
91                 PMD_DRV_LOG(ERR, "Failed to recv msg hdr: %d instead of %d.",
92                             ret, sz_hdr);
93                 goto fail;
94         }
95
96         /* validate msg flags */
97         if (msg->flags != (valid_flags)) {
98                 PMD_DRV_LOG(ERR, "Failed to recv msg: flags %x instead of %x.",
99                             msg->flags, valid_flags);
100                 goto fail;
101         }
102
103         sz_payload = msg->size;
104
105         if ((size_t)sz_payload > sizeof(msg->payload))
106                 goto fail;
107
108         if (sz_payload) {
109                 ret = recv(fd, (void *)((char *)msg + sz_hdr), sz_payload, 0);
110                 if (ret < sz_payload) {
111                         PMD_DRV_LOG(ERR,
112                                 "Failed to recv msg payload: %d instead of %d.",
113                                 ret, msg->size);
114                         goto fail;
115                 }
116         }
117
118         return 0;
119
120 fail:
121         return -1;
122 }
123
124 struct hugepage_file_info {
125         uint64_t addr;            /**< virtual addr */
126         size_t   size;            /**< the file size */
127         char     path[PATH_MAX];  /**< path to backing file */
128 };
129
130 /* Two possible options:
131  * 1. Match HUGEPAGE_INFO_FMT to find the file storing struct hugepage_file
132  * array. This is simple but cannot be used in secondary process because
133  * secondary process will close and munmap that file.
134  * 2. Match HUGEFILE_FMT to find hugepage files directly.
135  *
136  * We choose option 2.
137  */
138 static int
139 get_hugepage_file_info(struct hugepage_file_info huges[], int max)
140 {
141         int idx, k, exist;
142         FILE *f;
143         char buf[BUFSIZ], *tmp, *tail;
144         char *str_underline, *str_start;
145         int huge_index;
146         uint64_t v_start, v_end;
147         struct stat stats;
148
149         f = fopen("/proc/self/maps", "r");
150         if (!f) {
151                 PMD_DRV_LOG(ERR, "cannot open /proc/self/maps");
152                 return -1;
153         }
154
155         idx = 0;
156         while (fgets(buf, sizeof(buf), f) != NULL) {
157                 if (sscanf(buf, "%" PRIx64 "-%" PRIx64, &v_start, &v_end) < 2) {
158                         PMD_DRV_LOG(ERR, "Failed to parse address");
159                         goto error;
160                 }
161
162                 tmp = strchr(buf, ' ') + 1; /** skip address */
163                 tmp = strchr(tmp, ' ') + 1; /** skip perm */
164                 tmp = strchr(tmp, ' ') + 1; /** skip offset */
165                 tmp = strchr(tmp, ' ') + 1; /** skip dev */
166                 tmp = strchr(tmp, ' ') + 1; /** skip inode */
167                 while (*tmp == ' ')         /** skip spaces */
168                         tmp++;
169                 tail = strrchr(tmp, '\n');  /** remove newline if exists */
170                 if (tail)
171                         *tail = '\0';
172
173                 /* Match HUGEFILE_FMT, aka "%s/%smap_%d",
174                  * which is defined in eal_filesystem.h
175                  */
176                 str_underline = strrchr(tmp, '_');
177                 if (!str_underline)
178                         continue;
179
180                 str_start = str_underline - strlen("map");
181                 if (str_start < tmp)
182                         continue;
183
184                 if (sscanf(str_start, "map_%d", &huge_index) != 1)
185                         continue;
186
187                 /* skip duplicated file which is mapped to different regions */
188                 for (k = 0, exist = -1; k < idx; ++k) {
189                         if (!strcmp(huges[k].path, tmp)) {
190                                 exist = k;
191                                 break;
192                         }
193                 }
194                 if (exist >= 0)
195                         continue;
196
197                 if (idx >= max) {
198                         PMD_DRV_LOG(ERR, "Exceed maximum of %d", max);
199                         goto error;
200                 }
201
202                 huges[idx].addr = v_start;
203                 huges[idx].size = v_end - v_start; /* To be corrected later */
204                 snprintf(huges[idx].path, PATH_MAX, "%s", tmp);
205                 idx++;
206         }
207
208         /* correct the size for files who have many regions */
209         for (k = 0; k < idx; ++k) {
210                 if (stat(huges[k].path, &stats) < 0) {
211                         PMD_DRV_LOG(ERR, "Failed to stat %s, %s\n",
212                                     huges[k].path, strerror(errno));
213                         continue;
214                 }
215                 huges[k].size = stats.st_size;
216                 PMD_DRV_LOG(INFO, "file %s, size %zx\n",
217                             huges[k].path, huges[k].size);
218         }
219
220         fclose(f);
221         return idx;
222
223 error:
224         fclose(f);
225         return -1;
226 }
227
228 static int
229 prepare_vhost_memory_user(struct vhost_user_msg *msg, int fds[])
230 {
231         int i, num;
232         struct hugepage_file_info huges[VHOST_MEMORY_MAX_NREGIONS];
233         struct vhost_memory_region *mr;
234
235         num = get_hugepage_file_info(huges, VHOST_MEMORY_MAX_NREGIONS);
236         if (num < 0) {
237                 PMD_INIT_LOG(ERR, "Failed to prepare memory for vhost-user");
238                 return -1;
239         }
240
241         for (i = 0; i < num; ++i) {
242                 mr = &msg->payload.memory.regions[i];
243                 mr->guest_phys_addr = huges[i].addr; /* use vaddr! */
244                 mr->userspace_addr = huges[i].addr;
245                 mr->memory_size = huges[i].size;
246                 mr->mmap_offset = 0;
247                 fds[i] = open(huges[i].path, O_RDWR);
248         }
249
250         msg->payload.memory.nregions = num;
251         msg->payload.memory.padding = 0;
252
253         return 0;
254 }
255
256 static struct vhost_user_msg m;
257
258 const char * const vhost_msg_strings[] = {
259         [VHOST_USER_SET_OWNER] = "VHOST_SET_OWNER",
260         [VHOST_USER_RESET_OWNER] = "VHOST_RESET_OWNER",
261         [VHOST_USER_SET_FEATURES] = "VHOST_SET_FEATURES",
262         [VHOST_USER_GET_FEATURES] = "VHOST_GET_FEATURES",
263         [VHOST_USER_SET_VRING_CALL] = "VHOST_SET_VRING_CALL",
264         [VHOST_USER_SET_VRING_NUM] = "VHOST_SET_VRING_NUM",
265         [VHOST_USER_SET_VRING_BASE] = "VHOST_SET_VRING_BASE",
266         [VHOST_USER_GET_VRING_BASE] = "VHOST_GET_VRING_BASE",
267         [VHOST_USER_SET_VRING_ADDR] = "VHOST_SET_VRING_ADDR",
268         [VHOST_USER_SET_VRING_KICK] = "VHOST_SET_VRING_KICK",
269         [VHOST_USER_SET_MEM_TABLE] = "VHOST_SET_MEM_TABLE",
270         [VHOST_USER_SET_VRING_ENABLE] = "VHOST_SET_VRING_ENABLE",
271 };
272
273 static int
274 vhost_user_sock(struct virtio_user_dev *dev,
275                 enum vhost_user_request req,
276                 void *arg)
277 {
278         struct vhost_user_msg msg;
279         struct vhost_vring_file *file = 0;
280         int need_reply = 0;
281         int fds[VHOST_MEMORY_MAX_NREGIONS];
282         int fd_num = 0;
283         int i, len;
284         int vhostfd = dev->vhostfd;
285
286         RTE_SET_USED(m);
287
288         PMD_DRV_LOG(INFO, "%s", vhost_msg_strings[req]);
289
290         msg.request = req;
291         msg.flags = VHOST_USER_VERSION;
292         msg.size = 0;
293
294         switch (req) {
295         case VHOST_USER_GET_FEATURES:
296                 need_reply = 1;
297                 break;
298
299         case VHOST_USER_SET_FEATURES:
300         case VHOST_USER_SET_LOG_BASE:
301                 msg.payload.u64 = *((__u64 *)arg);
302                 msg.size = sizeof(m.payload.u64);
303                 break;
304
305         case VHOST_USER_SET_OWNER:
306         case VHOST_USER_RESET_OWNER:
307                 break;
308
309         case VHOST_USER_SET_MEM_TABLE:
310                 if (prepare_vhost_memory_user(&msg, fds) < 0)
311                         return -1;
312                 fd_num = msg.payload.memory.nregions;
313                 msg.size = sizeof(m.payload.memory.nregions);
314                 msg.size += sizeof(m.payload.memory.padding);
315                 msg.size += fd_num * sizeof(struct vhost_memory_region);
316                 break;
317
318         case VHOST_USER_SET_LOG_FD:
319                 fds[fd_num++] = *((int *)arg);
320                 break;
321
322         case VHOST_USER_SET_VRING_NUM:
323         case VHOST_USER_SET_VRING_BASE:
324         case VHOST_USER_SET_VRING_ENABLE:
325                 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state));
326                 msg.size = sizeof(m.payload.state);
327                 break;
328
329         case VHOST_USER_GET_VRING_BASE:
330                 memcpy(&msg.payload.state, arg, sizeof(msg.payload.state));
331                 msg.size = sizeof(m.payload.state);
332                 need_reply = 1;
333                 break;
334
335         case VHOST_USER_SET_VRING_ADDR:
336                 memcpy(&msg.payload.addr, arg, sizeof(msg.payload.addr));
337                 msg.size = sizeof(m.payload.addr);
338                 break;
339
340         case VHOST_USER_SET_VRING_KICK:
341         case VHOST_USER_SET_VRING_CALL:
342         case VHOST_USER_SET_VRING_ERR:
343                 file = arg;
344                 msg.payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK;
345                 msg.size = sizeof(m.payload.u64);
346                 if (file->fd > 0)
347                         fds[fd_num++] = file->fd;
348                 else
349                         msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK;
350                 break;
351
352         default:
353                 PMD_DRV_LOG(ERR, "trying to send unhandled msg type");
354                 return -1;
355         }
356
357         len = VHOST_USER_HDR_SIZE + msg.size;
358         if (vhost_user_write(vhostfd, &msg, len, fds, fd_num) < 0) {
359                 PMD_DRV_LOG(ERR, "%s failed: %s",
360                             vhost_msg_strings[req], strerror(errno));
361                 return -1;
362         }
363
364         if (req == VHOST_USER_SET_MEM_TABLE)
365                 for (i = 0; i < fd_num; ++i)
366                         close(fds[i]);
367
368         if (need_reply) {
369                 if (vhost_user_read(vhostfd, &msg) < 0) {
370                         PMD_DRV_LOG(ERR, "Received msg failed: %s",
371                                     strerror(errno));
372                         return -1;
373                 }
374
375                 if (req != msg.request) {
376                         PMD_DRV_LOG(ERR, "Received unexpected msg type");
377                         return -1;
378                 }
379
380                 switch (req) {
381                 case VHOST_USER_GET_FEATURES:
382                         if (msg.size != sizeof(m.payload.u64)) {
383                                 PMD_DRV_LOG(ERR, "Received bad msg size");
384                                 return -1;
385                         }
386                         *((__u64 *)arg) = msg.payload.u64;
387                         break;
388                 case VHOST_USER_GET_VRING_BASE:
389                         if (msg.size != sizeof(m.payload.state)) {
390                                 PMD_DRV_LOG(ERR, "Received bad msg size");
391                                 return -1;
392                         }
393                         memcpy(arg, &msg.payload.state,
394                                sizeof(struct vhost_vring_state));
395                         break;
396                 default:
397                         PMD_DRV_LOG(ERR, "Received unexpected msg type");
398                         return -1;
399                 }
400         }
401
402         return 0;
403 }
404
405 #define MAX_VIRTIO_USER_BACKLOG 1
406 static int
407 virtio_user_start_server(struct virtio_user_dev *dev, struct sockaddr_un *un)
408 {
409         int ret;
410         int flag;
411         int fd = dev->listenfd;
412
413         ret = bind(fd, (struct sockaddr *)un, sizeof(*un));
414         if (ret < 0) {
415                 PMD_DRV_LOG(ERR, "failed to bind to %s: %s; remove it and try again\n",
416                             dev->path, strerror(errno));
417                 return -1;
418         }
419         ret = listen(fd, MAX_VIRTIO_USER_BACKLOG);
420         if (ret < 0)
421                 return -1;
422
423         flag = fcntl(fd, F_GETFL);
424         fcntl(fd, F_SETFL, flag | O_NONBLOCK);
425
426         return 0;
427 }
428
429 /**
430  * Set up environment to talk with a vhost user backend.
431  *
432  * @return
433  *   - (-1) if fail;
434  *   - (0) if succeed.
435  */
436 static int
437 vhost_user_setup(struct virtio_user_dev *dev)
438 {
439         int fd;
440         int flag;
441         struct sockaddr_un un;
442
443         fd = socket(AF_UNIX, SOCK_STREAM, 0);
444         if (fd < 0) {
445                 PMD_DRV_LOG(ERR, "socket() error, %s", strerror(errno));
446                 return -1;
447         }
448
449         flag = fcntl(fd, F_GETFD);
450         if (fcntl(fd, F_SETFD, flag | FD_CLOEXEC) < 0)
451                 PMD_DRV_LOG(WARNING, "fcntl failed, %s", strerror(errno));
452
453         memset(&un, 0, sizeof(un));
454         un.sun_family = AF_UNIX;
455         snprintf(un.sun_path, sizeof(un.sun_path), "%s", dev->path);
456
457         if (dev->is_server) {
458                 dev->listenfd = fd;
459                 if (virtio_user_start_server(dev, &un) < 0) {
460                         PMD_DRV_LOG(ERR, "virtio-user startup fails in server mode");
461                         close(fd);
462                         return -1;
463                 }
464                 dev->vhostfd = -1;
465         } else {
466                 if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) {
467                         PMD_DRV_LOG(ERR, "connect error, %s", strerror(errno));
468                         close(fd);
469                         return -1;
470                 }
471                 dev->vhostfd = fd;
472         }
473
474         return 0;
475 }
476
477 static int
478 vhost_user_enable_queue_pair(struct virtio_user_dev *dev,
479                              uint16_t pair_idx,
480                              int enable)
481 {
482         int i;
483
484         for (i = 0; i < 2; ++i) {
485                 struct vhost_vring_state state = {
486                         .index = pair_idx * 2 + i,
487                         .num   = enable,
488                 };
489
490                 if (vhost_user_sock(dev, VHOST_USER_SET_VRING_ENABLE, &state))
491                         return -1;
492         }
493
494         return 0;
495 }
496
497 struct virtio_user_backend_ops ops_user = {
498         .setup = vhost_user_setup,
499         .send_request = vhost_user_sock,
500         .enable_qp = vhost_user_enable_queue_pair
501 };