vhost/crypto: add session message handler
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(void)
131 {
132         return 0;
133 }
134
135 static int
136 vhost_user_reset_owner(struct virtio_net *dev)
137 {
138         struct rte_vdpa_device *vdpa_dev;
139         int did = -1;
140
141         if (dev->flags & VIRTIO_DEV_RUNNING) {
142                 did = dev->vdpa_dev_id;
143                 vdpa_dev = rte_vdpa_get_device(did);
144                 if (vdpa_dev && vdpa_dev->ops->dev_close)
145                         vdpa_dev->ops->dev_close(dev->vid);
146                 dev->flags &= ~VIRTIO_DEV_RUNNING;
147                 dev->notify_ops->destroy_device(dev->vid);
148         }
149
150         cleanup_device(dev, 0);
151         reset_device(dev);
152         return 0;
153 }
154
155 /*
156  * The features that we support are requested.
157  */
158 static uint64_t
159 vhost_user_get_features(struct virtio_net *dev)
160 {
161         uint64_t features = 0;
162
163         rte_vhost_driver_get_features(dev->ifname, &features);
164         return features;
165 }
166
167 /*
168  * The queue number that we support are requested.
169  */
170 static uint32_t
171 vhost_user_get_queue_num(struct virtio_net *dev)
172 {
173         uint32_t queue_num = 0;
174
175         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
176         return queue_num;
177 }
178
179 /*
180  * We receive the negotiated features supported by us and the virtio device.
181  */
182 static int
183 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
184 {
185         uint64_t vhost_features = 0;
186         struct rte_vdpa_device *vdpa_dev;
187         int did = -1;
188
189         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
190         if (features & ~vhost_features) {
191                 RTE_LOG(ERR, VHOST_CONFIG,
192                         "(%d) received invalid negotiated features.\n",
193                         dev->vid);
194                 return -1;
195         }
196
197         if (dev->flags & VIRTIO_DEV_RUNNING) {
198                 if (dev->features == features)
199                         return 0;
200
201                 /*
202                  * Error out if master tries to change features while device is
203                  * in running state. The exception being VHOST_F_LOG_ALL, which
204                  * is enabled when the live-migration starts.
205                  */
206                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
207                         RTE_LOG(ERR, VHOST_CONFIG,
208                                 "(%d) features changed while device is running.\n",
209                                 dev->vid);
210                         return -1;
211                 }
212
213                 if (dev->notify_ops->features_changed)
214                         dev->notify_ops->features_changed(dev->vid, features);
215         }
216
217         did = dev->vdpa_dev_id;
218         vdpa_dev = rte_vdpa_get_device(did);
219         if (vdpa_dev && vdpa_dev->ops->set_features)
220                 vdpa_dev->ops->set_features(dev->vid);
221
222         dev->features = features;
223         if (dev->features &
224                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
225                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
226         } else {
227                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
228         }
229         VHOST_LOG_DEBUG(VHOST_CONFIG,
230                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
231                 dev->vid,
232                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
233                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
234
235         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
236             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
237                 /*
238                  * Remove all but first queue pair if MQ hasn't been
239                  * negotiated. This is safe because the device is not
240                  * running at this stage.
241                  */
242                 while (dev->nr_vring > 2) {
243                         struct vhost_virtqueue *vq;
244
245                         vq = dev->virtqueue[--dev->nr_vring];
246                         if (!vq)
247                                 continue;
248
249                         dev->virtqueue[dev->nr_vring] = NULL;
250                         cleanup_vq(vq, 1);
251                         free_vq(vq);
252                 }
253         }
254
255         return 0;
256 }
257
258 /*
259  * The virtio device sends us the size of the descriptor ring.
260  */
261 static int
262 vhost_user_set_vring_num(struct virtio_net *dev,
263                          VhostUserMsg *msg)
264 {
265         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
266
267         vq->size = msg->payload.state.num;
268
269         /* VIRTIO 1.0, 2.4 Virtqueues says:
270          *
271          *   Queue Size value is always a power of 2. The maximum Queue Size
272          *   value is 32768.
273          */
274         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
275                 RTE_LOG(ERR, VHOST_CONFIG,
276                         "invalid virtqueue size %u\n", vq->size);
277                 return -1;
278         }
279
280         if (dev->dequeue_zero_copy) {
281                 vq->nr_zmbuf = 0;
282                 vq->last_zmbuf_idx = 0;
283                 vq->zmbuf_size = vq->size;
284                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
285                                          sizeof(struct zcopy_mbuf), 0);
286                 if (vq->zmbufs == NULL) {
287                         RTE_LOG(WARNING, VHOST_CONFIG,
288                                 "failed to allocate mem for zero copy; "
289                                 "zero copy is force disabled\n");
290                         dev->dequeue_zero_copy = 0;
291                 }
292                 TAILQ_INIT(&vq->zmbuf_list);
293         }
294
295         vq->shadow_used_ring = rte_malloc(NULL,
296                                 vq->size * sizeof(struct vring_used_elem),
297                                 RTE_CACHE_LINE_SIZE);
298         if (!vq->shadow_used_ring) {
299                 RTE_LOG(ERR, VHOST_CONFIG,
300                         "failed to allocate memory for shadow used ring.\n");
301                 return -1;
302         }
303
304         vq->batch_copy_elems = rte_malloc(NULL,
305                                 vq->size * sizeof(struct batch_copy_elem),
306                                 RTE_CACHE_LINE_SIZE);
307         if (!vq->batch_copy_elems) {
308                 RTE_LOG(ERR, VHOST_CONFIG,
309                         "failed to allocate memory for batching copy.\n");
310                 return -1;
311         }
312
313         return 0;
314 }
315
316 /*
317  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
318  * same numa node as the memory of vring descriptor.
319  */
320 #ifdef RTE_LIBRTE_VHOST_NUMA
321 static struct virtio_net*
322 numa_realloc(struct virtio_net *dev, int index)
323 {
324         int oldnode, newnode;
325         struct virtio_net *old_dev;
326         struct vhost_virtqueue *old_vq, *vq;
327         struct zcopy_mbuf *new_zmbuf;
328         struct vring_used_elem *new_shadow_used_ring;
329         struct batch_copy_elem *new_batch_copy_elems;
330         int ret;
331
332         old_dev = dev;
333         vq = old_vq = dev->virtqueue[index];
334
335         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
336                             MPOL_F_NODE | MPOL_F_ADDR);
337
338         /* check if we need to reallocate vq */
339         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
340                              MPOL_F_NODE | MPOL_F_ADDR);
341         if (ret) {
342                 RTE_LOG(ERR, VHOST_CONFIG,
343                         "Unable to get vq numa information.\n");
344                 return dev;
345         }
346         if (oldnode != newnode) {
347                 RTE_LOG(INFO, VHOST_CONFIG,
348                         "reallocate vq from %d to %d node\n", oldnode, newnode);
349                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
350                 if (!vq)
351                         return dev;
352
353                 memcpy(vq, old_vq, sizeof(*vq));
354                 TAILQ_INIT(&vq->zmbuf_list);
355
356                 new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
357                         sizeof(struct zcopy_mbuf), 0, newnode);
358                 if (new_zmbuf) {
359                         rte_free(vq->zmbufs);
360                         vq->zmbufs = new_zmbuf;
361                 }
362
363                 new_shadow_used_ring = rte_malloc_socket(NULL,
364                         vq->size * sizeof(struct vring_used_elem),
365                         RTE_CACHE_LINE_SIZE,
366                         newnode);
367                 if (new_shadow_used_ring) {
368                         rte_free(vq->shadow_used_ring);
369                         vq->shadow_used_ring = new_shadow_used_ring;
370                 }
371
372                 new_batch_copy_elems = rte_malloc_socket(NULL,
373                         vq->size * sizeof(struct batch_copy_elem),
374                         RTE_CACHE_LINE_SIZE,
375                         newnode);
376                 if (new_batch_copy_elems) {
377                         rte_free(vq->batch_copy_elems);
378                         vq->batch_copy_elems = new_batch_copy_elems;
379                 }
380
381                 rte_free(old_vq);
382         }
383
384         /* check if we need to reallocate dev */
385         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
386                             MPOL_F_NODE | MPOL_F_ADDR);
387         if (ret) {
388                 RTE_LOG(ERR, VHOST_CONFIG,
389                         "Unable to get dev numa information.\n");
390                 goto out;
391         }
392         if (oldnode != newnode) {
393                 RTE_LOG(INFO, VHOST_CONFIG,
394                         "reallocate dev from %d to %d node\n",
395                         oldnode, newnode);
396                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
397                 if (!dev) {
398                         dev = old_dev;
399                         goto out;
400                 }
401
402                 memcpy(dev, old_dev, sizeof(*dev));
403                 rte_free(old_dev);
404         }
405
406 out:
407         dev->virtqueue[index] = vq;
408         vhost_devices[dev->vid] = dev;
409
410         if (old_vq != vq)
411                 vhost_user_iotlb_init(dev, index);
412
413         return dev;
414 }
415 #else
416 static struct virtio_net*
417 numa_realloc(struct virtio_net *dev, int index __rte_unused)
418 {
419         return dev;
420 }
421 #endif
422
423 /* Converts QEMU virtual address to Vhost virtual address. */
424 static uint64_t
425 qva_to_vva(struct virtio_net *dev, uint64_t qva)
426 {
427         struct rte_vhost_mem_region *reg;
428         uint32_t i;
429
430         /* Find the region where the address lives. */
431         for (i = 0; i < dev->mem->nregions; i++) {
432                 reg = &dev->mem->regions[i];
433
434                 if (qva >= reg->guest_user_addr &&
435                     qva <  reg->guest_user_addr + reg->size) {
436                         return qva - reg->guest_user_addr +
437                                reg->host_user_addr;
438                 }
439         }
440
441         return 0;
442 }
443
444
445 /*
446  * Converts ring address to Vhost virtual address.
447  * If IOMMU is enabled, the ring address is a guest IO virtual address,
448  * else it is a QEMU virtual address.
449  */
450 static uint64_t
451 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
452                 uint64_t ra, uint64_t size)
453 {
454         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
455                 uint64_t vva;
456
457                 vva = vhost_user_iotlb_cache_find(vq, ra,
458                                         &size, VHOST_ACCESS_RW);
459                 if (!vva)
460                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
461
462                 return vva;
463         }
464
465         return qva_to_vva(dev, ra);
466 }
467
468 static struct virtio_net *
469 translate_ring_addresses(struct virtio_net *dev, int vq_index)
470 {
471         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
472         struct vhost_vring_addr *addr = &vq->ring_addrs;
473
474         /* The addresses are converted from QEMU virtual to Vhost virtual. */
475         if (vq->desc && vq->avail && vq->used)
476                 return dev;
477
478         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
479                         vq, addr->desc_user_addr, sizeof(struct vring_desc));
480         if (vq->desc == 0) {
481                 RTE_LOG(DEBUG, VHOST_CONFIG,
482                         "(%d) failed to find desc ring address.\n",
483                         dev->vid);
484                 return dev;
485         }
486
487         dev = numa_realloc(dev, vq_index);
488         vq = dev->virtqueue[vq_index];
489         addr = &vq->ring_addrs;
490
491         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
492                         vq, addr->avail_user_addr, sizeof(struct vring_avail));
493         if (vq->avail == 0) {
494                 RTE_LOG(DEBUG, VHOST_CONFIG,
495                         "(%d) failed to find avail ring address.\n",
496                         dev->vid);
497                 return dev;
498         }
499
500         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
501                         vq, addr->used_user_addr, sizeof(struct vring_used));
502         if (vq->used == 0) {
503                 RTE_LOG(DEBUG, VHOST_CONFIG,
504                         "(%d) failed to find used ring address.\n",
505                         dev->vid);
506                 return dev;
507         }
508
509         if (vq->last_used_idx != vq->used->idx) {
510                 RTE_LOG(WARNING, VHOST_CONFIG,
511                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
512                         "some packets maybe resent for Tx and dropped for Rx\n",
513                         vq->last_used_idx, vq->used->idx);
514                 vq->last_used_idx  = vq->used->idx;
515                 vq->last_avail_idx = vq->used->idx;
516         }
517
518         vq->log_guest_addr = addr->log_guest_addr;
519
520         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
521                         dev->vid, vq->desc);
522         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
523                         dev->vid, vq->avail);
524         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
525                         dev->vid, vq->used);
526         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
527                         dev->vid, vq->log_guest_addr);
528
529         return dev;
530 }
531
532 /*
533  * The virtio device sends us the desc, used and avail ring addresses.
534  * This function then converts these to our address space.
535  */
536 static int
537 vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
538 {
539         struct vhost_virtqueue *vq;
540         struct vhost_vring_addr *addr = &msg->payload.addr;
541         struct virtio_net *dev = *pdev;
542
543         if (dev->mem == NULL)
544                 return -1;
545
546         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
547         vq = dev->virtqueue[msg->payload.addr.index];
548
549         /*
550          * Rings addresses should not be interpreted as long as the ring is not
551          * started and enabled
552          */
553         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
554
555         vring_invalidate(dev, vq);
556
557         if (vq->enabled && (dev->features &
558                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
559                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
560                 if (!dev)
561                         return -1;
562
563                 *pdev = dev;
564         }
565
566         return 0;
567 }
568
569 /*
570  * The virtio device sends us the available ring last used index.
571  */
572 static int
573 vhost_user_set_vring_base(struct virtio_net *dev,
574                           VhostUserMsg *msg)
575 {
576         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
577                         msg->payload.state.num;
578         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
579                         msg->payload.state.num;
580
581         return 0;
582 }
583
584 static int
585 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
586                    uint64_t host_phys_addr, uint64_t size)
587 {
588         struct guest_page *page, *last_page;
589
590         if (dev->nr_guest_pages == dev->max_guest_pages) {
591                 dev->max_guest_pages *= 2;
592                 dev->guest_pages = realloc(dev->guest_pages,
593                                         dev->max_guest_pages * sizeof(*page));
594                 if (!dev->guest_pages) {
595                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
596                         return -1;
597                 }
598         }
599
600         if (dev->nr_guest_pages > 0) {
601                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
602                 /* merge if the two pages are continuous */
603                 if (host_phys_addr == last_page->host_phys_addr +
604                                       last_page->size) {
605                         last_page->size += size;
606                         return 0;
607                 }
608         }
609
610         page = &dev->guest_pages[dev->nr_guest_pages++];
611         page->guest_phys_addr = guest_phys_addr;
612         page->host_phys_addr  = host_phys_addr;
613         page->size = size;
614
615         return 0;
616 }
617
618 static int
619 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
620                 uint64_t page_size)
621 {
622         uint64_t reg_size = reg->size;
623         uint64_t host_user_addr  = reg->host_user_addr;
624         uint64_t guest_phys_addr = reg->guest_phys_addr;
625         uint64_t host_phys_addr;
626         uint64_t size;
627
628         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
629         size = page_size - (guest_phys_addr & (page_size - 1));
630         size = RTE_MIN(size, reg_size);
631
632         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
633                 return -1;
634
635         host_user_addr  += size;
636         guest_phys_addr += size;
637         reg_size -= size;
638
639         while (reg_size > 0) {
640                 size = RTE_MIN(reg_size, page_size);
641                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
642                                                   host_user_addr);
643                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
644                                 size) < 0)
645                         return -1;
646
647                 host_user_addr  += size;
648                 guest_phys_addr += size;
649                 reg_size -= size;
650         }
651
652         return 0;
653 }
654
655 #ifdef RTE_LIBRTE_VHOST_DEBUG
656 /* TODO: enable it only in debug mode? */
657 static void
658 dump_guest_pages(struct virtio_net *dev)
659 {
660         uint32_t i;
661         struct guest_page *page;
662
663         for (i = 0; i < dev->nr_guest_pages; i++) {
664                 page = &dev->guest_pages[i];
665
666                 RTE_LOG(INFO, VHOST_CONFIG,
667                         "guest physical page region %u\n"
668                         "\t guest_phys_addr: %" PRIx64 "\n"
669                         "\t host_phys_addr : %" PRIx64 "\n"
670                         "\t size           : %" PRIx64 "\n",
671                         i,
672                         page->guest_phys_addr,
673                         page->host_phys_addr,
674                         page->size);
675         }
676 }
677 #else
678 #define dump_guest_pages(dev)
679 #endif
680
681 static bool
682 vhost_memory_changed(struct VhostUserMemory *new,
683                      struct rte_vhost_memory *old)
684 {
685         uint32_t i;
686
687         if (new->nregions != old->nregions)
688                 return true;
689
690         for (i = 0; i < new->nregions; ++i) {
691                 VhostUserMemoryRegion *new_r = &new->regions[i];
692                 struct rte_vhost_mem_region *old_r = &old->regions[i];
693
694                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
695                         return true;
696                 if (new_r->memory_size != old_r->size)
697                         return true;
698                 if (new_r->userspace_addr != old_r->guest_user_addr)
699                         return true;
700         }
701
702         return false;
703 }
704
705 static int
706 vhost_user_set_mem_table(struct virtio_net *dev, struct VhostUserMsg *pmsg)
707 {
708         struct VhostUserMemory memory = pmsg->payload.memory;
709         struct rte_vhost_mem_region *reg;
710         void *mmap_addr;
711         uint64_t mmap_size;
712         uint64_t mmap_offset;
713         uint64_t alignment;
714         uint32_t i;
715         int populate;
716         int fd;
717
718         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
719                 RTE_LOG(ERR, VHOST_CONFIG,
720                         "too many memory regions (%u)\n", memory.nregions);
721                 return -1;
722         }
723
724         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
725                 RTE_LOG(INFO, VHOST_CONFIG,
726                         "(%d) memory regions not changed\n", dev->vid);
727
728                 for (i = 0; i < memory.nregions; i++)
729                         close(pmsg->fds[i]);
730
731                 return 0;
732         }
733
734         if (dev->mem) {
735                 free_mem_region(dev);
736                 rte_free(dev->mem);
737                 dev->mem = NULL;
738         }
739
740         dev->nr_guest_pages = 0;
741         if (!dev->guest_pages) {
742                 dev->max_guest_pages = 8;
743                 dev->guest_pages = malloc(dev->max_guest_pages *
744                                                 sizeof(struct guest_page));
745                 if (dev->guest_pages == NULL) {
746                         RTE_LOG(ERR, VHOST_CONFIG,
747                                 "(%d) failed to allocate memory "
748                                 "for dev->guest_pages\n",
749                                 dev->vid);
750                         return -1;
751                 }
752         }
753
754         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
755                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
756         if (dev->mem == NULL) {
757                 RTE_LOG(ERR, VHOST_CONFIG,
758                         "(%d) failed to allocate memory for dev->mem\n",
759                         dev->vid);
760                 return -1;
761         }
762         dev->mem->nregions = memory.nregions;
763
764         for (i = 0; i < memory.nregions; i++) {
765                 fd  = pmsg->fds[i];
766                 reg = &dev->mem->regions[i];
767
768                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
769                 reg->guest_user_addr = memory.regions[i].userspace_addr;
770                 reg->size            = memory.regions[i].memory_size;
771                 reg->fd              = fd;
772
773                 mmap_offset = memory.regions[i].mmap_offset;
774
775                 /* Check for memory_size + mmap_offset overflow */
776                 if (mmap_offset >= -reg->size) {
777                         RTE_LOG(ERR, VHOST_CONFIG,
778                                 "mmap_offset (%#"PRIx64") and memory_size "
779                                 "(%#"PRIx64") overflow\n",
780                                 mmap_offset, reg->size);
781                         goto err_mmap;
782                 }
783
784                 mmap_size = reg->size + mmap_offset;
785
786                 /* mmap() without flag of MAP_ANONYMOUS, should be called
787                  * with length argument aligned with hugepagesz at older
788                  * longterm version Linux, like 2.6.32 and 3.2.72, or
789                  * mmap() will fail with EINVAL.
790                  *
791                  * to avoid failure, make sure in caller to keep length
792                  * aligned.
793                  */
794                 alignment = get_blk_size(fd);
795                 if (alignment == (uint64_t)-1) {
796                         RTE_LOG(ERR, VHOST_CONFIG,
797                                 "couldn't get hugepage size through fstat\n");
798                         goto err_mmap;
799                 }
800                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
801
802                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
803                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
804                                  MAP_SHARED | populate, fd, 0);
805
806                 if (mmap_addr == MAP_FAILED) {
807                         RTE_LOG(ERR, VHOST_CONFIG,
808                                 "mmap region %u failed.\n", i);
809                         goto err_mmap;
810                 }
811
812                 reg->mmap_addr = mmap_addr;
813                 reg->mmap_size = mmap_size;
814                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
815                                       mmap_offset;
816
817                 if (dev->dequeue_zero_copy)
818                         if (add_guest_pages(dev, reg, alignment) < 0) {
819                                 RTE_LOG(ERR, VHOST_CONFIG,
820                                         "adding guest pages to region %u failed.\n",
821                                         i);
822                                 goto err_mmap;
823                         }
824
825                 RTE_LOG(INFO, VHOST_CONFIG,
826                         "guest memory region %u, size: 0x%" PRIx64 "\n"
827                         "\t guest physical addr: 0x%" PRIx64 "\n"
828                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
829                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
830                         "\t mmap addr : 0x%" PRIx64 "\n"
831                         "\t mmap size : 0x%" PRIx64 "\n"
832                         "\t mmap align: 0x%" PRIx64 "\n"
833                         "\t mmap off  : 0x%" PRIx64 "\n",
834                         i, reg->size,
835                         reg->guest_phys_addr,
836                         reg->guest_user_addr,
837                         reg->host_user_addr,
838                         (uint64_t)(uintptr_t)mmap_addr,
839                         mmap_size,
840                         alignment,
841                         mmap_offset);
842         }
843
844         dump_guest_pages(dev);
845
846         return 0;
847
848 err_mmap:
849         free_mem_region(dev);
850         rte_free(dev->mem);
851         dev->mem = NULL;
852         return -1;
853 }
854
855 static int
856 vq_is_ready(struct vhost_virtqueue *vq)
857 {
858         return vq && vq->desc && vq->avail && vq->used &&
859                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
860                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
861 }
862
863 static int
864 virtio_is_ready(struct virtio_net *dev)
865 {
866         struct vhost_virtqueue *vq;
867         uint32_t i;
868
869         if (dev->nr_vring == 0)
870                 return 0;
871
872         for (i = 0; i < dev->nr_vring; i++) {
873                 vq = dev->virtqueue[i];
874
875                 if (!vq_is_ready(vq))
876                         return 0;
877         }
878
879         RTE_LOG(INFO, VHOST_CONFIG,
880                 "virtio is now ready for processing.\n");
881         return 1;
882 }
883
884 static void
885 vhost_user_set_vring_call(struct virtio_net *dev, struct VhostUserMsg *pmsg)
886 {
887         struct vhost_vring_file file;
888         struct vhost_virtqueue *vq;
889
890         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
891         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
892                 file.fd = VIRTIO_INVALID_EVENTFD;
893         else
894                 file.fd = pmsg->fds[0];
895         RTE_LOG(INFO, VHOST_CONFIG,
896                 "vring call idx:%d file:%d\n", file.index, file.fd);
897
898         vq = dev->virtqueue[file.index];
899         if (vq->callfd >= 0)
900                 close(vq->callfd);
901
902         vq->callfd = file.fd;
903 }
904
905 static void
906 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
907 {
908         struct vhost_vring_file file;
909         struct vhost_virtqueue *vq;
910         struct virtio_net *dev = *pdev;
911
912         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
913         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
914                 file.fd = VIRTIO_INVALID_EVENTFD;
915         else
916                 file.fd = pmsg->fds[0];
917         RTE_LOG(INFO, VHOST_CONFIG,
918                 "vring kick idx:%d file:%d\n", file.index, file.fd);
919
920         /* Interpret ring addresses only when ring is started. */
921         dev = translate_ring_addresses(dev, file.index);
922         if (!dev)
923                 return;
924
925         *pdev = dev;
926
927         vq = dev->virtqueue[file.index];
928
929         /*
930          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
931          * the ring starts already enabled. Otherwise, it is enabled via
932          * the SET_VRING_ENABLE message.
933          */
934         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
935                 vq->enabled = 1;
936
937         if (vq->kickfd >= 0)
938                 close(vq->kickfd);
939         vq->kickfd = file.fd;
940 }
941
942 static void
943 free_zmbufs(struct vhost_virtqueue *vq)
944 {
945         struct zcopy_mbuf *zmbuf, *next;
946
947         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
948              zmbuf != NULL; zmbuf = next) {
949                 next = TAILQ_NEXT(zmbuf, next);
950
951                 rte_pktmbuf_free(zmbuf->mbuf);
952                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
953         }
954
955         rte_free(vq->zmbufs);
956 }
957
958 /*
959  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
960  */
961 static int
962 vhost_user_get_vring_base(struct virtio_net *dev,
963                           VhostUserMsg *msg)
964 {
965         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
966         struct rte_vdpa_device *vdpa_dev;
967         int did = -1;
968
969         /* We have to stop the queue (virtio) if it is running. */
970         if (dev->flags & VIRTIO_DEV_RUNNING) {
971                 did = dev->vdpa_dev_id;
972                 vdpa_dev = rte_vdpa_get_device(did);
973                 if (vdpa_dev && vdpa_dev->ops->dev_close)
974                         vdpa_dev->ops->dev_close(dev->vid);
975                 dev->flags &= ~VIRTIO_DEV_RUNNING;
976                 dev->notify_ops->destroy_device(dev->vid);
977         }
978
979         dev->flags &= ~VIRTIO_DEV_READY;
980         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
981
982         /* Here we are safe to get the last avail index */
983         msg->payload.state.num = vq->last_avail_idx;
984
985         RTE_LOG(INFO, VHOST_CONFIG,
986                 "vring base idx:%d file:%d\n", msg->payload.state.index,
987                 msg->payload.state.num);
988         /*
989          * Based on current qemu vhost-user implementation, this message is
990          * sent and only sent in vhost_vring_stop.
991          * TODO: cleanup the vring, it isn't usable since here.
992          */
993         if (vq->kickfd >= 0)
994                 close(vq->kickfd);
995
996         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
997
998         if (vq->callfd >= 0)
999                 close(vq->callfd);
1000
1001         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1002
1003         if (dev->dequeue_zero_copy)
1004                 free_zmbufs(vq);
1005         rte_free(vq->shadow_used_ring);
1006         vq->shadow_used_ring = NULL;
1007
1008         rte_free(vq->batch_copy_elems);
1009         vq->batch_copy_elems = NULL;
1010
1011         return 0;
1012 }
1013
1014 /*
1015  * when virtio queues are ready to work, qemu will send us to
1016  * enable the virtio queue pair.
1017  */
1018 static int
1019 vhost_user_set_vring_enable(struct virtio_net *dev,
1020                             VhostUserMsg *msg)
1021 {
1022         int enable = (int)msg->payload.state.num;
1023         int index = (int)msg->payload.state.index;
1024         struct rte_vdpa_device *vdpa_dev;
1025         int did = -1;
1026
1027         RTE_LOG(INFO, VHOST_CONFIG,
1028                 "set queue enable: %d to qp idx: %d\n",
1029                 enable, index);
1030
1031         did = dev->vdpa_dev_id;
1032         vdpa_dev = rte_vdpa_get_device(did);
1033         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1034                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1035
1036         if (dev->notify_ops->vring_state_changed)
1037                 dev->notify_ops->vring_state_changed(dev->vid,
1038                                 index, enable);
1039
1040         dev->virtqueue[index]->enabled = enable;
1041
1042         return 0;
1043 }
1044
1045 static void
1046 vhost_user_get_protocol_features(struct virtio_net *dev,
1047                                  struct VhostUserMsg *msg)
1048 {
1049         uint64_t features, protocol_features;
1050
1051         rte_vhost_driver_get_features(dev->ifname, &features);
1052         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1053
1054         /*
1055          * REPLY_ACK protocol feature is only mandatory for now
1056          * for IOMMU feature. If IOMMU is explicitly disabled by the
1057          * application, disable also REPLY_ACK feature for older buggy
1058          * Qemu versions (from v2.7.0 to v2.9.0).
1059          */
1060         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1061                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1062
1063         msg->payload.u64 = protocol_features;
1064         msg->size = sizeof(msg->payload.u64);
1065 }
1066
1067 static void
1068 vhost_user_set_protocol_features(struct virtio_net *dev,
1069                                  uint64_t protocol_features)
1070 {
1071         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES)
1072                 return;
1073
1074         dev->protocol_features = protocol_features;
1075 }
1076
1077 static int
1078 vhost_user_set_log_base(struct virtio_net *dev, struct VhostUserMsg *msg)
1079 {
1080         int fd = msg->fds[0];
1081         uint64_t size, off;
1082         void *addr;
1083
1084         if (fd < 0) {
1085                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1086                 return -1;
1087         }
1088
1089         if (msg->size != sizeof(VhostUserLog)) {
1090                 RTE_LOG(ERR, VHOST_CONFIG,
1091                         "invalid log base msg size: %"PRId32" != %d\n",
1092                         msg->size, (int)sizeof(VhostUserLog));
1093                 return -1;
1094         }
1095
1096         size = msg->payload.log.mmap_size;
1097         off  = msg->payload.log.mmap_offset;
1098
1099         /* Don't allow mmap_offset to point outside the mmap region */
1100         if (off > size) {
1101                 RTE_LOG(ERR, VHOST_CONFIG,
1102                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1103                         off, size);
1104                 return -1;
1105         }
1106
1107         RTE_LOG(INFO, VHOST_CONFIG,
1108                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1109                 size, off);
1110
1111         /*
1112          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1113          * fail when offset is not page size aligned.
1114          */
1115         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1116         close(fd);
1117         if (addr == MAP_FAILED) {
1118                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1119                 return -1;
1120         }
1121
1122         /*
1123          * Free previously mapped log memory on occasionally
1124          * multiple VHOST_USER_SET_LOG_BASE.
1125          */
1126         if (dev->log_addr) {
1127                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1128         }
1129         dev->log_addr = (uint64_t)(uintptr_t)addr;
1130         dev->log_base = dev->log_addr + off;
1131         dev->log_size = size;
1132
1133         return 0;
1134 }
1135
1136 /*
1137  * An rarp packet is constructed and broadcasted to notify switches about
1138  * the new location of the migrated VM, so that packets from outside will
1139  * not be lost after migration.
1140  *
1141  * However, we don't actually "send" a rarp packet here, instead, we set
1142  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1143  */
1144 static int
1145 vhost_user_send_rarp(struct virtio_net *dev, struct VhostUserMsg *msg)
1146 {
1147         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1148         struct rte_vdpa_device *vdpa_dev;
1149         int did = -1;
1150
1151         RTE_LOG(DEBUG, VHOST_CONFIG,
1152                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1153                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1154         memcpy(dev->mac.addr_bytes, mac, 6);
1155
1156         /*
1157          * Set the flag to inject a RARP broadcast packet at
1158          * rte_vhost_dequeue_burst().
1159          *
1160          * rte_smp_wmb() is for making sure the mac is copied
1161          * before the flag is set.
1162          */
1163         rte_smp_wmb();
1164         rte_atomic16_set(&dev->broadcast_rarp, 1);
1165         did = dev->vdpa_dev_id;
1166         vdpa_dev = rte_vdpa_get_device(did);
1167         if (vdpa_dev && vdpa_dev->ops->migration_done)
1168                 vdpa_dev->ops->migration_done(dev->vid);
1169
1170         return 0;
1171 }
1172
1173 static int
1174 vhost_user_net_set_mtu(struct virtio_net *dev, struct VhostUserMsg *msg)
1175 {
1176         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1177                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1178                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1179                                 msg->payload.u64);
1180
1181                 return -1;
1182         }
1183
1184         dev->mtu = msg->payload.u64;
1185
1186         return 0;
1187 }
1188
1189 static int
1190 vhost_user_set_req_fd(struct virtio_net *dev, struct VhostUserMsg *msg)
1191 {
1192         int fd = msg->fds[0];
1193
1194         if (fd < 0) {
1195                 RTE_LOG(ERR, VHOST_CONFIG,
1196                                 "Invalid file descriptor for slave channel (%d)\n",
1197                                 fd);
1198                 return -1;
1199         }
1200
1201         dev->slave_req_fd = fd;
1202
1203         return 0;
1204 }
1205
1206 static int
1207 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1208 {
1209         struct vhost_vring_addr *ra;
1210         uint64_t start, end;
1211
1212         start = imsg->iova;
1213         end = start + imsg->size;
1214
1215         ra = &vq->ring_addrs;
1216         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1217                 return 1;
1218         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1219                 return 1;
1220         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1221                 return 1;
1222
1223         return 0;
1224 }
1225
1226 static int
1227 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1228                                 struct vhost_iotlb_msg *imsg)
1229 {
1230         uint64_t istart, iend, vstart, vend;
1231
1232         istart = imsg->iova;
1233         iend = istart + imsg->size - 1;
1234
1235         vstart = (uintptr_t)vq->desc;
1236         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1237         if (vstart <= iend && istart <= vend)
1238                 return 1;
1239
1240         vstart = (uintptr_t)vq->avail;
1241         vend = vstart + sizeof(struct vring_avail);
1242         vend += sizeof(uint16_t) * vq->size - 1;
1243         if (vstart <= iend && istart <= vend)
1244                 return 1;
1245
1246         vstart = (uintptr_t)vq->used;
1247         vend = vstart + sizeof(struct vring_used);
1248         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1249         if (vstart <= iend && istart <= vend)
1250                 return 1;
1251
1252         return 0;
1253 }
1254
1255 static int
1256 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1257 {
1258         struct virtio_net *dev = *pdev;
1259         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1260         uint16_t i;
1261         uint64_t vva;
1262
1263         switch (imsg->type) {
1264         case VHOST_IOTLB_UPDATE:
1265                 vva = qva_to_vva(dev, imsg->uaddr);
1266                 if (!vva)
1267                         return -1;
1268
1269                 for (i = 0; i < dev->nr_vring; i++) {
1270                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1271
1272                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1273                                         imsg->size, imsg->perm);
1274
1275                         if (is_vring_iotlb_update(vq, imsg))
1276                                 *pdev = dev = translate_ring_addresses(dev, i);
1277                 }
1278                 break;
1279         case VHOST_IOTLB_INVALIDATE:
1280                 for (i = 0; i < dev->nr_vring; i++) {
1281                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1282
1283                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1284                                         imsg->size);
1285
1286                         if (is_vring_iotlb_invalidate(vq, imsg))
1287                                 vring_invalidate(dev, vq);
1288                 }
1289                 break;
1290         default:
1291                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1292                                 imsg->type);
1293                 return -1;
1294         }
1295
1296         return 0;
1297 }
1298
1299 /* return bytes# of read on success or negative val on failure. */
1300 static int
1301 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1302 {
1303         int ret;
1304
1305         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1306                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1307         if (ret <= 0)
1308                 return ret;
1309
1310         if (msg && msg->size) {
1311                 if (msg->size > sizeof(msg->payload)) {
1312                         RTE_LOG(ERR, VHOST_CONFIG,
1313                                 "invalid msg size: %d\n", msg->size);
1314                         return -1;
1315                 }
1316                 ret = read(sockfd, &msg->payload, msg->size);
1317                 if (ret <= 0)
1318                         return ret;
1319                 if (ret != (int)msg->size) {
1320                         RTE_LOG(ERR, VHOST_CONFIG,
1321                                 "read control message failed\n");
1322                         return -1;
1323                 }
1324         }
1325
1326         return ret;
1327 }
1328
1329 static int
1330 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1331 {
1332         if (!msg)
1333                 return 0;
1334
1335         return send_fd_message(sockfd, (char *)msg,
1336                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1337 }
1338
1339 static int
1340 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1341 {
1342         if (!msg)
1343                 return 0;
1344
1345         msg->flags &= ~VHOST_USER_VERSION_MASK;
1346         msg->flags &= ~VHOST_USER_NEED_REPLY;
1347         msg->flags |= VHOST_USER_VERSION;
1348         msg->flags |= VHOST_USER_REPLY_MASK;
1349
1350         return send_vhost_message(sockfd, msg, NULL, 0);
1351 }
1352
1353 /*
1354  * Allocate a queue pair if it hasn't been allocated yet
1355  */
1356 static int
1357 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev, VhostUserMsg *msg)
1358 {
1359         uint16_t vring_idx;
1360
1361         switch (msg->request.master) {
1362         case VHOST_USER_SET_VRING_KICK:
1363         case VHOST_USER_SET_VRING_CALL:
1364         case VHOST_USER_SET_VRING_ERR:
1365                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1366                 break;
1367         case VHOST_USER_SET_VRING_NUM:
1368         case VHOST_USER_SET_VRING_BASE:
1369         case VHOST_USER_SET_VRING_ENABLE:
1370                 vring_idx = msg->payload.state.index;
1371                 break;
1372         case VHOST_USER_SET_VRING_ADDR:
1373                 vring_idx = msg->payload.addr.index;
1374                 break;
1375         default:
1376                 return 0;
1377         }
1378
1379         if (vring_idx >= VHOST_MAX_VRING) {
1380                 RTE_LOG(ERR, VHOST_CONFIG,
1381                         "invalid vring index: %u\n", vring_idx);
1382                 return -1;
1383         }
1384
1385         if (dev->virtqueue[vring_idx])
1386                 return 0;
1387
1388         return alloc_vring_queue(dev, vring_idx);
1389 }
1390
1391 static void
1392 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1393 {
1394         unsigned int i = 0;
1395         unsigned int vq_num = 0;
1396
1397         while (vq_num < dev->nr_vring) {
1398                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1399
1400                 if (vq) {
1401                         rte_spinlock_lock(&vq->access_lock);
1402                         vq_num++;
1403                 }
1404                 i++;
1405         }
1406 }
1407
1408 static void
1409 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1410 {
1411         unsigned int i = 0;
1412         unsigned int vq_num = 0;
1413
1414         while (vq_num < dev->nr_vring) {
1415                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1416
1417                 if (vq) {
1418                         rte_spinlock_unlock(&vq->access_lock);
1419                         vq_num++;
1420                 }
1421                 i++;
1422         }
1423 }
1424
1425 int
1426 vhost_user_msg_handler(int vid, int fd)
1427 {
1428         struct virtio_net *dev;
1429         struct VhostUserMsg msg;
1430         struct rte_vdpa_device *vdpa_dev;
1431         int did = -1;
1432         int ret;
1433         int unlock_required = 0;
1434         uint32_t skip_master = 0;
1435
1436         dev = get_device(vid);
1437         if (dev == NULL)
1438                 return -1;
1439
1440         if (!dev->notify_ops) {
1441                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1442                 if (!dev->notify_ops) {
1443                         RTE_LOG(ERR, VHOST_CONFIG,
1444                                 "failed to get callback ops for driver %s\n",
1445                                 dev->ifname);
1446                         return -1;
1447                 }
1448         }
1449
1450         ret = read_vhost_message(fd, &msg);
1451         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1452                 if (ret < 0)
1453                         RTE_LOG(ERR, VHOST_CONFIG,
1454                                 "vhost read message failed\n");
1455                 else if (ret == 0)
1456                         RTE_LOG(INFO, VHOST_CONFIG,
1457                                 "vhost peer closed\n");
1458                 else
1459                         RTE_LOG(ERR, VHOST_CONFIG,
1460                                 "vhost read incorrect message\n");
1461
1462                 return -1;
1463         }
1464
1465         ret = 0;
1466         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1467                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1468                         vhost_message_str[msg.request.master]);
1469         else
1470                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1471                         vhost_message_str[msg.request.master]);
1472
1473         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1474         if (ret < 0) {
1475                 RTE_LOG(ERR, VHOST_CONFIG,
1476                         "failed to alloc queue\n");
1477                 return -1;
1478         }
1479
1480         /*
1481          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1482          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1483          * and device is destroyed. destroy_device waits for queues to be
1484          * inactive, so it is safe. Otherwise taking the access_lock
1485          * would cause a dead lock.
1486          */
1487         switch (msg.request.master) {
1488         case VHOST_USER_SET_FEATURES:
1489         case VHOST_USER_SET_PROTOCOL_FEATURES:
1490         case VHOST_USER_SET_OWNER:
1491         case VHOST_USER_SET_MEM_TABLE:
1492         case VHOST_USER_SET_LOG_BASE:
1493         case VHOST_USER_SET_LOG_FD:
1494         case VHOST_USER_SET_VRING_NUM:
1495         case VHOST_USER_SET_VRING_ADDR:
1496         case VHOST_USER_SET_VRING_BASE:
1497         case VHOST_USER_SET_VRING_KICK:
1498         case VHOST_USER_SET_VRING_CALL:
1499         case VHOST_USER_SET_VRING_ERR:
1500         case VHOST_USER_SET_VRING_ENABLE:
1501         case VHOST_USER_SEND_RARP:
1502         case VHOST_USER_NET_SET_MTU:
1503         case VHOST_USER_SET_SLAVE_REQ_FD:
1504                 vhost_user_lock_all_queue_pairs(dev);
1505                 unlock_required = 1;
1506                 break;
1507         default:
1508                 break;
1509
1510         }
1511
1512         if (dev->extern_ops.pre_msg_handle) {
1513                 uint32_t need_reply;
1514
1515                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1516                                 (void *)&msg, &need_reply, &skip_master);
1517                 if (ret < 0)
1518                         goto skip_to_reply;
1519
1520                 if (need_reply)
1521                         send_vhost_reply(fd, &msg);
1522
1523                 if (skip_master)
1524                         goto skip_to_post_handle;
1525         }
1526
1527         switch (msg.request.master) {
1528         case VHOST_USER_GET_FEATURES:
1529                 msg.payload.u64 = vhost_user_get_features(dev);
1530                 msg.size = sizeof(msg.payload.u64);
1531                 send_vhost_reply(fd, &msg);
1532                 break;
1533         case VHOST_USER_SET_FEATURES:
1534                 ret = vhost_user_set_features(dev, msg.payload.u64);
1535                 if (ret)
1536                         return -1;
1537                 break;
1538
1539         case VHOST_USER_GET_PROTOCOL_FEATURES:
1540                 vhost_user_get_protocol_features(dev, &msg);
1541                 send_vhost_reply(fd, &msg);
1542                 break;
1543         case VHOST_USER_SET_PROTOCOL_FEATURES:
1544                 vhost_user_set_protocol_features(dev, msg.payload.u64);
1545                 break;
1546
1547         case VHOST_USER_SET_OWNER:
1548                 vhost_user_set_owner();
1549                 break;
1550         case VHOST_USER_RESET_OWNER:
1551                 vhost_user_reset_owner(dev);
1552                 break;
1553
1554         case VHOST_USER_SET_MEM_TABLE:
1555                 ret = vhost_user_set_mem_table(dev, &msg);
1556                 break;
1557
1558         case VHOST_USER_SET_LOG_BASE:
1559                 vhost_user_set_log_base(dev, &msg);
1560
1561                 /* it needs a reply */
1562                 msg.size = sizeof(msg.payload.u64);
1563                 send_vhost_reply(fd, &msg);
1564                 break;
1565         case VHOST_USER_SET_LOG_FD:
1566                 close(msg.fds[0]);
1567                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1568                 break;
1569
1570         case VHOST_USER_SET_VRING_NUM:
1571                 vhost_user_set_vring_num(dev, &msg);
1572                 break;
1573         case VHOST_USER_SET_VRING_ADDR:
1574                 vhost_user_set_vring_addr(&dev, &msg);
1575                 break;
1576         case VHOST_USER_SET_VRING_BASE:
1577                 vhost_user_set_vring_base(dev, &msg);
1578                 break;
1579
1580         case VHOST_USER_GET_VRING_BASE:
1581                 vhost_user_get_vring_base(dev, &msg);
1582                 msg.size = sizeof(msg.payload.state);
1583                 send_vhost_reply(fd, &msg);
1584                 break;
1585
1586         case VHOST_USER_SET_VRING_KICK:
1587                 vhost_user_set_vring_kick(&dev, &msg);
1588                 break;
1589         case VHOST_USER_SET_VRING_CALL:
1590                 vhost_user_set_vring_call(dev, &msg);
1591                 break;
1592
1593         case VHOST_USER_SET_VRING_ERR:
1594                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1595                         close(msg.fds[0]);
1596                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1597                 break;
1598
1599         case VHOST_USER_GET_QUEUE_NUM:
1600                 msg.payload.u64 = (uint64_t)vhost_user_get_queue_num(dev);
1601                 msg.size = sizeof(msg.payload.u64);
1602                 send_vhost_reply(fd, &msg);
1603                 break;
1604
1605         case VHOST_USER_SET_VRING_ENABLE:
1606                 vhost_user_set_vring_enable(dev, &msg);
1607                 break;
1608         case VHOST_USER_SEND_RARP:
1609                 vhost_user_send_rarp(dev, &msg);
1610                 break;
1611
1612         case VHOST_USER_NET_SET_MTU:
1613                 ret = vhost_user_net_set_mtu(dev, &msg);
1614                 break;
1615
1616         case VHOST_USER_SET_SLAVE_REQ_FD:
1617                 ret = vhost_user_set_req_fd(dev, &msg);
1618                 break;
1619
1620         case VHOST_USER_IOTLB_MSG:
1621                 ret = vhost_user_iotlb_msg(&dev, &msg);
1622                 break;
1623
1624         default:
1625                 ret = -1;
1626                 break;
1627         }
1628
1629 skip_to_post_handle:
1630         if (dev->extern_ops.post_msg_handle) {
1631                 uint32_t need_reply;
1632
1633                 ret = (*dev->extern_ops.post_msg_handle)(
1634                                 dev->vid, (void *)&msg, &need_reply);
1635                 if (ret < 0)
1636                         goto skip_to_reply;
1637
1638                 if (need_reply)
1639                         send_vhost_reply(fd, &msg);
1640         }
1641
1642 skip_to_reply:
1643         if (unlock_required)
1644                 vhost_user_unlock_all_queue_pairs(dev);
1645
1646         if (msg.flags & VHOST_USER_NEED_REPLY) {
1647                 msg.payload.u64 = !!ret;
1648                 msg.size = sizeof(msg.payload.u64);
1649                 send_vhost_reply(fd, &msg);
1650         }
1651
1652         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1653                 dev->flags |= VIRTIO_DEV_READY;
1654
1655                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1656                         if (dev->dequeue_zero_copy) {
1657                                 RTE_LOG(INFO, VHOST_CONFIG,
1658                                                 "dequeue zero copy is enabled\n");
1659                         }
1660
1661                         if (dev->notify_ops->new_device(dev->vid) == 0)
1662                                 dev->flags |= VIRTIO_DEV_RUNNING;
1663                 }
1664         }
1665
1666         did = dev->vdpa_dev_id;
1667         vdpa_dev = rte_vdpa_get_device(did);
1668         if (vdpa_dev && virtio_is_ready(dev) &&
1669                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1670                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1671                 if (vdpa_dev->ops->dev_conf)
1672                         vdpa_dev->ops->dev_conf(dev->vid);
1673                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1674         }
1675
1676         return 0;
1677 }
1678
1679 int
1680 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1681 {
1682         int ret;
1683         struct VhostUserMsg msg = {
1684                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1685                 .flags = VHOST_USER_VERSION,
1686                 .size = sizeof(msg.payload.iotlb),
1687                 .payload.iotlb = {
1688                         .iova = iova,
1689                         .perm = perm,
1690                         .type = VHOST_IOTLB_MISS,
1691                 },
1692         };
1693
1694         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1695         if (ret < 0) {
1696                 RTE_LOG(ERR, VHOST_CONFIG,
1697                                 "Failed to send IOTLB miss message (%d)\n",
1698                                 ret);
1699                 return ret;
1700         }
1701
1702         return 0;
1703 }