vhost/crypto: add vhost-user message handlers
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70 };
71
72 static uint64_t
73 get_blk_size(int fd)
74 {
75         struct stat stat;
76         int ret;
77
78         ret = fstat(fd, &stat);
79         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
80 }
81
82 static void
83 free_mem_region(struct virtio_net *dev)
84 {
85         uint32_t i;
86         struct rte_vhost_mem_region *reg;
87
88         if (!dev || !dev->mem)
89                 return;
90
91         for (i = 0; i < dev->mem->nregions; i++) {
92                 reg = &dev->mem->regions[i];
93                 if (reg->host_user_addr) {
94                         munmap(reg->mmap_addr, reg->mmap_size);
95                         close(reg->fd);
96                 }
97         }
98 }
99
100 void
101 vhost_backend_cleanup(struct virtio_net *dev)
102 {
103         if (dev->mem) {
104                 free_mem_region(dev);
105                 rte_free(dev->mem);
106                 dev->mem = NULL;
107         }
108
109         free(dev->guest_pages);
110         dev->guest_pages = NULL;
111
112         if (dev->log_addr) {
113                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
114                 dev->log_addr = 0;
115         }
116
117         if (dev->slave_req_fd >= 0) {
118                 close(dev->slave_req_fd);
119                 dev->slave_req_fd = -1;
120         }
121 }
122
123 /*
124  * This function just returns success at the moment unless
125  * the device hasn't been initialised.
126  */
127 static int
128 vhost_user_set_owner(void)
129 {
130         return 0;
131 }
132
133 static int
134 vhost_user_reset_owner(struct virtio_net *dev)
135 {
136         struct rte_vdpa_device *vdpa_dev;
137         int did = -1;
138
139         if (dev->flags & VIRTIO_DEV_RUNNING) {
140                 did = dev->vdpa_dev_id;
141                 vdpa_dev = rte_vdpa_get_device(did);
142                 if (vdpa_dev && vdpa_dev->ops->dev_close)
143                         vdpa_dev->ops->dev_close(dev->vid);
144                 dev->flags &= ~VIRTIO_DEV_RUNNING;
145                 dev->notify_ops->destroy_device(dev->vid);
146         }
147
148         cleanup_device(dev, 0);
149         reset_device(dev);
150         return 0;
151 }
152
153 /*
154  * The features that we support are requested.
155  */
156 static uint64_t
157 vhost_user_get_features(struct virtio_net *dev)
158 {
159         uint64_t features = 0;
160
161         rte_vhost_driver_get_features(dev->ifname, &features);
162         return features;
163 }
164
165 /*
166  * The queue number that we support are requested.
167  */
168 static uint32_t
169 vhost_user_get_queue_num(struct virtio_net *dev)
170 {
171         uint32_t queue_num = 0;
172
173         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
174         return queue_num;
175 }
176
177 /*
178  * We receive the negotiated features supported by us and the virtio device.
179  */
180 static int
181 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
182 {
183         uint64_t vhost_features = 0;
184         struct rte_vdpa_device *vdpa_dev;
185         int did = -1;
186
187         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
188         if (features & ~vhost_features) {
189                 RTE_LOG(ERR, VHOST_CONFIG,
190                         "(%d) received invalid negotiated features.\n",
191                         dev->vid);
192                 return -1;
193         }
194
195         if (dev->flags & VIRTIO_DEV_RUNNING) {
196                 if (dev->features == features)
197                         return 0;
198
199                 /*
200                  * Error out if master tries to change features while device is
201                  * in running state. The exception being VHOST_F_LOG_ALL, which
202                  * is enabled when the live-migration starts.
203                  */
204                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
205                         RTE_LOG(ERR, VHOST_CONFIG,
206                                 "(%d) features changed while device is running.\n",
207                                 dev->vid);
208                         return -1;
209                 }
210
211                 if (dev->notify_ops->features_changed)
212                         dev->notify_ops->features_changed(dev->vid, features);
213         }
214
215         did = dev->vdpa_dev_id;
216         vdpa_dev = rte_vdpa_get_device(did);
217         if (vdpa_dev && vdpa_dev->ops->set_features)
218                 vdpa_dev->ops->set_features(dev->vid);
219
220         dev->features = features;
221         if (dev->features &
222                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
223                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
224         } else {
225                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
226         }
227         VHOST_LOG_DEBUG(VHOST_CONFIG,
228                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
229                 dev->vid,
230                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
231                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
232
233         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
234             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
235                 /*
236                  * Remove all but first queue pair if MQ hasn't been
237                  * negotiated. This is safe because the device is not
238                  * running at this stage.
239                  */
240                 while (dev->nr_vring > 2) {
241                         struct vhost_virtqueue *vq;
242
243                         vq = dev->virtqueue[--dev->nr_vring];
244                         if (!vq)
245                                 continue;
246
247                         dev->virtqueue[dev->nr_vring] = NULL;
248                         cleanup_vq(vq, 1);
249                         free_vq(vq);
250                 }
251         }
252
253         return 0;
254 }
255
256 /*
257  * The virtio device sends us the size of the descriptor ring.
258  */
259 static int
260 vhost_user_set_vring_num(struct virtio_net *dev,
261                          VhostUserMsg *msg)
262 {
263         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
264
265         vq->size = msg->payload.state.num;
266
267         /* VIRTIO 1.0, 2.4 Virtqueues says:
268          *
269          *   Queue Size value is always a power of 2. The maximum Queue Size
270          *   value is 32768.
271          */
272         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
273                 RTE_LOG(ERR, VHOST_CONFIG,
274                         "invalid virtqueue size %u\n", vq->size);
275                 return -1;
276         }
277
278         if (dev->dequeue_zero_copy) {
279                 vq->nr_zmbuf = 0;
280                 vq->last_zmbuf_idx = 0;
281                 vq->zmbuf_size = vq->size;
282                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
283                                          sizeof(struct zcopy_mbuf), 0);
284                 if (vq->zmbufs == NULL) {
285                         RTE_LOG(WARNING, VHOST_CONFIG,
286                                 "failed to allocate mem for zero copy; "
287                                 "zero copy is force disabled\n");
288                         dev->dequeue_zero_copy = 0;
289                 }
290                 TAILQ_INIT(&vq->zmbuf_list);
291         }
292
293         vq->shadow_used_ring = rte_malloc(NULL,
294                                 vq->size * sizeof(struct vring_used_elem),
295                                 RTE_CACHE_LINE_SIZE);
296         if (!vq->shadow_used_ring) {
297                 RTE_LOG(ERR, VHOST_CONFIG,
298                         "failed to allocate memory for shadow used ring.\n");
299                 return -1;
300         }
301
302         vq->batch_copy_elems = rte_malloc(NULL,
303                                 vq->size * sizeof(struct batch_copy_elem),
304                                 RTE_CACHE_LINE_SIZE);
305         if (!vq->batch_copy_elems) {
306                 RTE_LOG(ERR, VHOST_CONFIG,
307                         "failed to allocate memory for batching copy.\n");
308                 return -1;
309         }
310
311         return 0;
312 }
313
314 /*
315  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
316  * same numa node as the memory of vring descriptor.
317  */
318 #ifdef RTE_LIBRTE_VHOST_NUMA
319 static struct virtio_net*
320 numa_realloc(struct virtio_net *dev, int index)
321 {
322         int oldnode, newnode;
323         struct virtio_net *old_dev;
324         struct vhost_virtqueue *old_vq, *vq;
325         struct zcopy_mbuf *new_zmbuf;
326         struct vring_used_elem *new_shadow_used_ring;
327         struct batch_copy_elem *new_batch_copy_elems;
328         int ret;
329
330         old_dev = dev;
331         vq = old_vq = dev->virtqueue[index];
332
333         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
334                             MPOL_F_NODE | MPOL_F_ADDR);
335
336         /* check if we need to reallocate vq */
337         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
338                              MPOL_F_NODE | MPOL_F_ADDR);
339         if (ret) {
340                 RTE_LOG(ERR, VHOST_CONFIG,
341                         "Unable to get vq numa information.\n");
342                 return dev;
343         }
344         if (oldnode != newnode) {
345                 RTE_LOG(INFO, VHOST_CONFIG,
346                         "reallocate vq from %d to %d node\n", oldnode, newnode);
347                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
348                 if (!vq)
349                         return dev;
350
351                 memcpy(vq, old_vq, sizeof(*vq));
352                 TAILQ_INIT(&vq->zmbuf_list);
353
354                 new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
355                         sizeof(struct zcopy_mbuf), 0, newnode);
356                 if (new_zmbuf) {
357                         rte_free(vq->zmbufs);
358                         vq->zmbufs = new_zmbuf;
359                 }
360
361                 new_shadow_used_ring = rte_malloc_socket(NULL,
362                         vq->size * sizeof(struct vring_used_elem),
363                         RTE_CACHE_LINE_SIZE,
364                         newnode);
365                 if (new_shadow_used_ring) {
366                         rte_free(vq->shadow_used_ring);
367                         vq->shadow_used_ring = new_shadow_used_ring;
368                 }
369
370                 new_batch_copy_elems = rte_malloc_socket(NULL,
371                         vq->size * sizeof(struct batch_copy_elem),
372                         RTE_CACHE_LINE_SIZE,
373                         newnode);
374                 if (new_batch_copy_elems) {
375                         rte_free(vq->batch_copy_elems);
376                         vq->batch_copy_elems = new_batch_copy_elems;
377                 }
378
379                 rte_free(old_vq);
380         }
381
382         /* check if we need to reallocate dev */
383         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
384                             MPOL_F_NODE | MPOL_F_ADDR);
385         if (ret) {
386                 RTE_LOG(ERR, VHOST_CONFIG,
387                         "Unable to get dev numa information.\n");
388                 goto out;
389         }
390         if (oldnode != newnode) {
391                 RTE_LOG(INFO, VHOST_CONFIG,
392                         "reallocate dev from %d to %d node\n",
393                         oldnode, newnode);
394                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
395                 if (!dev) {
396                         dev = old_dev;
397                         goto out;
398                 }
399
400                 memcpy(dev, old_dev, sizeof(*dev));
401                 rte_free(old_dev);
402         }
403
404 out:
405         dev->virtqueue[index] = vq;
406         vhost_devices[dev->vid] = dev;
407
408         if (old_vq != vq)
409                 vhost_user_iotlb_init(dev, index);
410
411         return dev;
412 }
413 #else
414 static struct virtio_net*
415 numa_realloc(struct virtio_net *dev, int index __rte_unused)
416 {
417         return dev;
418 }
419 #endif
420
421 /* Converts QEMU virtual address to Vhost virtual address. */
422 static uint64_t
423 qva_to_vva(struct virtio_net *dev, uint64_t qva)
424 {
425         struct rte_vhost_mem_region *reg;
426         uint32_t i;
427
428         /* Find the region where the address lives. */
429         for (i = 0; i < dev->mem->nregions; i++) {
430                 reg = &dev->mem->regions[i];
431
432                 if (qva >= reg->guest_user_addr &&
433                     qva <  reg->guest_user_addr + reg->size) {
434                         return qva - reg->guest_user_addr +
435                                reg->host_user_addr;
436                 }
437         }
438
439         return 0;
440 }
441
442
443 /*
444  * Converts ring address to Vhost virtual address.
445  * If IOMMU is enabled, the ring address is a guest IO virtual address,
446  * else it is a QEMU virtual address.
447  */
448 static uint64_t
449 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
450                 uint64_t ra, uint64_t size)
451 {
452         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
453                 uint64_t vva;
454
455                 vva = vhost_user_iotlb_cache_find(vq, ra,
456                                         &size, VHOST_ACCESS_RW);
457                 if (!vva)
458                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
459
460                 return vva;
461         }
462
463         return qva_to_vva(dev, ra);
464 }
465
466 static struct virtio_net *
467 translate_ring_addresses(struct virtio_net *dev, int vq_index)
468 {
469         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
470         struct vhost_vring_addr *addr = &vq->ring_addrs;
471
472         /* The addresses are converted from QEMU virtual to Vhost virtual. */
473         if (vq->desc && vq->avail && vq->used)
474                 return dev;
475
476         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
477                         vq, addr->desc_user_addr, sizeof(struct vring_desc));
478         if (vq->desc == 0) {
479                 RTE_LOG(DEBUG, VHOST_CONFIG,
480                         "(%d) failed to find desc ring address.\n",
481                         dev->vid);
482                 return dev;
483         }
484
485         dev = numa_realloc(dev, vq_index);
486         vq = dev->virtqueue[vq_index];
487         addr = &vq->ring_addrs;
488
489         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
490                         vq, addr->avail_user_addr, sizeof(struct vring_avail));
491         if (vq->avail == 0) {
492                 RTE_LOG(DEBUG, VHOST_CONFIG,
493                         "(%d) failed to find avail ring address.\n",
494                         dev->vid);
495                 return dev;
496         }
497
498         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
499                         vq, addr->used_user_addr, sizeof(struct vring_used));
500         if (vq->used == 0) {
501                 RTE_LOG(DEBUG, VHOST_CONFIG,
502                         "(%d) failed to find used ring address.\n",
503                         dev->vid);
504                 return dev;
505         }
506
507         if (vq->last_used_idx != vq->used->idx) {
508                 RTE_LOG(WARNING, VHOST_CONFIG,
509                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
510                         "some packets maybe resent for Tx and dropped for Rx\n",
511                         vq->last_used_idx, vq->used->idx);
512                 vq->last_used_idx  = vq->used->idx;
513                 vq->last_avail_idx = vq->used->idx;
514         }
515
516         vq->log_guest_addr = addr->log_guest_addr;
517
518         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
519                         dev->vid, vq->desc);
520         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
521                         dev->vid, vq->avail);
522         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
523                         dev->vid, vq->used);
524         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
525                         dev->vid, vq->log_guest_addr);
526
527         return dev;
528 }
529
530 /*
531  * The virtio device sends us the desc, used and avail ring addresses.
532  * This function then converts these to our address space.
533  */
534 static int
535 vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
536 {
537         struct vhost_virtqueue *vq;
538         struct vhost_vring_addr *addr = &msg->payload.addr;
539         struct virtio_net *dev = *pdev;
540
541         if (dev->mem == NULL)
542                 return -1;
543
544         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
545         vq = dev->virtqueue[msg->payload.addr.index];
546
547         /*
548          * Rings addresses should not be interpreted as long as the ring is not
549          * started and enabled
550          */
551         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
552
553         vring_invalidate(dev, vq);
554
555         if (vq->enabled && (dev->features &
556                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
557                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
558                 if (!dev)
559                         return -1;
560
561                 *pdev = dev;
562         }
563
564         return 0;
565 }
566
567 /*
568  * The virtio device sends us the available ring last used index.
569  */
570 static int
571 vhost_user_set_vring_base(struct virtio_net *dev,
572                           VhostUserMsg *msg)
573 {
574         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
575                         msg->payload.state.num;
576         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
577                         msg->payload.state.num;
578
579         return 0;
580 }
581
582 static int
583 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
584                    uint64_t host_phys_addr, uint64_t size)
585 {
586         struct guest_page *page, *last_page;
587
588         if (dev->nr_guest_pages == dev->max_guest_pages) {
589                 dev->max_guest_pages *= 2;
590                 dev->guest_pages = realloc(dev->guest_pages,
591                                         dev->max_guest_pages * sizeof(*page));
592                 if (!dev->guest_pages) {
593                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
594                         return -1;
595                 }
596         }
597
598         if (dev->nr_guest_pages > 0) {
599                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
600                 /* merge if the two pages are continuous */
601                 if (host_phys_addr == last_page->host_phys_addr +
602                                       last_page->size) {
603                         last_page->size += size;
604                         return 0;
605                 }
606         }
607
608         page = &dev->guest_pages[dev->nr_guest_pages++];
609         page->guest_phys_addr = guest_phys_addr;
610         page->host_phys_addr  = host_phys_addr;
611         page->size = size;
612
613         return 0;
614 }
615
616 static int
617 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
618                 uint64_t page_size)
619 {
620         uint64_t reg_size = reg->size;
621         uint64_t host_user_addr  = reg->host_user_addr;
622         uint64_t guest_phys_addr = reg->guest_phys_addr;
623         uint64_t host_phys_addr;
624         uint64_t size;
625
626         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
627         size = page_size - (guest_phys_addr & (page_size - 1));
628         size = RTE_MIN(size, reg_size);
629
630         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
631                 return -1;
632
633         host_user_addr  += size;
634         guest_phys_addr += size;
635         reg_size -= size;
636
637         while (reg_size > 0) {
638                 size = RTE_MIN(reg_size, page_size);
639                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
640                                                   host_user_addr);
641                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
642                                 size) < 0)
643                         return -1;
644
645                 host_user_addr  += size;
646                 guest_phys_addr += size;
647                 reg_size -= size;
648         }
649
650         return 0;
651 }
652
653 #ifdef RTE_LIBRTE_VHOST_DEBUG
654 /* TODO: enable it only in debug mode? */
655 static void
656 dump_guest_pages(struct virtio_net *dev)
657 {
658         uint32_t i;
659         struct guest_page *page;
660
661         for (i = 0; i < dev->nr_guest_pages; i++) {
662                 page = &dev->guest_pages[i];
663
664                 RTE_LOG(INFO, VHOST_CONFIG,
665                         "guest physical page region %u\n"
666                         "\t guest_phys_addr: %" PRIx64 "\n"
667                         "\t host_phys_addr : %" PRIx64 "\n"
668                         "\t size           : %" PRIx64 "\n",
669                         i,
670                         page->guest_phys_addr,
671                         page->host_phys_addr,
672                         page->size);
673         }
674 }
675 #else
676 #define dump_guest_pages(dev)
677 #endif
678
679 static bool
680 vhost_memory_changed(struct VhostUserMemory *new,
681                      struct rte_vhost_memory *old)
682 {
683         uint32_t i;
684
685         if (new->nregions != old->nregions)
686                 return true;
687
688         for (i = 0; i < new->nregions; ++i) {
689                 VhostUserMemoryRegion *new_r = &new->regions[i];
690                 struct rte_vhost_mem_region *old_r = &old->regions[i];
691
692                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
693                         return true;
694                 if (new_r->memory_size != old_r->size)
695                         return true;
696                 if (new_r->userspace_addr != old_r->guest_user_addr)
697                         return true;
698         }
699
700         return false;
701 }
702
703 static int
704 vhost_user_set_mem_table(struct virtio_net *dev, struct VhostUserMsg *pmsg)
705 {
706         struct VhostUserMemory memory = pmsg->payload.memory;
707         struct rte_vhost_mem_region *reg;
708         void *mmap_addr;
709         uint64_t mmap_size;
710         uint64_t mmap_offset;
711         uint64_t alignment;
712         uint32_t i;
713         int populate;
714         int fd;
715
716         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
717                 RTE_LOG(ERR, VHOST_CONFIG,
718                         "too many memory regions (%u)\n", memory.nregions);
719                 return -1;
720         }
721
722         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
723                 RTE_LOG(INFO, VHOST_CONFIG,
724                         "(%d) memory regions not changed\n", dev->vid);
725
726                 for (i = 0; i < memory.nregions; i++)
727                         close(pmsg->fds[i]);
728
729                 return 0;
730         }
731
732         if (dev->mem) {
733                 free_mem_region(dev);
734                 rte_free(dev->mem);
735                 dev->mem = NULL;
736         }
737
738         dev->nr_guest_pages = 0;
739         if (!dev->guest_pages) {
740                 dev->max_guest_pages = 8;
741                 dev->guest_pages = malloc(dev->max_guest_pages *
742                                                 sizeof(struct guest_page));
743                 if (dev->guest_pages == NULL) {
744                         RTE_LOG(ERR, VHOST_CONFIG,
745                                 "(%d) failed to allocate memory "
746                                 "for dev->guest_pages\n",
747                                 dev->vid);
748                         return -1;
749                 }
750         }
751
752         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
753                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
754         if (dev->mem == NULL) {
755                 RTE_LOG(ERR, VHOST_CONFIG,
756                         "(%d) failed to allocate memory for dev->mem\n",
757                         dev->vid);
758                 return -1;
759         }
760         dev->mem->nregions = memory.nregions;
761
762         for (i = 0; i < memory.nregions; i++) {
763                 fd  = pmsg->fds[i];
764                 reg = &dev->mem->regions[i];
765
766                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
767                 reg->guest_user_addr = memory.regions[i].userspace_addr;
768                 reg->size            = memory.regions[i].memory_size;
769                 reg->fd              = fd;
770
771                 mmap_offset = memory.regions[i].mmap_offset;
772
773                 /* Check for memory_size + mmap_offset overflow */
774                 if (mmap_offset >= -reg->size) {
775                         RTE_LOG(ERR, VHOST_CONFIG,
776                                 "mmap_offset (%#"PRIx64") and memory_size "
777                                 "(%#"PRIx64") overflow\n",
778                                 mmap_offset, reg->size);
779                         goto err_mmap;
780                 }
781
782                 mmap_size = reg->size + mmap_offset;
783
784                 /* mmap() without flag of MAP_ANONYMOUS, should be called
785                  * with length argument aligned with hugepagesz at older
786                  * longterm version Linux, like 2.6.32 and 3.2.72, or
787                  * mmap() will fail with EINVAL.
788                  *
789                  * to avoid failure, make sure in caller to keep length
790                  * aligned.
791                  */
792                 alignment = get_blk_size(fd);
793                 if (alignment == (uint64_t)-1) {
794                         RTE_LOG(ERR, VHOST_CONFIG,
795                                 "couldn't get hugepage size through fstat\n");
796                         goto err_mmap;
797                 }
798                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
799
800                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
801                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
802                                  MAP_SHARED | populate, fd, 0);
803
804                 if (mmap_addr == MAP_FAILED) {
805                         RTE_LOG(ERR, VHOST_CONFIG,
806                                 "mmap region %u failed.\n", i);
807                         goto err_mmap;
808                 }
809
810                 reg->mmap_addr = mmap_addr;
811                 reg->mmap_size = mmap_size;
812                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
813                                       mmap_offset;
814
815                 if (dev->dequeue_zero_copy)
816                         if (add_guest_pages(dev, reg, alignment) < 0) {
817                                 RTE_LOG(ERR, VHOST_CONFIG,
818                                         "adding guest pages to region %u failed.\n",
819                                         i);
820                                 goto err_mmap;
821                         }
822
823                 RTE_LOG(INFO, VHOST_CONFIG,
824                         "guest memory region %u, size: 0x%" PRIx64 "\n"
825                         "\t guest physical addr: 0x%" PRIx64 "\n"
826                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
827                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
828                         "\t mmap addr : 0x%" PRIx64 "\n"
829                         "\t mmap size : 0x%" PRIx64 "\n"
830                         "\t mmap align: 0x%" PRIx64 "\n"
831                         "\t mmap off  : 0x%" PRIx64 "\n",
832                         i, reg->size,
833                         reg->guest_phys_addr,
834                         reg->guest_user_addr,
835                         reg->host_user_addr,
836                         (uint64_t)(uintptr_t)mmap_addr,
837                         mmap_size,
838                         alignment,
839                         mmap_offset);
840         }
841
842         dump_guest_pages(dev);
843
844         return 0;
845
846 err_mmap:
847         free_mem_region(dev);
848         rte_free(dev->mem);
849         dev->mem = NULL;
850         return -1;
851 }
852
853 static int
854 vq_is_ready(struct vhost_virtqueue *vq)
855 {
856         return vq && vq->desc && vq->avail && vq->used &&
857                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
858                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
859 }
860
861 static int
862 virtio_is_ready(struct virtio_net *dev)
863 {
864         struct vhost_virtqueue *vq;
865         uint32_t i;
866
867         if (dev->nr_vring == 0)
868                 return 0;
869
870         for (i = 0; i < dev->nr_vring; i++) {
871                 vq = dev->virtqueue[i];
872
873                 if (!vq_is_ready(vq))
874                         return 0;
875         }
876
877         RTE_LOG(INFO, VHOST_CONFIG,
878                 "virtio is now ready for processing.\n");
879         return 1;
880 }
881
882 static void
883 vhost_user_set_vring_call(struct virtio_net *dev, struct VhostUserMsg *pmsg)
884 {
885         struct vhost_vring_file file;
886         struct vhost_virtqueue *vq;
887
888         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
889         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
890                 file.fd = VIRTIO_INVALID_EVENTFD;
891         else
892                 file.fd = pmsg->fds[0];
893         RTE_LOG(INFO, VHOST_CONFIG,
894                 "vring call idx:%d file:%d\n", file.index, file.fd);
895
896         vq = dev->virtqueue[file.index];
897         if (vq->callfd >= 0)
898                 close(vq->callfd);
899
900         vq->callfd = file.fd;
901 }
902
903 static void
904 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
905 {
906         struct vhost_vring_file file;
907         struct vhost_virtqueue *vq;
908         struct virtio_net *dev = *pdev;
909
910         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
911         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
912                 file.fd = VIRTIO_INVALID_EVENTFD;
913         else
914                 file.fd = pmsg->fds[0];
915         RTE_LOG(INFO, VHOST_CONFIG,
916                 "vring kick idx:%d file:%d\n", file.index, file.fd);
917
918         /* Interpret ring addresses only when ring is started. */
919         dev = translate_ring_addresses(dev, file.index);
920         if (!dev)
921                 return;
922
923         *pdev = dev;
924
925         vq = dev->virtqueue[file.index];
926
927         /*
928          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
929          * the ring starts already enabled. Otherwise, it is enabled via
930          * the SET_VRING_ENABLE message.
931          */
932         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
933                 vq->enabled = 1;
934
935         if (vq->kickfd >= 0)
936                 close(vq->kickfd);
937         vq->kickfd = file.fd;
938 }
939
940 static void
941 free_zmbufs(struct vhost_virtqueue *vq)
942 {
943         struct zcopy_mbuf *zmbuf, *next;
944
945         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
946              zmbuf != NULL; zmbuf = next) {
947                 next = TAILQ_NEXT(zmbuf, next);
948
949                 rte_pktmbuf_free(zmbuf->mbuf);
950                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
951         }
952
953         rte_free(vq->zmbufs);
954 }
955
956 /*
957  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
958  */
959 static int
960 vhost_user_get_vring_base(struct virtio_net *dev,
961                           VhostUserMsg *msg)
962 {
963         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
964         struct rte_vdpa_device *vdpa_dev;
965         int did = -1;
966
967         /* We have to stop the queue (virtio) if it is running. */
968         if (dev->flags & VIRTIO_DEV_RUNNING) {
969                 did = dev->vdpa_dev_id;
970                 vdpa_dev = rte_vdpa_get_device(did);
971                 if (vdpa_dev && vdpa_dev->ops->dev_close)
972                         vdpa_dev->ops->dev_close(dev->vid);
973                 dev->flags &= ~VIRTIO_DEV_RUNNING;
974                 dev->notify_ops->destroy_device(dev->vid);
975         }
976
977         dev->flags &= ~VIRTIO_DEV_READY;
978         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
979
980         /* Here we are safe to get the last avail index */
981         msg->payload.state.num = vq->last_avail_idx;
982
983         RTE_LOG(INFO, VHOST_CONFIG,
984                 "vring base idx:%d file:%d\n", msg->payload.state.index,
985                 msg->payload.state.num);
986         /*
987          * Based on current qemu vhost-user implementation, this message is
988          * sent and only sent in vhost_vring_stop.
989          * TODO: cleanup the vring, it isn't usable since here.
990          */
991         if (vq->kickfd >= 0)
992                 close(vq->kickfd);
993
994         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
995
996         if (vq->callfd >= 0)
997                 close(vq->callfd);
998
999         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1000
1001         if (dev->dequeue_zero_copy)
1002                 free_zmbufs(vq);
1003         rte_free(vq->shadow_used_ring);
1004         vq->shadow_used_ring = NULL;
1005
1006         rte_free(vq->batch_copy_elems);
1007         vq->batch_copy_elems = NULL;
1008
1009         return 0;
1010 }
1011
1012 /*
1013  * when virtio queues are ready to work, qemu will send us to
1014  * enable the virtio queue pair.
1015  */
1016 static int
1017 vhost_user_set_vring_enable(struct virtio_net *dev,
1018                             VhostUserMsg *msg)
1019 {
1020         int enable = (int)msg->payload.state.num;
1021         int index = (int)msg->payload.state.index;
1022         struct rte_vdpa_device *vdpa_dev;
1023         int did = -1;
1024
1025         RTE_LOG(INFO, VHOST_CONFIG,
1026                 "set queue enable: %d to qp idx: %d\n",
1027                 enable, index);
1028
1029         did = dev->vdpa_dev_id;
1030         vdpa_dev = rte_vdpa_get_device(did);
1031         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1032                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1033
1034         if (dev->notify_ops->vring_state_changed)
1035                 dev->notify_ops->vring_state_changed(dev->vid,
1036                                 index, enable);
1037
1038         dev->virtqueue[index]->enabled = enable;
1039
1040         return 0;
1041 }
1042
1043 static void
1044 vhost_user_get_protocol_features(struct virtio_net *dev,
1045                                  struct VhostUserMsg *msg)
1046 {
1047         uint64_t features, protocol_features;
1048
1049         rte_vhost_driver_get_features(dev->ifname, &features);
1050         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1051
1052         /*
1053          * REPLY_ACK protocol feature is only mandatory for now
1054          * for IOMMU feature. If IOMMU is explicitly disabled by the
1055          * application, disable also REPLY_ACK feature for older buggy
1056          * Qemu versions (from v2.7.0 to v2.9.0).
1057          */
1058         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1059                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1060
1061         msg->payload.u64 = protocol_features;
1062         msg->size = sizeof(msg->payload.u64);
1063 }
1064
1065 static void
1066 vhost_user_set_protocol_features(struct virtio_net *dev,
1067                                  uint64_t protocol_features)
1068 {
1069         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES)
1070                 return;
1071
1072         dev->protocol_features = protocol_features;
1073 }
1074
1075 static int
1076 vhost_user_set_log_base(struct virtio_net *dev, struct VhostUserMsg *msg)
1077 {
1078         int fd = msg->fds[0];
1079         uint64_t size, off;
1080         void *addr;
1081
1082         if (fd < 0) {
1083                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1084                 return -1;
1085         }
1086
1087         if (msg->size != sizeof(VhostUserLog)) {
1088                 RTE_LOG(ERR, VHOST_CONFIG,
1089                         "invalid log base msg size: %"PRId32" != %d\n",
1090                         msg->size, (int)sizeof(VhostUserLog));
1091                 return -1;
1092         }
1093
1094         size = msg->payload.log.mmap_size;
1095         off  = msg->payload.log.mmap_offset;
1096
1097         /* Don't allow mmap_offset to point outside the mmap region */
1098         if (off > size) {
1099                 RTE_LOG(ERR, VHOST_CONFIG,
1100                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1101                         off, size);
1102                 return -1;
1103         }
1104
1105         RTE_LOG(INFO, VHOST_CONFIG,
1106                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1107                 size, off);
1108
1109         /*
1110          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1111          * fail when offset is not page size aligned.
1112          */
1113         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1114         close(fd);
1115         if (addr == MAP_FAILED) {
1116                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1117                 return -1;
1118         }
1119
1120         /*
1121          * Free previously mapped log memory on occasionally
1122          * multiple VHOST_USER_SET_LOG_BASE.
1123          */
1124         if (dev->log_addr) {
1125                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1126         }
1127         dev->log_addr = (uint64_t)(uintptr_t)addr;
1128         dev->log_base = dev->log_addr + off;
1129         dev->log_size = size;
1130
1131         return 0;
1132 }
1133
1134 /*
1135  * An rarp packet is constructed and broadcasted to notify switches about
1136  * the new location of the migrated VM, so that packets from outside will
1137  * not be lost after migration.
1138  *
1139  * However, we don't actually "send" a rarp packet here, instead, we set
1140  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1141  */
1142 static int
1143 vhost_user_send_rarp(struct virtio_net *dev, struct VhostUserMsg *msg)
1144 {
1145         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1146         struct rte_vdpa_device *vdpa_dev;
1147         int did = -1;
1148
1149         RTE_LOG(DEBUG, VHOST_CONFIG,
1150                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1151                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1152         memcpy(dev->mac.addr_bytes, mac, 6);
1153
1154         /*
1155          * Set the flag to inject a RARP broadcast packet at
1156          * rte_vhost_dequeue_burst().
1157          *
1158          * rte_smp_wmb() is for making sure the mac is copied
1159          * before the flag is set.
1160          */
1161         rte_smp_wmb();
1162         rte_atomic16_set(&dev->broadcast_rarp, 1);
1163         did = dev->vdpa_dev_id;
1164         vdpa_dev = rte_vdpa_get_device(did);
1165         if (vdpa_dev && vdpa_dev->ops->migration_done)
1166                 vdpa_dev->ops->migration_done(dev->vid);
1167
1168         return 0;
1169 }
1170
1171 static int
1172 vhost_user_net_set_mtu(struct virtio_net *dev, struct VhostUserMsg *msg)
1173 {
1174         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1175                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1176                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1177                                 msg->payload.u64);
1178
1179                 return -1;
1180         }
1181
1182         dev->mtu = msg->payload.u64;
1183
1184         return 0;
1185 }
1186
1187 static int
1188 vhost_user_set_req_fd(struct virtio_net *dev, struct VhostUserMsg *msg)
1189 {
1190         int fd = msg->fds[0];
1191
1192         if (fd < 0) {
1193                 RTE_LOG(ERR, VHOST_CONFIG,
1194                                 "Invalid file descriptor for slave channel (%d)\n",
1195                                 fd);
1196                 return -1;
1197         }
1198
1199         dev->slave_req_fd = fd;
1200
1201         return 0;
1202 }
1203
1204 static int
1205 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1206 {
1207         struct vhost_vring_addr *ra;
1208         uint64_t start, end;
1209
1210         start = imsg->iova;
1211         end = start + imsg->size;
1212
1213         ra = &vq->ring_addrs;
1214         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1215                 return 1;
1216         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1217                 return 1;
1218         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1219                 return 1;
1220
1221         return 0;
1222 }
1223
1224 static int
1225 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1226                                 struct vhost_iotlb_msg *imsg)
1227 {
1228         uint64_t istart, iend, vstart, vend;
1229
1230         istart = imsg->iova;
1231         iend = istart + imsg->size - 1;
1232
1233         vstart = (uintptr_t)vq->desc;
1234         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1235         if (vstart <= iend && istart <= vend)
1236                 return 1;
1237
1238         vstart = (uintptr_t)vq->avail;
1239         vend = vstart + sizeof(struct vring_avail);
1240         vend += sizeof(uint16_t) * vq->size - 1;
1241         if (vstart <= iend && istart <= vend)
1242                 return 1;
1243
1244         vstart = (uintptr_t)vq->used;
1245         vend = vstart + sizeof(struct vring_used);
1246         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1247         if (vstart <= iend && istart <= vend)
1248                 return 1;
1249
1250         return 0;
1251 }
1252
1253 static int
1254 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1255 {
1256         struct virtio_net *dev = *pdev;
1257         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1258         uint16_t i;
1259         uint64_t vva;
1260
1261         switch (imsg->type) {
1262         case VHOST_IOTLB_UPDATE:
1263                 vva = qva_to_vva(dev, imsg->uaddr);
1264                 if (!vva)
1265                         return -1;
1266
1267                 for (i = 0; i < dev->nr_vring; i++) {
1268                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1269
1270                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1271                                         imsg->size, imsg->perm);
1272
1273                         if (is_vring_iotlb_update(vq, imsg))
1274                                 *pdev = dev = translate_ring_addresses(dev, i);
1275                 }
1276                 break;
1277         case VHOST_IOTLB_INVALIDATE:
1278                 for (i = 0; i < dev->nr_vring; i++) {
1279                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1280
1281                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1282                                         imsg->size);
1283
1284                         if (is_vring_iotlb_invalidate(vq, imsg))
1285                                 vring_invalidate(dev, vq);
1286                 }
1287                 break;
1288         default:
1289                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1290                                 imsg->type);
1291                 return -1;
1292         }
1293
1294         return 0;
1295 }
1296
1297 /* return bytes# of read on success or negative val on failure. */
1298 static int
1299 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1300 {
1301         int ret;
1302
1303         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1304                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1305         if (ret <= 0)
1306                 return ret;
1307
1308         if (msg && msg->size) {
1309                 if (msg->size > sizeof(msg->payload)) {
1310                         RTE_LOG(ERR, VHOST_CONFIG,
1311                                 "invalid msg size: %d\n", msg->size);
1312                         return -1;
1313                 }
1314                 ret = read(sockfd, &msg->payload, msg->size);
1315                 if (ret <= 0)
1316                         return ret;
1317                 if (ret != (int)msg->size) {
1318                         RTE_LOG(ERR, VHOST_CONFIG,
1319                                 "read control message failed\n");
1320                         return -1;
1321                 }
1322         }
1323
1324         return ret;
1325 }
1326
1327 static int
1328 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1329 {
1330         if (!msg)
1331                 return 0;
1332
1333         return send_fd_message(sockfd, (char *)msg,
1334                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1335 }
1336
1337 static int
1338 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1339 {
1340         if (!msg)
1341                 return 0;
1342
1343         msg->flags &= ~VHOST_USER_VERSION_MASK;
1344         msg->flags &= ~VHOST_USER_NEED_REPLY;
1345         msg->flags |= VHOST_USER_VERSION;
1346         msg->flags |= VHOST_USER_REPLY_MASK;
1347
1348         return send_vhost_message(sockfd, msg, NULL, 0);
1349 }
1350
1351 /*
1352  * Allocate a queue pair if it hasn't been allocated yet
1353  */
1354 static int
1355 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev, VhostUserMsg *msg)
1356 {
1357         uint16_t vring_idx;
1358
1359         switch (msg->request.master) {
1360         case VHOST_USER_SET_VRING_KICK:
1361         case VHOST_USER_SET_VRING_CALL:
1362         case VHOST_USER_SET_VRING_ERR:
1363                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1364                 break;
1365         case VHOST_USER_SET_VRING_NUM:
1366         case VHOST_USER_SET_VRING_BASE:
1367         case VHOST_USER_SET_VRING_ENABLE:
1368                 vring_idx = msg->payload.state.index;
1369                 break;
1370         case VHOST_USER_SET_VRING_ADDR:
1371                 vring_idx = msg->payload.addr.index;
1372                 break;
1373         default:
1374                 return 0;
1375         }
1376
1377         if (vring_idx >= VHOST_MAX_VRING) {
1378                 RTE_LOG(ERR, VHOST_CONFIG,
1379                         "invalid vring index: %u\n", vring_idx);
1380                 return -1;
1381         }
1382
1383         if (dev->virtqueue[vring_idx])
1384                 return 0;
1385
1386         return alloc_vring_queue(dev, vring_idx);
1387 }
1388
1389 static void
1390 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1391 {
1392         unsigned int i = 0;
1393         unsigned int vq_num = 0;
1394
1395         while (vq_num < dev->nr_vring) {
1396                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1397
1398                 if (vq) {
1399                         rte_spinlock_lock(&vq->access_lock);
1400                         vq_num++;
1401                 }
1402                 i++;
1403         }
1404 }
1405
1406 static void
1407 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1408 {
1409         unsigned int i = 0;
1410         unsigned int vq_num = 0;
1411
1412         while (vq_num < dev->nr_vring) {
1413                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1414
1415                 if (vq) {
1416                         rte_spinlock_unlock(&vq->access_lock);
1417                         vq_num++;
1418                 }
1419                 i++;
1420         }
1421 }
1422
1423 int
1424 vhost_user_msg_handler(int vid, int fd)
1425 {
1426         struct virtio_net *dev;
1427         struct VhostUserMsg msg;
1428         struct rte_vdpa_device *vdpa_dev;
1429         int did = -1;
1430         int ret;
1431         int unlock_required = 0;
1432         uint32_t skip_master = 0;
1433
1434         dev = get_device(vid);
1435         if (dev == NULL)
1436                 return -1;
1437
1438         if (!dev->notify_ops) {
1439                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1440                 if (!dev->notify_ops) {
1441                         RTE_LOG(ERR, VHOST_CONFIG,
1442                                 "failed to get callback ops for driver %s\n",
1443                                 dev->ifname);
1444                         return -1;
1445                 }
1446         }
1447
1448         ret = read_vhost_message(fd, &msg);
1449         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1450                 if (ret < 0)
1451                         RTE_LOG(ERR, VHOST_CONFIG,
1452                                 "vhost read message failed\n");
1453                 else if (ret == 0)
1454                         RTE_LOG(INFO, VHOST_CONFIG,
1455                                 "vhost peer closed\n");
1456                 else
1457                         RTE_LOG(ERR, VHOST_CONFIG,
1458                                 "vhost read incorrect message\n");
1459
1460                 return -1;
1461         }
1462
1463         ret = 0;
1464         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1465                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1466                         vhost_message_str[msg.request.master]);
1467         else
1468                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1469                         vhost_message_str[msg.request.master]);
1470
1471         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1472         if (ret < 0) {
1473                 RTE_LOG(ERR, VHOST_CONFIG,
1474                         "failed to alloc queue\n");
1475                 return -1;
1476         }
1477
1478         /*
1479          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1480          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1481          * and device is destroyed. destroy_device waits for queues to be
1482          * inactive, so it is safe. Otherwise taking the access_lock
1483          * would cause a dead lock.
1484          */
1485         switch (msg.request.master) {
1486         case VHOST_USER_SET_FEATURES:
1487         case VHOST_USER_SET_PROTOCOL_FEATURES:
1488         case VHOST_USER_SET_OWNER:
1489         case VHOST_USER_SET_MEM_TABLE:
1490         case VHOST_USER_SET_LOG_BASE:
1491         case VHOST_USER_SET_LOG_FD:
1492         case VHOST_USER_SET_VRING_NUM:
1493         case VHOST_USER_SET_VRING_ADDR:
1494         case VHOST_USER_SET_VRING_BASE:
1495         case VHOST_USER_SET_VRING_KICK:
1496         case VHOST_USER_SET_VRING_CALL:
1497         case VHOST_USER_SET_VRING_ERR:
1498         case VHOST_USER_SET_VRING_ENABLE:
1499         case VHOST_USER_SEND_RARP:
1500         case VHOST_USER_NET_SET_MTU:
1501         case VHOST_USER_SET_SLAVE_REQ_FD:
1502                 vhost_user_lock_all_queue_pairs(dev);
1503                 unlock_required = 1;
1504                 break;
1505         default:
1506                 break;
1507
1508         }
1509
1510         if (dev->extern_ops.pre_msg_handle) {
1511                 uint32_t need_reply;
1512
1513                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1514                                 (void *)&msg, &need_reply, &skip_master);
1515                 if (ret < 0)
1516                         goto skip_to_reply;
1517
1518                 if (need_reply)
1519                         send_vhost_reply(fd, &msg);
1520
1521                 if (skip_master)
1522                         goto skip_to_post_handle;
1523         }
1524
1525         switch (msg.request.master) {
1526         case VHOST_USER_GET_FEATURES:
1527                 msg.payload.u64 = vhost_user_get_features(dev);
1528                 msg.size = sizeof(msg.payload.u64);
1529                 send_vhost_reply(fd, &msg);
1530                 break;
1531         case VHOST_USER_SET_FEATURES:
1532                 ret = vhost_user_set_features(dev, msg.payload.u64);
1533                 if (ret)
1534                         return -1;
1535                 break;
1536
1537         case VHOST_USER_GET_PROTOCOL_FEATURES:
1538                 vhost_user_get_protocol_features(dev, &msg);
1539                 send_vhost_reply(fd, &msg);
1540                 break;
1541         case VHOST_USER_SET_PROTOCOL_FEATURES:
1542                 vhost_user_set_protocol_features(dev, msg.payload.u64);
1543                 break;
1544
1545         case VHOST_USER_SET_OWNER:
1546                 vhost_user_set_owner();
1547                 break;
1548         case VHOST_USER_RESET_OWNER:
1549                 vhost_user_reset_owner(dev);
1550                 break;
1551
1552         case VHOST_USER_SET_MEM_TABLE:
1553                 ret = vhost_user_set_mem_table(dev, &msg);
1554                 break;
1555
1556         case VHOST_USER_SET_LOG_BASE:
1557                 vhost_user_set_log_base(dev, &msg);
1558
1559                 /* it needs a reply */
1560                 msg.size = sizeof(msg.payload.u64);
1561                 send_vhost_reply(fd, &msg);
1562                 break;
1563         case VHOST_USER_SET_LOG_FD:
1564                 close(msg.fds[0]);
1565                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1566                 break;
1567
1568         case VHOST_USER_SET_VRING_NUM:
1569                 vhost_user_set_vring_num(dev, &msg);
1570                 break;
1571         case VHOST_USER_SET_VRING_ADDR:
1572                 vhost_user_set_vring_addr(&dev, &msg);
1573                 break;
1574         case VHOST_USER_SET_VRING_BASE:
1575                 vhost_user_set_vring_base(dev, &msg);
1576                 break;
1577
1578         case VHOST_USER_GET_VRING_BASE:
1579                 vhost_user_get_vring_base(dev, &msg);
1580                 msg.size = sizeof(msg.payload.state);
1581                 send_vhost_reply(fd, &msg);
1582                 break;
1583
1584         case VHOST_USER_SET_VRING_KICK:
1585                 vhost_user_set_vring_kick(&dev, &msg);
1586                 break;
1587         case VHOST_USER_SET_VRING_CALL:
1588                 vhost_user_set_vring_call(dev, &msg);
1589                 break;
1590
1591         case VHOST_USER_SET_VRING_ERR:
1592                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1593                         close(msg.fds[0]);
1594                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1595                 break;
1596
1597         case VHOST_USER_GET_QUEUE_NUM:
1598                 msg.payload.u64 = (uint64_t)vhost_user_get_queue_num(dev);
1599                 msg.size = sizeof(msg.payload.u64);
1600                 send_vhost_reply(fd, &msg);
1601                 break;
1602
1603         case VHOST_USER_SET_VRING_ENABLE:
1604                 vhost_user_set_vring_enable(dev, &msg);
1605                 break;
1606         case VHOST_USER_SEND_RARP:
1607                 vhost_user_send_rarp(dev, &msg);
1608                 break;
1609
1610         case VHOST_USER_NET_SET_MTU:
1611                 ret = vhost_user_net_set_mtu(dev, &msg);
1612                 break;
1613
1614         case VHOST_USER_SET_SLAVE_REQ_FD:
1615                 ret = vhost_user_set_req_fd(dev, &msg);
1616                 break;
1617
1618         case VHOST_USER_IOTLB_MSG:
1619                 ret = vhost_user_iotlb_msg(&dev, &msg);
1620                 break;
1621
1622         default:
1623                 ret = -1;
1624                 break;
1625         }
1626
1627 skip_to_post_handle:
1628         if (dev->extern_ops.post_msg_handle) {
1629                 uint32_t need_reply;
1630
1631                 ret = (*dev->extern_ops.post_msg_handle)(
1632                                 dev->vid, (void *)&msg, &need_reply);
1633                 if (ret < 0)
1634                         goto skip_to_reply;
1635
1636                 if (need_reply)
1637                         send_vhost_reply(fd, &msg);
1638         }
1639
1640 skip_to_reply:
1641         if (unlock_required)
1642                 vhost_user_unlock_all_queue_pairs(dev);
1643
1644         if (msg.flags & VHOST_USER_NEED_REPLY) {
1645                 msg.payload.u64 = !!ret;
1646                 msg.size = sizeof(msg.payload.u64);
1647                 send_vhost_reply(fd, &msg);
1648         }
1649
1650         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1651                 dev->flags |= VIRTIO_DEV_READY;
1652
1653                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1654                         if (dev->dequeue_zero_copy) {
1655                                 RTE_LOG(INFO, VHOST_CONFIG,
1656                                                 "dequeue zero copy is enabled\n");
1657                         }
1658
1659                         if (dev->notify_ops->new_device(dev->vid) == 0)
1660                                 dev->flags |= VIRTIO_DEV_RUNNING;
1661                 }
1662         }
1663
1664         did = dev->vdpa_dev_id;
1665         vdpa_dev = rte_vdpa_get_device(did);
1666         if (vdpa_dev && virtio_is_ready(dev) &&
1667                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1668                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1669                 if (vdpa_dev->ops->dev_conf)
1670                         vdpa_dev->ops->dev_conf(dev->vid);
1671                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1672         }
1673
1674         return 0;
1675 }
1676
1677 int
1678 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1679 {
1680         int ret;
1681         struct VhostUserMsg msg = {
1682                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1683                 .flags = VHOST_USER_VERSION,
1684                 .size = sizeof(msg.payload.iotlb),
1685                 .payload.iotlb = {
1686                         .iova = iova,
1687                         .perm = perm,
1688                         .type = VHOST_IOTLB_MISS,
1689                 },
1690         };
1691
1692         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1693         if (ret < 0) {
1694                 RTE_LOG(ERR, VHOST_CONFIG,
1695                                 "Failed to send IOTLB miss message (%d)\n",
1696                                 ret);
1697                 return ret;
1698         }
1699
1700         return 0;
1701 }