vhost: fix potential null pointer dereference
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(void)
131 {
132         return 0;
133 }
134
135 static int
136 vhost_user_reset_owner(struct virtio_net *dev)
137 {
138         vhost_destroy_device_notify(dev);
139
140         cleanup_device(dev, 0);
141         reset_device(dev);
142         return 0;
143 }
144
145 /*
146  * The features that we support are requested.
147  */
148 static uint64_t
149 vhost_user_get_features(struct virtio_net *dev)
150 {
151         uint64_t features = 0;
152
153         rte_vhost_driver_get_features(dev->ifname, &features);
154         return features;
155 }
156
157 /*
158  * The queue number that we support are requested.
159  */
160 static uint32_t
161 vhost_user_get_queue_num(struct virtio_net *dev)
162 {
163         uint32_t queue_num = 0;
164
165         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
166         return queue_num;
167 }
168
169 /*
170  * We receive the negotiated features supported by us and the virtio device.
171  */
172 static int
173 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
174 {
175         uint64_t vhost_features = 0;
176         struct rte_vdpa_device *vdpa_dev;
177         int did = -1;
178
179         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
180         if (features & ~vhost_features) {
181                 RTE_LOG(ERR, VHOST_CONFIG,
182                         "(%d) received invalid negotiated features.\n",
183                         dev->vid);
184                 return -1;
185         }
186
187         if (dev->flags & VIRTIO_DEV_RUNNING) {
188                 if (dev->features == features)
189                         return 0;
190
191                 /*
192                  * Error out if master tries to change features while device is
193                  * in running state. The exception being VHOST_F_LOG_ALL, which
194                  * is enabled when the live-migration starts.
195                  */
196                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
197                         RTE_LOG(ERR, VHOST_CONFIG,
198                                 "(%d) features changed while device is running.\n",
199                                 dev->vid);
200                         return -1;
201                 }
202
203                 if (dev->notify_ops->features_changed)
204                         dev->notify_ops->features_changed(dev->vid, features);
205         }
206
207         dev->features = features;
208         if (dev->features &
209                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
210                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
211         } else {
212                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
213         }
214         VHOST_LOG_DEBUG(VHOST_CONFIG,
215                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
216                 dev->vid,
217                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
218                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
219
220         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
221             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
222                 /*
223                  * Remove all but first queue pair if MQ hasn't been
224                  * negotiated. This is safe because the device is not
225                  * running at this stage.
226                  */
227                 while (dev->nr_vring > 2) {
228                         struct vhost_virtqueue *vq;
229
230                         vq = dev->virtqueue[--dev->nr_vring];
231                         if (!vq)
232                                 continue;
233
234                         dev->virtqueue[dev->nr_vring] = NULL;
235                         cleanup_vq(vq, 1);
236                         free_vq(vq);
237                 }
238         }
239
240         did = dev->vdpa_dev_id;
241         vdpa_dev = rte_vdpa_get_device(did);
242         if (vdpa_dev && vdpa_dev->ops->set_features)
243                 vdpa_dev->ops->set_features(dev->vid);
244
245         return 0;
246 }
247
248 /*
249  * The virtio device sends us the size of the descriptor ring.
250  */
251 static int
252 vhost_user_set_vring_num(struct virtio_net *dev,
253                          VhostUserMsg *msg)
254 {
255         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
256
257         vq->size = msg->payload.state.num;
258
259         /* VIRTIO 1.0, 2.4 Virtqueues says:
260          *
261          *   Queue Size value is always a power of 2. The maximum Queue Size
262          *   value is 32768.
263          */
264         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
265                 RTE_LOG(ERR, VHOST_CONFIG,
266                         "invalid virtqueue size %u\n", vq->size);
267                 return -1;
268         }
269
270         if (dev->dequeue_zero_copy) {
271                 vq->nr_zmbuf = 0;
272                 vq->last_zmbuf_idx = 0;
273                 vq->zmbuf_size = vq->size;
274                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
275                                          sizeof(struct zcopy_mbuf), 0);
276                 if (vq->zmbufs == NULL) {
277                         RTE_LOG(WARNING, VHOST_CONFIG,
278                                 "failed to allocate mem for zero copy; "
279                                 "zero copy is force disabled\n");
280                         dev->dequeue_zero_copy = 0;
281                 }
282                 TAILQ_INIT(&vq->zmbuf_list);
283         }
284
285         vq->shadow_used_ring = rte_malloc(NULL,
286                                 vq->size * sizeof(struct vring_used_elem),
287                                 RTE_CACHE_LINE_SIZE);
288         if (!vq->shadow_used_ring) {
289                 RTE_LOG(ERR, VHOST_CONFIG,
290                         "failed to allocate memory for shadow used ring.\n");
291                 return -1;
292         }
293
294         vq->batch_copy_elems = rte_malloc(NULL,
295                                 vq->size * sizeof(struct batch_copy_elem),
296                                 RTE_CACHE_LINE_SIZE);
297         if (!vq->batch_copy_elems) {
298                 RTE_LOG(ERR, VHOST_CONFIG,
299                         "failed to allocate memory for batching copy.\n");
300                 return -1;
301         }
302
303         return 0;
304 }
305
306 /*
307  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
308  * same numa node as the memory of vring descriptor.
309  */
310 #ifdef RTE_LIBRTE_VHOST_NUMA
311 static struct virtio_net*
312 numa_realloc(struct virtio_net *dev, int index)
313 {
314         int oldnode, newnode;
315         struct virtio_net *old_dev;
316         struct vhost_virtqueue *old_vq, *vq;
317         struct zcopy_mbuf *new_zmbuf;
318         struct vring_used_elem *new_shadow_used_ring;
319         struct batch_copy_elem *new_batch_copy_elems;
320         int ret;
321
322         old_dev = dev;
323         vq = old_vq = dev->virtqueue[index];
324
325         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
326                             MPOL_F_NODE | MPOL_F_ADDR);
327
328         /* check if we need to reallocate vq */
329         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
330                              MPOL_F_NODE | MPOL_F_ADDR);
331         if (ret) {
332                 RTE_LOG(ERR, VHOST_CONFIG,
333                         "Unable to get vq numa information.\n");
334                 return dev;
335         }
336         if (oldnode != newnode) {
337                 RTE_LOG(INFO, VHOST_CONFIG,
338                         "reallocate vq from %d to %d node\n", oldnode, newnode);
339                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
340                 if (!vq)
341                         return dev;
342
343                 memcpy(vq, old_vq, sizeof(*vq));
344                 TAILQ_INIT(&vq->zmbuf_list);
345
346                 new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
347                         sizeof(struct zcopy_mbuf), 0, newnode);
348                 if (new_zmbuf) {
349                         rte_free(vq->zmbufs);
350                         vq->zmbufs = new_zmbuf;
351                 }
352
353                 new_shadow_used_ring = rte_malloc_socket(NULL,
354                         vq->size * sizeof(struct vring_used_elem),
355                         RTE_CACHE_LINE_SIZE,
356                         newnode);
357                 if (new_shadow_used_ring) {
358                         rte_free(vq->shadow_used_ring);
359                         vq->shadow_used_ring = new_shadow_used_ring;
360                 }
361
362                 new_batch_copy_elems = rte_malloc_socket(NULL,
363                         vq->size * sizeof(struct batch_copy_elem),
364                         RTE_CACHE_LINE_SIZE,
365                         newnode);
366                 if (new_batch_copy_elems) {
367                         rte_free(vq->batch_copy_elems);
368                         vq->batch_copy_elems = new_batch_copy_elems;
369                 }
370
371                 rte_free(old_vq);
372         }
373
374         /* check if we need to reallocate dev */
375         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
376                             MPOL_F_NODE | MPOL_F_ADDR);
377         if (ret) {
378                 RTE_LOG(ERR, VHOST_CONFIG,
379                         "Unable to get dev numa information.\n");
380                 goto out;
381         }
382         if (oldnode != newnode) {
383                 RTE_LOG(INFO, VHOST_CONFIG,
384                         "reallocate dev from %d to %d node\n",
385                         oldnode, newnode);
386                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
387                 if (!dev) {
388                         dev = old_dev;
389                         goto out;
390                 }
391
392                 memcpy(dev, old_dev, sizeof(*dev));
393                 rte_free(old_dev);
394         }
395
396 out:
397         dev->virtqueue[index] = vq;
398         vhost_devices[dev->vid] = dev;
399
400         if (old_vq != vq)
401                 vhost_user_iotlb_init(dev, index);
402
403         return dev;
404 }
405 #else
406 static struct virtio_net*
407 numa_realloc(struct virtio_net *dev, int index __rte_unused)
408 {
409         return dev;
410 }
411 #endif
412
413 /* Converts QEMU virtual address to Vhost virtual address. */
414 static uint64_t
415 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
416 {
417         struct rte_vhost_mem_region *r;
418         uint32_t i;
419
420         /* Find the region where the address lives. */
421         for (i = 0; i < dev->mem->nregions; i++) {
422                 r = &dev->mem->regions[i];
423
424                 if (qva >= r->guest_user_addr &&
425                     qva <  r->guest_user_addr + r->size) {
426
427                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
428                                 *len = r->guest_user_addr + r->size - qva;
429
430                         return qva - r->guest_user_addr +
431                                r->host_user_addr;
432                 }
433         }
434         *len = 0;
435
436         return 0;
437 }
438
439
440 /*
441  * Converts ring address to Vhost virtual address.
442  * If IOMMU is enabled, the ring address is a guest IO virtual address,
443  * else it is a QEMU virtual address.
444  */
445 static uint64_t
446 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
447                 uint64_t ra, uint64_t *size)
448 {
449         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
450                 uint64_t vva;
451
452                 vva = vhost_user_iotlb_cache_find(vq, ra,
453                                         size, VHOST_ACCESS_RW);
454                 if (!vva)
455                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
456
457                 return vva;
458         }
459
460         return qva_to_vva(dev, ra, size);
461 }
462
463 static struct virtio_net *
464 translate_ring_addresses(struct virtio_net *dev, int vq_index)
465 {
466         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
467         struct vhost_vring_addr *addr = &vq->ring_addrs;
468         uint64_t len;
469
470         /* The addresses are converted from QEMU virtual to Vhost virtual. */
471         if (vq->desc && vq->avail && vq->used)
472                 return dev;
473
474         len = sizeof(struct vring_desc) * vq->size;
475         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
476                         vq, addr->desc_user_addr, &len);
477         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
478                 RTE_LOG(DEBUG, VHOST_CONFIG,
479                         "(%d) failed to map desc ring.\n",
480                         dev->vid);
481                 return dev;
482         }
483
484         dev = numa_realloc(dev, vq_index);
485         vq = dev->virtqueue[vq_index];
486         addr = &vq->ring_addrs;
487
488         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
489         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
490                         vq, addr->avail_user_addr, &len);
491         if (vq->avail == 0 ||
492                         len != sizeof(struct vring_avail) +
493                         sizeof(uint16_t) * vq->size) {
494                 RTE_LOG(DEBUG, VHOST_CONFIG,
495                         "(%d) failed to map avail ring.\n",
496                         dev->vid);
497                 return dev;
498         }
499
500         len = sizeof(struct vring_used) +
501                 sizeof(struct vring_used_elem) * vq->size;
502         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
503                         vq, addr->used_user_addr, &len);
504         if (vq->used == 0 || len != sizeof(struct vring_used) +
505                         sizeof(struct vring_used_elem) * vq->size) {
506                 RTE_LOG(DEBUG, VHOST_CONFIG,
507                         "(%d) failed to map used ring.\n",
508                         dev->vid);
509                 return dev;
510         }
511
512         if (vq->last_used_idx != vq->used->idx) {
513                 RTE_LOG(WARNING, VHOST_CONFIG,
514                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
515                         "some packets maybe resent for Tx and dropped for Rx\n",
516                         vq->last_used_idx, vq->used->idx);
517                 vq->last_used_idx  = vq->used->idx;
518                 vq->last_avail_idx = vq->used->idx;
519         }
520
521         vq->log_guest_addr = addr->log_guest_addr;
522
523         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
524                         dev->vid, vq->desc);
525         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
526                         dev->vid, vq->avail);
527         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
528                         dev->vid, vq->used);
529         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
530                         dev->vid, vq->log_guest_addr);
531
532         return dev;
533 }
534
535 /*
536  * The virtio device sends us the desc, used and avail ring addresses.
537  * This function then converts these to our address space.
538  */
539 static int
540 vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
541 {
542         struct vhost_virtqueue *vq;
543         struct vhost_vring_addr *addr = &msg->payload.addr;
544         struct virtio_net *dev = *pdev;
545
546         if (dev->mem == NULL)
547                 return -1;
548
549         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
550         vq = dev->virtqueue[msg->payload.addr.index];
551
552         /*
553          * Rings addresses should not be interpreted as long as the ring is not
554          * started and enabled
555          */
556         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
557
558         vring_invalidate(dev, vq);
559
560         if (vq->enabled && (dev->features &
561                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
562                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
563                 if (!dev)
564                         return -1;
565
566                 *pdev = dev;
567         }
568
569         return 0;
570 }
571
572 /*
573  * The virtio device sends us the available ring last used index.
574  */
575 static int
576 vhost_user_set_vring_base(struct virtio_net *dev,
577                           VhostUserMsg *msg)
578 {
579         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
580                         msg->payload.state.num;
581         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
582                         msg->payload.state.num;
583
584         return 0;
585 }
586
587 static int
588 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
589                    uint64_t host_phys_addr, uint64_t size)
590 {
591         struct guest_page *page, *last_page;
592
593         if (dev->nr_guest_pages == dev->max_guest_pages) {
594                 dev->max_guest_pages *= 2;
595                 dev->guest_pages = realloc(dev->guest_pages,
596                                         dev->max_guest_pages * sizeof(*page));
597                 if (!dev->guest_pages) {
598                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
599                         return -1;
600                 }
601         }
602
603         if (dev->nr_guest_pages > 0) {
604                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
605                 /* merge if the two pages are continuous */
606                 if (host_phys_addr == last_page->host_phys_addr +
607                                       last_page->size) {
608                         last_page->size += size;
609                         return 0;
610                 }
611         }
612
613         page = &dev->guest_pages[dev->nr_guest_pages++];
614         page->guest_phys_addr = guest_phys_addr;
615         page->host_phys_addr  = host_phys_addr;
616         page->size = size;
617
618         return 0;
619 }
620
621 static int
622 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
623                 uint64_t page_size)
624 {
625         uint64_t reg_size = reg->size;
626         uint64_t host_user_addr  = reg->host_user_addr;
627         uint64_t guest_phys_addr = reg->guest_phys_addr;
628         uint64_t host_phys_addr;
629         uint64_t size;
630
631         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
632         size = page_size - (guest_phys_addr & (page_size - 1));
633         size = RTE_MIN(size, reg_size);
634
635         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
636                 return -1;
637
638         host_user_addr  += size;
639         guest_phys_addr += size;
640         reg_size -= size;
641
642         while (reg_size > 0) {
643                 size = RTE_MIN(reg_size, page_size);
644                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
645                                                   host_user_addr);
646                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
647                                 size) < 0)
648                         return -1;
649
650                 host_user_addr  += size;
651                 guest_phys_addr += size;
652                 reg_size -= size;
653         }
654
655         return 0;
656 }
657
658 #ifdef RTE_LIBRTE_VHOST_DEBUG
659 /* TODO: enable it only in debug mode? */
660 static void
661 dump_guest_pages(struct virtio_net *dev)
662 {
663         uint32_t i;
664         struct guest_page *page;
665
666         for (i = 0; i < dev->nr_guest_pages; i++) {
667                 page = &dev->guest_pages[i];
668
669                 RTE_LOG(INFO, VHOST_CONFIG,
670                         "guest physical page region %u\n"
671                         "\t guest_phys_addr: %" PRIx64 "\n"
672                         "\t host_phys_addr : %" PRIx64 "\n"
673                         "\t size           : %" PRIx64 "\n",
674                         i,
675                         page->guest_phys_addr,
676                         page->host_phys_addr,
677                         page->size);
678         }
679 }
680 #else
681 #define dump_guest_pages(dev)
682 #endif
683
684 static bool
685 vhost_memory_changed(struct VhostUserMemory *new,
686                      struct rte_vhost_memory *old)
687 {
688         uint32_t i;
689
690         if (new->nregions != old->nregions)
691                 return true;
692
693         for (i = 0; i < new->nregions; ++i) {
694                 VhostUserMemoryRegion *new_r = &new->regions[i];
695                 struct rte_vhost_mem_region *old_r = &old->regions[i];
696
697                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
698                         return true;
699                 if (new_r->memory_size != old_r->size)
700                         return true;
701                 if (new_r->userspace_addr != old_r->guest_user_addr)
702                         return true;
703         }
704
705         return false;
706 }
707
708 static int
709 vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
710 {
711         struct virtio_net *dev = *pdev;
712         struct VhostUserMemory memory = pmsg->payload.memory;
713         struct rte_vhost_mem_region *reg;
714         void *mmap_addr;
715         uint64_t mmap_size;
716         uint64_t mmap_offset;
717         uint64_t alignment;
718         uint32_t i;
719         int populate;
720         int fd;
721
722         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
723                 RTE_LOG(ERR, VHOST_CONFIG,
724                         "too many memory regions (%u)\n", memory.nregions);
725                 return -1;
726         }
727
728         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
729                 RTE_LOG(INFO, VHOST_CONFIG,
730                         "(%d) memory regions not changed\n", dev->vid);
731
732                 for (i = 0; i < memory.nregions; i++)
733                         close(pmsg->fds[i]);
734
735                 return 0;
736         }
737
738         if (dev->mem) {
739                 free_mem_region(dev);
740                 rte_free(dev->mem);
741                 dev->mem = NULL;
742         }
743
744         dev->nr_guest_pages = 0;
745         if (!dev->guest_pages) {
746                 dev->max_guest_pages = 8;
747                 dev->guest_pages = malloc(dev->max_guest_pages *
748                                                 sizeof(struct guest_page));
749                 if (dev->guest_pages == NULL) {
750                         RTE_LOG(ERR, VHOST_CONFIG,
751                                 "(%d) failed to allocate memory "
752                                 "for dev->guest_pages\n",
753                                 dev->vid);
754                         return -1;
755                 }
756         }
757
758         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
759                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
760         if (dev->mem == NULL) {
761                 RTE_LOG(ERR, VHOST_CONFIG,
762                         "(%d) failed to allocate memory for dev->mem\n",
763                         dev->vid);
764                 return -1;
765         }
766         dev->mem->nregions = memory.nregions;
767
768         for (i = 0; i < memory.nregions; i++) {
769                 fd  = pmsg->fds[i];
770                 reg = &dev->mem->regions[i];
771
772                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
773                 reg->guest_user_addr = memory.regions[i].userspace_addr;
774                 reg->size            = memory.regions[i].memory_size;
775                 reg->fd              = fd;
776
777                 mmap_offset = memory.regions[i].mmap_offset;
778
779                 /* Check for memory_size + mmap_offset overflow */
780                 if (mmap_offset >= -reg->size) {
781                         RTE_LOG(ERR, VHOST_CONFIG,
782                                 "mmap_offset (%#"PRIx64") and memory_size "
783                                 "(%#"PRIx64") overflow\n",
784                                 mmap_offset, reg->size);
785                         goto err_mmap;
786                 }
787
788                 mmap_size = reg->size + mmap_offset;
789
790                 /* mmap() without flag of MAP_ANONYMOUS, should be called
791                  * with length argument aligned with hugepagesz at older
792                  * longterm version Linux, like 2.6.32 and 3.2.72, or
793                  * mmap() will fail with EINVAL.
794                  *
795                  * to avoid failure, make sure in caller to keep length
796                  * aligned.
797                  */
798                 alignment = get_blk_size(fd);
799                 if (alignment == (uint64_t)-1) {
800                         RTE_LOG(ERR, VHOST_CONFIG,
801                                 "couldn't get hugepage size through fstat\n");
802                         goto err_mmap;
803                 }
804                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
805
806                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
807                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
808                                  MAP_SHARED | populate, fd, 0);
809
810                 if (mmap_addr == MAP_FAILED) {
811                         RTE_LOG(ERR, VHOST_CONFIG,
812                                 "mmap region %u failed.\n", i);
813                         goto err_mmap;
814                 }
815
816                 reg->mmap_addr = mmap_addr;
817                 reg->mmap_size = mmap_size;
818                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
819                                       mmap_offset;
820
821                 if (dev->dequeue_zero_copy)
822                         if (add_guest_pages(dev, reg, alignment) < 0) {
823                                 RTE_LOG(ERR, VHOST_CONFIG,
824                                         "adding guest pages to region %u failed.\n",
825                                         i);
826                                 goto err_mmap;
827                         }
828
829                 RTE_LOG(INFO, VHOST_CONFIG,
830                         "guest memory region %u, size: 0x%" PRIx64 "\n"
831                         "\t guest physical addr: 0x%" PRIx64 "\n"
832                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
833                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
834                         "\t mmap addr : 0x%" PRIx64 "\n"
835                         "\t mmap size : 0x%" PRIx64 "\n"
836                         "\t mmap align: 0x%" PRIx64 "\n"
837                         "\t mmap off  : 0x%" PRIx64 "\n",
838                         i, reg->size,
839                         reg->guest_phys_addr,
840                         reg->guest_user_addr,
841                         reg->host_user_addr,
842                         (uint64_t)(uintptr_t)mmap_addr,
843                         mmap_size,
844                         alignment,
845                         mmap_offset);
846         }
847
848         for (i = 0; i < dev->nr_vring; i++) {
849                 struct vhost_virtqueue *vq = dev->virtqueue[i];
850
851                 if (vq->desc || vq->avail || vq->used) {
852                         /*
853                          * If the memory table got updated, the ring addresses
854                          * need to be translated again as virtual addresses have
855                          * changed.
856                          */
857                         vring_invalidate(dev, vq);
858
859                         dev = translate_ring_addresses(dev, i);
860                         if (!dev)
861                                 return -1;
862
863                         *pdev = dev;
864                 }
865         }
866
867         dump_guest_pages(dev);
868
869         return 0;
870
871 err_mmap:
872         free_mem_region(dev);
873         rte_free(dev->mem);
874         dev->mem = NULL;
875         return -1;
876 }
877
878 static int
879 vq_is_ready(struct vhost_virtqueue *vq)
880 {
881         return vq && vq->desc && vq->avail && vq->used &&
882                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
883                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
884 }
885
886 static int
887 virtio_is_ready(struct virtio_net *dev)
888 {
889         struct vhost_virtqueue *vq;
890         uint32_t i;
891
892         if (dev->nr_vring == 0)
893                 return 0;
894
895         for (i = 0; i < dev->nr_vring; i++) {
896                 vq = dev->virtqueue[i];
897
898                 if (!vq_is_ready(vq))
899                         return 0;
900         }
901
902         RTE_LOG(INFO, VHOST_CONFIG,
903                 "virtio is now ready for processing.\n");
904         return 1;
905 }
906
907 static void
908 vhost_user_set_vring_call(struct virtio_net *dev, struct VhostUserMsg *pmsg)
909 {
910         struct vhost_vring_file file;
911         struct vhost_virtqueue *vq;
912
913         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
914         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
915                 file.fd = VIRTIO_INVALID_EVENTFD;
916         else
917                 file.fd = pmsg->fds[0];
918         RTE_LOG(INFO, VHOST_CONFIG,
919                 "vring call idx:%d file:%d\n", file.index, file.fd);
920
921         vq = dev->virtqueue[file.index];
922         if (vq->callfd >= 0)
923                 close(vq->callfd);
924
925         vq->callfd = file.fd;
926 }
927
928 static void
929 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
930 {
931         struct vhost_vring_file file;
932         struct vhost_virtqueue *vq;
933         struct virtio_net *dev = *pdev;
934
935         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
936         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
937                 file.fd = VIRTIO_INVALID_EVENTFD;
938         else
939                 file.fd = pmsg->fds[0];
940         RTE_LOG(INFO, VHOST_CONFIG,
941                 "vring kick idx:%d file:%d\n", file.index, file.fd);
942
943         /* Interpret ring addresses only when ring is started. */
944         dev = translate_ring_addresses(dev, file.index);
945         if (!dev)
946                 return;
947
948         *pdev = dev;
949
950         vq = dev->virtqueue[file.index];
951
952         /*
953          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
954          * the ring starts already enabled. Otherwise, it is enabled via
955          * the SET_VRING_ENABLE message.
956          */
957         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
958                 vq->enabled = 1;
959
960         if (vq->kickfd >= 0)
961                 close(vq->kickfd);
962         vq->kickfd = file.fd;
963 }
964
965 static void
966 free_zmbufs(struct vhost_virtqueue *vq)
967 {
968         struct zcopy_mbuf *zmbuf, *next;
969
970         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
971              zmbuf != NULL; zmbuf = next) {
972                 next = TAILQ_NEXT(zmbuf, next);
973
974                 rte_pktmbuf_free(zmbuf->mbuf);
975                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
976         }
977
978         rte_free(vq->zmbufs);
979 }
980
981 /*
982  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
983  */
984 static int
985 vhost_user_get_vring_base(struct virtio_net *dev,
986                           VhostUserMsg *msg)
987 {
988         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
989
990         /* We have to stop the queue (virtio) if it is running. */
991         vhost_destroy_device_notify(dev);
992
993         dev->flags &= ~VIRTIO_DEV_READY;
994         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
995
996         /* Here we are safe to get the last avail index */
997         msg->payload.state.num = vq->last_avail_idx;
998
999         RTE_LOG(INFO, VHOST_CONFIG,
1000                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1001                 msg->payload.state.num);
1002         /*
1003          * Based on current qemu vhost-user implementation, this message is
1004          * sent and only sent in vhost_vring_stop.
1005          * TODO: cleanup the vring, it isn't usable since here.
1006          */
1007         if (vq->kickfd >= 0)
1008                 close(vq->kickfd);
1009
1010         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1011
1012         if (vq->callfd >= 0)
1013                 close(vq->callfd);
1014
1015         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1016
1017         if (dev->dequeue_zero_copy)
1018                 free_zmbufs(vq);
1019         rte_free(vq->shadow_used_ring);
1020         vq->shadow_used_ring = NULL;
1021
1022         rte_free(vq->batch_copy_elems);
1023         vq->batch_copy_elems = NULL;
1024
1025         return 0;
1026 }
1027
1028 /*
1029  * when virtio queues are ready to work, qemu will send us to
1030  * enable the virtio queue pair.
1031  */
1032 static int
1033 vhost_user_set_vring_enable(struct virtio_net *dev,
1034                             VhostUserMsg *msg)
1035 {
1036         int enable = (int)msg->payload.state.num;
1037         int index = (int)msg->payload.state.index;
1038         struct rte_vdpa_device *vdpa_dev;
1039         int did = -1;
1040
1041         RTE_LOG(INFO, VHOST_CONFIG,
1042                 "set queue enable: %d to qp idx: %d\n",
1043                 enable, index);
1044
1045         did = dev->vdpa_dev_id;
1046         vdpa_dev = rte_vdpa_get_device(did);
1047         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1048                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1049
1050         if (dev->notify_ops->vring_state_changed)
1051                 dev->notify_ops->vring_state_changed(dev->vid,
1052                                 index, enable);
1053
1054         dev->virtqueue[index]->enabled = enable;
1055
1056         return 0;
1057 }
1058
1059 static void
1060 vhost_user_get_protocol_features(struct virtio_net *dev,
1061                                  struct VhostUserMsg *msg)
1062 {
1063         uint64_t features, protocol_features;
1064
1065         rte_vhost_driver_get_features(dev->ifname, &features);
1066         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1067
1068         /*
1069          * REPLY_ACK protocol feature is only mandatory for now
1070          * for IOMMU feature. If IOMMU is explicitly disabled by the
1071          * application, disable also REPLY_ACK feature for older buggy
1072          * Qemu versions (from v2.7.0 to v2.9.0).
1073          */
1074         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1075                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1076
1077         msg->payload.u64 = protocol_features;
1078         msg->size = sizeof(msg->payload.u64);
1079 }
1080
1081 static void
1082 vhost_user_set_protocol_features(struct virtio_net *dev,
1083                                  uint64_t protocol_features)
1084 {
1085         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES)
1086                 return;
1087
1088         dev->protocol_features = protocol_features;
1089 }
1090
1091 static int
1092 vhost_user_set_log_base(struct virtio_net *dev, struct VhostUserMsg *msg)
1093 {
1094         int fd = msg->fds[0];
1095         uint64_t size, off;
1096         void *addr;
1097
1098         if (fd < 0) {
1099                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1100                 return -1;
1101         }
1102
1103         if (msg->size != sizeof(VhostUserLog)) {
1104                 RTE_LOG(ERR, VHOST_CONFIG,
1105                         "invalid log base msg size: %"PRId32" != %d\n",
1106                         msg->size, (int)sizeof(VhostUserLog));
1107                 return -1;
1108         }
1109
1110         size = msg->payload.log.mmap_size;
1111         off  = msg->payload.log.mmap_offset;
1112
1113         /* Don't allow mmap_offset to point outside the mmap region */
1114         if (off > size) {
1115                 RTE_LOG(ERR, VHOST_CONFIG,
1116                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1117                         off, size);
1118                 return -1;
1119         }
1120
1121         RTE_LOG(INFO, VHOST_CONFIG,
1122                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1123                 size, off);
1124
1125         /*
1126          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1127          * fail when offset is not page size aligned.
1128          */
1129         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1130         close(fd);
1131         if (addr == MAP_FAILED) {
1132                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1133                 return -1;
1134         }
1135
1136         /*
1137          * Free previously mapped log memory on occasionally
1138          * multiple VHOST_USER_SET_LOG_BASE.
1139          */
1140         if (dev->log_addr) {
1141                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1142         }
1143         dev->log_addr = (uint64_t)(uintptr_t)addr;
1144         dev->log_base = dev->log_addr + off;
1145         dev->log_size = size;
1146
1147         return 0;
1148 }
1149
1150 /*
1151  * An rarp packet is constructed and broadcasted to notify switches about
1152  * the new location of the migrated VM, so that packets from outside will
1153  * not be lost after migration.
1154  *
1155  * However, we don't actually "send" a rarp packet here, instead, we set
1156  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1157  */
1158 static int
1159 vhost_user_send_rarp(struct virtio_net *dev, struct VhostUserMsg *msg)
1160 {
1161         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1162         struct rte_vdpa_device *vdpa_dev;
1163         int did = -1;
1164
1165         RTE_LOG(DEBUG, VHOST_CONFIG,
1166                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1167                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1168         memcpy(dev->mac.addr_bytes, mac, 6);
1169
1170         /*
1171          * Set the flag to inject a RARP broadcast packet at
1172          * rte_vhost_dequeue_burst().
1173          *
1174          * rte_smp_wmb() is for making sure the mac is copied
1175          * before the flag is set.
1176          */
1177         rte_smp_wmb();
1178         rte_atomic16_set(&dev->broadcast_rarp, 1);
1179         did = dev->vdpa_dev_id;
1180         vdpa_dev = rte_vdpa_get_device(did);
1181         if (vdpa_dev && vdpa_dev->ops->migration_done)
1182                 vdpa_dev->ops->migration_done(dev->vid);
1183
1184         return 0;
1185 }
1186
1187 static int
1188 vhost_user_net_set_mtu(struct virtio_net *dev, struct VhostUserMsg *msg)
1189 {
1190         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1191                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1192                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1193                                 msg->payload.u64);
1194
1195                 return -1;
1196         }
1197
1198         dev->mtu = msg->payload.u64;
1199
1200         return 0;
1201 }
1202
1203 static int
1204 vhost_user_set_req_fd(struct virtio_net *dev, struct VhostUserMsg *msg)
1205 {
1206         int fd = msg->fds[0];
1207
1208         if (fd < 0) {
1209                 RTE_LOG(ERR, VHOST_CONFIG,
1210                                 "Invalid file descriptor for slave channel (%d)\n",
1211                                 fd);
1212                 return -1;
1213         }
1214
1215         dev->slave_req_fd = fd;
1216
1217         return 0;
1218 }
1219
1220 static int
1221 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1222 {
1223         struct vhost_vring_addr *ra;
1224         uint64_t start, end;
1225
1226         start = imsg->iova;
1227         end = start + imsg->size;
1228
1229         ra = &vq->ring_addrs;
1230         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1231                 return 1;
1232         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1233                 return 1;
1234         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1235                 return 1;
1236
1237         return 0;
1238 }
1239
1240 static int
1241 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1242                                 struct vhost_iotlb_msg *imsg)
1243 {
1244         uint64_t istart, iend, vstart, vend;
1245
1246         istart = imsg->iova;
1247         iend = istart + imsg->size - 1;
1248
1249         vstart = (uintptr_t)vq->desc;
1250         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1251         if (vstart <= iend && istart <= vend)
1252                 return 1;
1253
1254         vstart = (uintptr_t)vq->avail;
1255         vend = vstart + sizeof(struct vring_avail);
1256         vend += sizeof(uint16_t) * vq->size - 1;
1257         if (vstart <= iend && istart <= vend)
1258                 return 1;
1259
1260         vstart = (uintptr_t)vq->used;
1261         vend = vstart + sizeof(struct vring_used);
1262         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1263         if (vstart <= iend && istart <= vend)
1264                 return 1;
1265
1266         return 0;
1267 }
1268
1269 static int
1270 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1271 {
1272         struct virtio_net *dev = *pdev;
1273         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1274         uint16_t i;
1275         uint64_t vva, len;
1276
1277         switch (imsg->type) {
1278         case VHOST_IOTLB_UPDATE:
1279                 len = imsg->size;
1280                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1281                 if (!vva)
1282                         return -1;
1283
1284                 for (i = 0; i < dev->nr_vring; i++) {
1285                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1286
1287                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1288                                         len, imsg->perm);
1289
1290                         if (is_vring_iotlb_update(vq, imsg))
1291                                 *pdev = dev = translate_ring_addresses(dev, i);
1292                 }
1293                 break;
1294         case VHOST_IOTLB_INVALIDATE:
1295                 for (i = 0; i < dev->nr_vring; i++) {
1296                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1297
1298                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1299                                         imsg->size);
1300
1301                         if (is_vring_iotlb_invalidate(vq, imsg))
1302                                 vring_invalidate(dev, vq);
1303                 }
1304                 break;
1305         default:
1306                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1307                                 imsg->type);
1308                 return -1;
1309         }
1310
1311         return 0;
1312 }
1313
1314 /* return bytes# of read on success or negative val on failure. */
1315 static int
1316 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1317 {
1318         int ret;
1319
1320         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1321                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1322         if (ret <= 0)
1323                 return ret;
1324
1325         if (msg && msg->size) {
1326                 if (msg->size > sizeof(msg->payload)) {
1327                         RTE_LOG(ERR, VHOST_CONFIG,
1328                                 "invalid msg size: %d\n", msg->size);
1329                         return -1;
1330                 }
1331                 ret = read(sockfd, &msg->payload, msg->size);
1332                 if (ret <= 0)
1333                         return ret;
1334                 if (ret != (int)msg->size) {
1335                         RTE_LOG(ERR, VHOST_CONFIG,
1336                                 "read control message failed\n");
1337                         return -1;
1338                 }
1339         }
1340
1341         return ret;
1342 }
1343
1344 static int
1345 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1346 {
1347         if (!msg)
1348                 return 0;
1349
1350         return send_fd_message(sockfd, (char *)msg,
1351                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1352 }
1353
1354 static int
1355 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1356 {
1357         if (!msg)
1358                 return 0;
1359
1360         msg->flags &= ~VHOST_USER_VERSION_MASK;
1361         msg->flags &= ~VHOST_USER_NEED_REPLY;
1362         msg->flags |= VHOST_USER_VERSION;
1363         msg->flags |= VHOST_USER_REPLY_MASK;
1364
1365         return send_vhost_message(sockfd, msg, NULL, 0);
1366 }
1367
1368 static int
1369 send_vhost_slave_message(struct virtio_net *dev, struct VhostUserMsg *msg,
1370                          int *fds, int fd_num)
1371 {
1372         int ret;
1373
1374         if (msg->flags & VHOST_USER_NEED_REPLY)
1375                 rte_spinlock_lock(&dev->slave_req_lock);
1376
1377         ret = send_vhost_message(dev->slave_req_fd, msg, fds, fd_num);
1378         if (ret < 0 && (msg->flags & VHOST_USER_NEED_REPLY))
1379                 rte_spinlock_unlock(&dev->slave_req_lock);
1380
1381         return ret;
1382 }
1383
1384 /*
1385  * Allocate a queue pair if it hasn't been allocated yet
1386  */
1387 static int
1388 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev, VhostUserMsg *msg)
1389 {
1390         uint16_t vring_idx;
1391
1392         switch (msg->request.master) {
1393         case VHOST_USER_SET_VRING_KICK:
1394         case VHOST_USER_SET_VRING_CALL:
1395         case VHOST_USER_SET_VRING_ERR:
1396                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1397                 break;
1398         case VHOST_USER_SET_VRING_NUM:
1399         case VHOST_USER_SET_VRING_BASE:
1400         case VHOST_USER_SET_VRING_ENABLE:
1401                 vring_idx = msg->payload.state.index;
1402                 break;
1403         case VHOST_USER_SET_VRING_ADDR:
1404                 vring_idx = msg->payload.addr.index;
1405                 break;
1406         default:
1407                 return 0;
1408         }
1409
1410         if (vring_idx >= VHOST_MAX_VRING) {
1411                 RTE_LOG(ERR, VHOST_CONFIG,
1412                         "invalid vring index: %u\n", vring_idx);
1413                 return -1;
1414         }
1415
1416         if (dev->virtqueue[vring_idx])
1417                 return 0;
1418
1419         return alloc_vring_queue(dev, vring_idx);
1420 }
1421
1422 static void
1423 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1424 {
1425         unsigned int i = 0;
1426         unsigned int vq_num = 0;
1427
1428         while (vq_num < dev->nr_vring) {
1429                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1430
1431                 if (vq) {
1432                         rte_spinlock_lock(&vq->access_lock);
1433                         vq_num++;
1434                 }
1435                 i++;
1436         }
1437 }
1438
1439 static void
1440 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1441 {
1442         unsigned int i = 0;
1443         unsigned int vq_num = 0;
1444
1445         while (vq_num < dev->nr_vring) {
1446                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1447
1448                 if (vq) {
1449                         rte_spinlock_unlock(&vq->access_lock);
1450                         vq_num++;
1451                 }
1452                 i++;
1453         }
1454 }
1455
1456 int
1457 vhost_user_msg_handler(int vid, int fd)
1458 {
1459         struct virtio_net *dev;
1460         struct VhostUserMsg msg;
1461         struct rte_vdpa_device *vdpa_dev;
1462         int did = -1;
1463         int ret;
1464         int unlock_required = 0;
1465         uint32_t skip_master = 0;
1466
1467         dev = get_device(vid);
1468         if (dev == NULL)
1469                 return -1;
1470
1471         if (!dev->notify_ops) {
1472                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1473                 if (!dev->notify_ops) {
1474                         RTE_LOG(ERR, VHOST_CONFIG,
1475                                 "failed to get callback ops for driver %s\n",
1476                                 dev->ifname);
1477                         return -1;
1478                 }
1479         }
1480
1481         ret = read_vhost_message(fd, &msg);
1482         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1483                 if (ret < 0)
1484                         RTE_LOG(ERR, VHOST_CONFIG,
1485                                 "vhost read message failed\n");
1486                 else if (ret == 0)
1487                         RTE_LOG(INFO, VHOST_CONFIG,
1488                                 "vhost peer closed\n");
1489                 else
1490                         RTE_LOG(ERR, VHOST_CONFIG,
1491                                 "vhost read incorrect message\n");
1492
1493                 return -1;
1494         }
1495
1496         ret = 0;
1497         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1498                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1499                         vhost_message_str[msg.request.master]);
1500         else
1501                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1502                         vhost_message_str[msg.request.master]);
1503
1504         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1505         if (ret < 0) {
1506                 RTE_LOG(ERR, VHOST_CONFIG,
1507                         "failed to alloc queue\n");
1508                 return -1;
1509         }
1510
1511         /*
1512          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1513          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1514          * and device is destroyed. destroy_device waits for queues to be
1515          * inactive, so it is safe. Otherwise taking the access_lock
1516          * would cause a dead lock.
1517          */
1518         switch (msg.request.master) {
1519         case VHOST_USER_SET_FEATURES:
1520         case VHOST_USER_SET_PROTOCOL_FEATURES:
1521         case VHOST_USER_SET_OWNER:
1522         case VHOST_USER_SET_MEM_TABLE:
1523         case VHOST_USER_SET_LOG_BASE:
1524         case VHOST_USER_SET_LOG_FD:
1525         case VHOST_USER_SET_VRING_NUM:
1526         case VHOST_USER_SET_VRING_ADDR:
1527         case VHOST_USER_SET_VRING_BASE:
1528         case VHOST_USER_SET_VRING_KICK:
1529         case VHOST_USER_SET_VRING_CALL:
1530         case VHOST_USER_SET_VRING_ERR:
1531         case VHOST_USER_SET_VRING_ENABLE:
1532         case VHOST_USER_SEND_RARP:
1533         case VHOST_USER_NET_SET_MTU:
1534         case VHOST_USER_SET_SLAVE_REQ_FD:
1535                 vhost_user_lock_all_queue_pairs(dev);
1536                 unlock_required = 1;
1537                 break;
1538         default:
1539                 break;
1540
1541         }
1542
1543         if (dev->extern_ops.pre_msg_handle) {
1544                 uint32_t need_reply;
1545
1546                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1547                                 (void *)&msg, &need_reply, &skip_master);
1548                 if (ret < 0)
1549                         goto skip_to_reply;
1550
1551                 if (need_reply)
1552                         send_vhost_reply(fd, &msg);
1553
1554                 if (skip_master)
1555                         goto skip_to_post_handle;
1556         }
1557
1558         switch (msg.request.master) {
1559         case VHOST_USER_GET_FEATURES:
1560                 msg.payload.u64 = vhost_user_get_features(dev);
1561                 msg.size = sizeof(msg.payload.u64);
1562                 send_vhost_reply(fd, &msg);
1563                 break;
1564         case VHOST_USER_SET_FEATURES:
1565                 ret = vhost_user_set_features(dev, msg.payload.u64);
1566                 if (ret)
1567                         return -1;
1568                 break;
1569
1570         case VHOST_USER_GET_PROTOCOL_FEATURES:
1571                 vhost_user_get_protocol_features(dev, &msg);
1572                 send_vhost_reply(fd, &msg);
1573                 break;
1574         case VHOST_USER_SET_PROTOCOL_FEATURES:
1575                 vhost_user_set_protocol_features(dev, msg.payload.u64);
1576                 break;
1577
1578         case VHOST_USER_SET_OWNER:
1579                 vhost_user_set_owner();
1580                 break;
1581         case VHOST_USER_RESET_OWNER:
1582                 vhost_user_reset_owner(dev);
1583                 break;
1584
1585         case VHOST_USER_SET_MEM_TABLE:
1586                 ret = vhost_user_set_mem_table(&dev, &msg);
1587                 break;
1588
1589         case VHOST_USER_SET_LOG_BASE:
1590                 vhost_user_set_log_base(dev, &msg);
1591
1592                 /* it needs a reply */
1593                 msg.size = sizeof(msg.payload.u64);
1594                 send_vhost_reply(fd, &msg);
1595                 break;
1596         case VHOST_USER_SET_LOG_FD:
1597                 close(msg.fds[0]);
1598                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1599                 break;
1600
1601         case VHOST_USER_SET_VRING_NUM:
1602                 vhost_user_set_vring_num(dev, &msg);
1603                 break;
1604         case VHOST_USER_SET_VRING_ADDR:
1605                 vhost_user_set_vring_addr(&dev, &msg);
1606                 break;
1607         case VHOST_USER_SET_VRING_BASE:
1608                 vhost_user_set_vring_base(dev, &msg);
1609                 break;
1610
1611         case VHOST_USER_GET_VRING_BASE:
1612                 vhost_user_get_vring_base(dev, &msg);
1613                 msg.size = sizeof(msg.payload.state);
1614                 send_vhost_reply(fd, &msg);
1615                 break;
1616
1617         case VHOST_USER_SET_VRING_KICK:
1618                 vhost_user_set_vring_kick(&dev, &msg);
1619                 break;
1620         case VHOST_USER_SET_VRING_CALL:
1621                 vhost_user_set_vring_call(dev, &msg);
1622                 break;
1623
1624         case VHOST_USER_SET_VRING_ERR:
1625                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1626                         close(msg.fds[0]);
1627                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1628                 break;
1629
1630         case VHOST_USER_GET_QUEUE_NUM:
1631                 msg.payload.u64 = (uint64_t)vhost_user_get_queue_num(dev);
1632                 msg.size = sizeof(msg.payload.u64);
1633                 send_vhost_reply(fd, &msg);
1634                 break;
1635
1636         case VHOST_USER_SET_VRING_ENABLE:
1637                 vhost_user_set_vring_enable(dev, &msg);
1638                 break;
1639         case VHOST_USER_SEND_RARP:
1640                 vhost_user_send_rarp(dev, &msg);
1641                 break;
1642
1643         case VHOST_USER_NET_SET_MTU:
1644                 ret = vhost_user_net_set_mtu(dev, &msg);
1645                 break;
1646
1647         case VHOST_USER_SET_SLAVE_REQ_FD:
1648                 ret = vhost_user_set_req_fd(dev, &msg);
1649                 break;
1650
1651         case VHOST_USER_IOTLB_MSG:
1652                 ret = vhost_user_iotlb_msg(&dev, &msg);
1653                 break;
1654
1655         default:
1656                 ret = -1;
1657                 break;
1658         }
1659
1660 skip_to_post_handle:
1661         if (dev->extern_ops.post_msg_handle) {
1662                 uint32_t need_reply;
1663
1664                 ret = (*dev->extern_ops.post_msg_handle)(
1665                                 dev->vid, (void *)&msg, &need_reply);
1666                 if (ret < 0)
1667                         goto skip_to_reply;
1668
1669                 if (need_reply)
1670                         send_vhost_reply(fd, &msg);
1671         }
1672
1673 skip_to_reply:
1674         if (unlock_required)
1675                 vhost_user_unlock_all_queue_pairs(dev);
1676
1677         if (msg.flags & VHOST_USER_NEED_REPLY) {
1678                 msg.payload.u64 = !!ret;
1679                 msg.size = sizeof(msg.payload.u64);
1680                 send_vhost_reply(fd, &msg);
1681         }
1682
1683         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1684                 dev->flags |= VIRTIO_DEV_READY;
1685
1686                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1687                         if (dev->dequeue_zero_copy) {
1688                                 RTE_LOG(INFO, VHOST_CONFIG,
1689                                                 "dequeue zero copy is enabled\n");
1690                         }
1691
1692                         if (dev->notify_ops->new_device(dev->vid) == 0)
1693                                 dev->flags |= VIRTIO_DEV_RUNNING;
1694                 }
1695         }
1696
1697         did = dev->vdpa_dev_id;
1698         vdpa_dev = rte_vdpa_get_device(did);
1699         if (vdpa_dev && virtio_is_ready(dev) &&
1700                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1701                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1702                 if (vdpa_dev->ops->dev_conf)
1703                         vdpa_dev->ops->dev_conf(dev->vid);
1704                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1705                 if (vhost_user_host_notifier_ctrl(dev->vid, true) != 0) {
1706                         RTE_LOG(INFO, VHOST_CONFIG,
1707                                 "(%d) software relay is used for vDPA, performance may be low.\n",
1708                                 dev->vid);
1709                 }
1710         }
1711
1712         return 0;
1713 }
1714
1715 static int process_slave_message_reply(struct virtio_net *dev,
1716                                        const VhostUserMsg *msg)
1717 {
1718         VhostUserMsg msg_reply;
1719         int ret;
1720
1721         if ((msg->flags & VHOST_USER_NEED_REPLY) == 0)
1722                 return 0;
1723
1724         if (read_vhost_message(dev->slave_req_fd, &msg_reply) < 0) {
1725                 ret = -1;
1726                 goto out;
1727         }
1728
1729         if (msg_reply.request.slave != msg->request.slave) {
1730                 RTE_LOG(ERR, VHOST_CONFIG,
1731                         "Received unexpected msg type (%u), expected %u\n",
1732                         msg_reply.request.slave, msg->request.slave);
1733                 ret = -1;
1734                 goto out;
1735         }
1736
1737         ret = msg_reply.payload.u64 ? -1 : 0;
1738
1739 out:
1740         rte_spinlock_unlock(&dev->slave_req_lock);
1741         return ret;
1742 }
1743
1744 int
1745 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1746 {
1747         int ret;
1748         struct VhostUserMsg msg = {
1749                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1750                 .flags = VHOST_USER_VERSION,
1751                 .size = sizeof(msg.payload.iotlb),
1752                 .payload.iotlb = {
1753                         .iova = iova,
1754                         .perm = perm,
1755                         .type = VHOST_IOTLB_MISS,
1756                 },
1757         };
1758
1759         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1760         if (ret < 0) {
1761                 RTE_LOG(ERR, VHOST_CONFIG,
1762                                 "Failed to send IOTLB miss message (%d)\n",
1763                                 ret);
1764                 return ret;
1765         }
1766
1767         return 0;
1768 }
1769
1770 static int vhost_user_slave_set_vring_host_notifier(struct virtio_net *dev,
1771                                                     int index, int fd,
1772                                                     uint64_t offset,
1773                                                     uint64_t size)
1774 {
1775         int *fdp = NULL;
1776         size_t fd_num = 0;
1777         int ret;
1778         struct VhostUserMsg msg = {
1779                 .request.slave = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1780                 .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY,
1781                 .size = sizeof(msg.payload.area),
1782                 .payload.area = {
1783                         .u64 = index & VHOST_USER_VRING_IDX_MASK,
1784                         .size = size,
1785                         .offset = offset,
1786                 },
1787         };
1788
1789         if (fd < 0)
1790                 msg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1791         else {
1792                 fdp = &fd;
1793                 fd_num = 1;
1794         }
1795
1796         ret = send_vhost_slave_message(dev, &msg, fdp, fd_num);
1797         if (ret < 0) {
1798                 RTE_LOG(ERR, VHOST_CONFIG,
1799                         "Failed to set host notifier (%d)\n", ret);
1800                 return ret;
1801         }
1802
1803         return process_slave_message_reply(dev, &msg);
1804 }
1805
1806 int vhost_user_host_notifier_ctrl(int vid, bool enable)
1807 {
1808         struct virtio_net *dev;
1809         struct rte_vdpa_device *vdpa_dev;
1810         int vfio_device_fd, did, ret = 0;
1811         uint64_t offset, size;
1812         unsigned int i;
1813
1814         dev = get_device(vid);
1815         if (!dev)
1816                 return -ENODEV;
1817
1818         did = dev->vdpa_dev_id;
1819         if (did < 0)
1820                 return -EINVAL;
1821
1822         if (!(dev->features & (1ULL << VIRTIO_F_VERSION_1)) ||
1823             !(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)) ||
1824             !(dev->protocol_features &
1825                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) ||
1826             !(dev->protocol_features &
1827                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) ||
1828             !(dev->protocol_features &
1829                         (1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER)))
1830                 return -ENOTSUP;
1831
1832         vdpa_dev = rte_vdpa_get_device(did);
1833         if (!vdpa_dev)
1834                 return -ENODEV;
1835
1836         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_vfio_device_fd, -ENOTSUP);
1837         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_notify_area, -ENOTSUP);
1838
1839         vfio_device_fd = vdpa_dev->ops->get_vfio_device_fd(vid);
1840         if (vfio_device_fd < 0)
1841                 return -ENOTSUP;
1842
1843         if (enable) {
1844                 for (i = 0; i < dev->nr_vring; i++) {
1845                         if (vdpa_dev->ops->get_notify_area(vid, i, &offset,
1846                                         &size) < 0) {
1847                                 ret = -ENOTSUP;
1848                                 goto disable;
1849                         }
1850
1851                         if (vhost_user_slave_set_vring_host_notifier(dev, i,
1852                                         vfio_device_fd, offset, size) < 0) {
1853                                 ret = -EFAULT;
1854                                 goto disable;
1855                         }
1856                 }
1857         } else {
1858 disable:
1859                 for (i = 0; i < dev->nr_vring; i++) {
1860                         vhost_user_slave_set_vring_host_notifier(dev, i, -1,
1861                                         0, 0);
1862                 }
1863         }
1864
1865         return ret;
1866 }