a28fc3495e3028b11194c36e2ee4dd0f89a0a368
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(void)
131 {
132         return 0;
133 }
134
135 static int
136 vhost_user_reset_owner(struct virtio_net *dev)
137 {
138         struct rte_vdpa_device *vdpa_dev;
139         int did = -1;
140
141         if (dev->flags & VIRTIO_DEV_RUNNING) {
142                 did = dev->vdpa_dev_id;
143                 vdpa_dev = rte_vdpa_get_device(did);
144                 if (vdpa_dev && vdpa_dev->ops->dev_close)
145                         vdpa_dev->ops->dev_close(dev->vid);
146                 dev->flags &= ~VIRTIO_DEV_RUNNING;
147                 dev->notify_ops->destroy_device(dev->vid);
148         }
149
150         cleanup_device(dev, 0);
151         reset_device(dev);
152         return 0;
153 }
154
155 /*
156  * The features that we support are requested.
157  */
158 static uint64_t
159 vhost_user_get_features(struct virtio_net *dev)
160 {
161         uint64_t features = 0;
162
163         rte_vhost_driver_get_features(dev->ifname, &features);
164         return features;
165 }
166
167 /*
168  * The queue number that we support are requested.
169  */
170 static uint32_t
171 vhost_user_get_queue_num(struct virtio_net *dev)
172 {
173         uint32_t queue_num = 0;
174
175         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
176         return queue_num;
177 }
178
179 /*
180  * We receive the negotiated features supported by us and the virtio device.
181  */
182 static int
183 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
184 {
185         uint64_t vhost_features = 0;
186         struct rte_vdpa_device *vdpa_dev;
187         int did = -1;
188
189         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
190         if (features & ~vhost_features) {
191                 RTE_LOG(ERR, VHOST_CONFIG,
192                         "(%d) received invalid negotiated features.\n",
193                         dev->vid);
194                 return -1;
195         }
196
197         if (dev->flags & VIRTIO_DEV_RUNNING) {
198                 if (dev->features == features)
199                         return 0;
200
201                 /*
202                  * Error out if master tries to change features while device is
203                  * in running state. The exception being VHOST_F_LOG_ALL, which
204                  * is enabled when the live-migration starts.
205                  */
206                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
207                         RTE_LOG(ERR, VHOST_CONFIG,
208                                 "(%d) features changed while device is running.\n",
209                                 dev->vid);
210                         return -1;
211                 }
212
213                 if (dev->notify_ops->features_changed)
214                         dev->notify_ops->features_changed(dev->vid, features);
215         }
216
217         dev->features = features;
218         if (dev->features &
219                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
220                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
221         } else {
222                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
223         }
224         VHOST_LOG_DEBUG(VHOST_CONFIG,
225                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
226                 dev->vid,
227                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
228                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
229
230         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
231             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
232                 /*
233                  * Remove all but first queue pair if MQ hasn't been
234                  * negotiated. This is safe because the device is not
235                  * running at this stage.
236                  */
237                 while (dev->nr_vring > 2) {
238                         struct vhost_virtqueue *vq;
239
240                         vq = dev->virtqueue[--dev->nr_vring];
241                         if (!vq)
242                                 continue;
243
244                         dev->virtqueue[dev->nr_vring] = NULL;
245                         cleanup_vq(vq, 1);
246                         free_vq(vq);
247                 }
248         }
249
250         did = dev->vdpa_dev_id;
251         vdpa_dev = rte_vdpa_get_device(did);
252         if (vdpa_dev && vdpa_dev->ops->set_features)
253                 vdpa_dev->ops->set_features(dev->vid);
254
255         return 0;
256 }
257
258 /*
259  * The virtio device sends us the size of the descriptor ring.
260  */
261 static int
262 vhost_user_set_vring_num(struct virtio_net *dev,
263                          VhostUserMsg *msg)
264 {
265         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
266
267         vq->size = msg->payload.state.num;
268
269         /* VIRTIO 1.0, 2.4 Virtqueues says:
270          *
271          *   Queue Size value is always a power of 2. The maximum Queue Size
272          *   value is 32768.
273          */
274         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
275                 RTE_LOG(ERR, VHOST_CONFIG,
276                         "invalid virtqueue size %u\n", vq->size);
277                 return -1;
278         }
279
280         if (dev->dequeue_zero_copy) {
281                 vq->nr_zmbuf = 0;
282                 vq->last_zmbuf_idx = 0;
283                 vq->zmbuf_size = vq->size;
284                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
285                                          sizeof(struct zcopy_mbuf), 0);
286                 if (vq->zmbufs == NULL) {
287                         RTE_LOG(WARNING, VHOST_CONFIG,
288                                 "failed to allocate mem for zero copy; "
289                                 "zero copy is force disabled\n");
290                         dev->dequeue_zero_copy = 0;
291                 }
292                 TAILQ_INIT(&vq->zmbuf_list);
293         }
294
295         vq->shadow_used_ring = rte_malloc(NULL,
296                                 vq->size * sizeof(struct vring_used_elem),
297                                 RTE_CACHE_LINE_SIZE);
298         if (!vq->shadow_used_ring) {
299                 RTE_LOG(ERR, VHOST_CONFIG,
300                         "failed to allocate memory for shadow used ring.\n");
301                 return -1;
302         }
303
304         vq->batch_copy_elems = rte_malloc(NULL,
305                                 vq->size * sizeof(struct batch_copy_elem),
306                                 RTE_CACHE_LINE_SIZE);
307         if (!vq->batch_copy_elems) {
308                 RTE_LOG(ERR, VHOST_CONFIG,
309                         "failed to allocate memory for batching copy.\n");
310                 return -1;
311         }
312
313         return 0;
314 }
315
316 /*
317  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
318  * same numa node as the memory of vring descriptor.
319  */
320 #ifdef RTE_LIBRTE_VHOST_NUMA
321 static struct virtio_net*
322 numa_realloc(struct virtio_net *dev, int index)
323 {
324         int oldnode, newnode;
325         struct virtio_net *old_dev;
326         struct vhost_virtqueue *old_vq, *vq;
327         struct zcopy_mbuf *new_zmbuf;
328         struct vring_used_elem *new_shadow_used_ring;
329         struct batch_copy_elem *new_batch_copy_elems;
330         int ret;
331
332         old_dev = dev;
333         vq = old_vq = dev->virtqueue[index];
334
335         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
336                             MPOL_F_NODE | MPOL_F_ADDR);
337
338         /* check if we need to reallocate vq */
339         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
340                              MPOL_F_NODE | MPOL_F_ADDR);
341         if (ret) {
342                 RTE_LOG(ERR, VHOST_CONFIG,
343                         "Unable to get vq numa information.\n");
344                 return dev;
345         }
346         if (oldnode != newnode) {
347                 RTE_LOG(INFO, VHOST_CONFIG,
348                         "reallocate vq from %d to %d node\n", oldnode, newnode);
349                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
350                 if (!vq)
351                         return dev;
352
353                 memcpy(vq, old_vq, sizeof(*vq));
354                 TAILQ_INIT(&vq->zmbuf_list);
355
356                 new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
357                         sizeof(struct zcopy_mbuf), 0, newnode);
358                 if (new_zmbuf) {
359                         rte_free(vq->zmbufs);
360                         vq->zmbufs = new_zmbuf;
361                 }
362
363                 new_shadow_used_ring = rte_malloc_socket(NULL,
364                         vq->size * sizeof(struct vring_used_elem),
365                         RTE_CACHE_LINE_SIZE,
366                         newnode);
367                 if (new_shadow_used_ring) {
368                         rte_free(vq->shadow_used_ring);
369                         vq->shadow_used_ring = new_shadow_used_ring;
370                 }
371
372                 new_batch_copy_elems = rte_malloc_socket(NULL,
373                         vq->size * sizeof(struct batch_copy_elem),
374                         RTE_CACHE_LINE_SIZE,
375                         newnode);
376                 if (new_batch_copy_elems) {
377                         rte_free(vq->batch_copy_elems);
378                         vq->batch_copy_elems = new_batch_copy_elems;
379                 }
380
381                 rte_free(old_vq);
382         }
383
384         /* check if we need to reallocate dev */
385         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
386                             MPOL_F_NODE | MPOL_F_ADDR);
387         if (ret) {
388                 RTE_LOG(ERR, VHOST_CONFIG,
389                         "Unable to get dev numa information.\n");
390                 goto out;
391         }
392         if (oldnode != newnode) {
393                 RTE_LOG(INFO, VHOST_CONFIG,
394                         "reallocate dev from %d to %d node\n",
395                         oldnode, newnode);
396                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
397                 if (!dev) {
398                         dev = old_dev;
399                         goto out;
400                 }
401
402                 memcpy(dev, old_dev, sizeof(*dev));
403                 rte_free(old_dev);
404         }
405
406 out:
407         dev->virtqueue[index] = vq;
408         vhost_devices[dev->vid] = dev;
409
410         if (old_vq != vq)
411                 vhost_user_iotlb_init(dev, index);
412
413         return dev;
414 }
415 #else
416 static struct virtio_net*
417 numa_realloc(struct virtio_net *dev, int index __rte_unused)
418 {
419         return dev;
420 }
421 #endif
422
423 /* Converts QEMU virtual address to Vhost virtual address. */
424 static uint64_t
425 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
426 {
427         struct rte_vhost_mem_region *r;
428         uint32_t i;
429
430         /* Find the region where the address lives. */
431         for (i = 0; i < dev->mem->nregions; i++) {
432                 r = &dev->mem->regions[i];
433
434                 if (qva >= r->guest_user_addr &&
435                     qva <  r->guest_user_addr + r->size) {
436
437                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
438                                 *len = r->guest_user_addr + r->size - qva;
439
440                         return qva - r->guest_user_addr +
441                                r->host_user_addr;
442                 }
443         }
444         *len = 0;
445
446         return 0;
447 }
448
449
450 /*
451  * Converts ring address to Vhost virtual address.
452  * If IOMMU is enabled, the ring address is a guest IO virtual address,
453  * else it is a QEMU virtual address.
454  */
455 static uint64_t
456 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
457                 uint64_t ra, uint64_t *size)
458 {
459         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
460                 uint64_t vva;
461
462                 vva = vhost_user_iotlb_cache_find(vq, ra,
463                                         size, VHOST_ACCESS_RW);
464                 if (!vva)
465                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
466
467                 return vva;
468         }
469
470         return qva_to_vva(dev, ra, size);
471 }
472
473 static struct virtio_net *
474 translate_ring_addresses(struct virtio_net *dev, int vq_index)
475 {
476         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
477         struct vhost_vring_addr *addr = &vq->ring_addrs;
478         uint64_t len;
479
480         /* The addresses are converted from QEMU virtual to Vhost virtual. */
481         if (vq->desc && vq->avail && vq->used)
482                 return dev;
483
484         len = sizeof(struct vring_desc) * vq->size;
485         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
486                         vq, addr->desc_user_addr, &len);
487         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
488                 RTE_LOG(DEBUG, VHOST_CONFIG,
489                         "(%d) failed to map desc ring.\n",
490                         dev->vid);
491                 return dev;
492         }
493
494         dev = numa_realloc(dev, vq_index);
495         vq = dev->virtqueue[vq_index];
496         addr = &vq->ring_addrs;
497
498         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
499         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
500                         vq, addr->avail_user_addr, &len);
501         if (vq->avail == 0 ||
502                         len != sizeof(struct vring_avail) +
503                         sizeof(uint16_t) * vq->size) {
504                 RTE_LOG(DEBUG, VHOST_CONFIG,
505                         "(%d) failed to map avail ring.\n",
506                         dev->vid);
507                 return dev;
508         }
509
510         len = sizeof(struct vring_used) +
511                 sizeof(struct vring_used_elem) * vq->size;
512         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
513                         vq, addr->used_user_addr, &len);
514         if (vq->used == 0 || len != sizeof(struct vring_used) +
515                         sizeof(struct vring_used_elem) * vq->size) {
516                 RTE_LOG(DEBUG, VHOST_CONFIG,
517                         "(%d) failed to map used ring.\n",
518                         dev->vid);
519                 return dev;
520         }
521
522         if (vq->last_used_idx != vq->used->idx) {
523                 RTE_LOG(WARNING, VHOST_CONFIG,
524                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
525                         "some packets maybe resent for Tx and dropped for Rx\n",
526                         vq->last_used_idx, vq->used->idx);
527                 vq->last_used_idx  = vq->used->idx;
528                 vq->last_avail_idx = vq->used->idx;
529         }
530
531         vq->log_guest_addr = addr->log_guest_addr;
532
533         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
534                         dev->vid, vq->desc);
535         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
536                         dev->vid, vq->avail);
537         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
538                         dev->vid, vq->used);
539         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
540                         dev->vid, vq->log_guest_addr);
541
542         return dev;
543 }
544
545 /*
546  * The virtio device sends us the desc, used and avail ring addresses.
547  * This function then converts these to our address space.
548  */
549 static int
550 vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
551 {
552         struct vhost_virtqueue *vq;
553         struct vhost_vring_addr *addr = &msg->payload.addr;
554         struct virtio_net *dev = *pdev;
555
556         if (dev->mem == NULL)
557                 return -1;
558
559         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
560         vq = dev->virtqueue[msg->payload.addr.index];
561
562         /*
563          * Rings addresses should not be interpreted as long as the ring is not
564          * started and enabled
565          */
566         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
567
568         vring_invalidate(dev, vq);
569
570         if (vq->enabled && (dev->features &
571                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
572                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
573                 if (!dev)
574                         return -1;
575
576                 *pdev = dev;
577         }
578
579         return 0;
580 }
581
582 /*
583  * The virtio device sends us the available ring last used index.
584  */
585 static int
586 vhost_user_set_vring_base(struct virtio_net *dev,
587                           VhostUserMsg *msg)
588 {
589         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
590                         msg->payload.state.num;
591         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
592                         msg->payload.state.num;
593
594         return 0;
595 }
596
597 static int
598 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
599                    uint64_t host_phys_addr, uint64_t size)
600 {
601         struct guest_page *page, *last_page;
602
603         if (dev->nr_guest_pages == dev->max_guest_pages) {
604                 dev->max_guest_pages *= 2;
605                 dev->guest_pages = realloc(dev->guest_pages,
606                                         dev->max_guest_pages * sizeof(*page));
607                 if (!dev->guest_pages) {
608                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
609                         return -1;
610                 }
611         }
612
613         if (dev->nr_guest_pages > 0) {
614                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
615                 /* merge if the two pages are continuous */
616                 if (host_phys_addr == last_page->host_phys_addr +
617                                       last_page->size) {
618                         last_page->size += size;
619                         return 0;
620                 }
621         }
622
623         page = &dev->guest_pages[dev->nr_guest_pages++];
624         page->guest_phys_addr = guest_phys_addr;
625         page->host_phys_addr  = host_phys_addr;
626         page->size = size;
627
628         return 0;
629 }
630
631 static int
632 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
633                 uint64_t page_size)
634 {
635         uint64_t reg_size = reg->size;
636         uint64_t host_user_addr  = reg->host_user_addr;
637         uint64_t guest_phys_addr = reg->guest_phys_addr;
638         uint64_t host_phys_addr;
639         uint64_t size;
640
641         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
642         size = page_size - (guest_phys_addr & (page_size - 1));
643         size = RTE_MIN(size, reg_size);
644
645         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
646                 return -1;
647
648         host_user_addr  += size;
649         guest_phys_addr += size;
650         reg_size -= size;
651
652         while (reg_size > 0) {
653                 size = RTE_MIN(reg_size, page_size);
654                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
655                                                   host_user_addr);
656                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
657                                 size) < 0)
658                         return -1;
659
660                 host_user_addr  += size;
661                 guest_phys_addr += size;
662                 reg_size -= size;
663         }
664
665         return 0;
666 }
667
668 #ifdef RTE_LIBRTE_VHOST_DEBUG
669 /* TODO: enable it only in debug mode? */
670 static void
671 dump_guest_pages(struct virtio_net *dev)
672 {
673         uint32_t i;
674         struct guest_page *page;
675
676         for (i = 0; i < dev->nr_guest_pages; i++) {
677                 page = &dev->guest_pages[i];
678
679                 RTE_LOG(INFO, VHOST_CONFIG,
680                         "guest physical page region %u\n"
681                         "\t guest_phys_addr: %" PRIx64 "\n"
682                         "\t host_phys_addr : %" PRIx64 "\n"
683                         "\t size           : %" PRIx64 "\n",
684                         i,
685                         page->guest_phys_addr,
686                         page->host_phys_addr,
687                         page->size);
688         }
689 }
690 #else
691 #define dump_guest_pages(dev)
692 #endif
693
694 static bool
695 vhost_memory_changed(struct VhostUserMemory *new,
696                      struct rte_vhost_memory *old)
697 {
698         uint32_t i;
699
700         if (new->nregions != old->nregions)
701                 return true;
702
703         for (i = 0; i < new->nregions; ++i) {
704                 VhostUserMemoryRegion *new_r = &new->regions[i];
705                 struct rte_vhost_mem_region *old_r = &old->regions[i];
706
707                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
708                         return true;
709                 if (new_r->memory_size != old_r->size)
710                         return true;
711                 if (new_r->userspace_addr != old_r->guest_user_addr)
712                         return true;
713         }
714
715         return false;
716 }
717
718 static int
719 vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
720 {
721         struct virtio_net *dev = *pdev;
722         struct VhostUserMemory memory = pmsg->payload.memory;
723         struct rte_vhost_mem_region *reg;
724         void *mmap_addr;
725         uint64_t mmap_size;
726         uint64_t mmap_offset;
727         uint64_t alignment;
728         uint32_t i;
729         int populate;
730         int fd;
731
732         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
733                 RTE_LOG(ERR, VHOST_CONFIG,
734                         "too many memory regions (%u)\n", memory.nregions);
735                 return -1;
736         }
737
738         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
739                 RTE_LOG(INFO, VHOST_CONFIG,
740                         "(%d) memory regions not changed\n", dev->vid);
741
742                 for (i = 0; i < memory.nregions; i++)
743                         close(pmsg->fds[i]);
744
745                 return 0;
746         }
747
748         if (dev->mem) {
749                 free_mem_region(dev);
750                 rte_free(dev->mem);
751                 dev->mem = NULL;
752         }
753
754         dev->nr_guest_pages = 0;
755         if (!dev->guest_pages) {
756                 dev->max_guest_pages = 8;
757                 dev->guest_pages = malloc(dev->max_guest_pages *
758                                                 sizeof(struct guest_page));
759                 if (dev->guest_pages == NULL) {
760                         RTE_LOG(ERR, VHOST_CONFIG,
761                                 "(%d) failed to allocate memory "
762                                 "for dev->guest_pages\n",
763                                 dev->vid);
764                         return -1;
765                 }
766         }
767
768         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
769                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
770         if (dev->mem == NULL) {
771                 RTE_LOG(ERR, VHOST_CONFIG,
772                         "(%d) failed to allocate memory for dev->mem\n",
773                         dev->vid);
774                 return -1;
775         }
776         dev->mem->nregions = memory.nregions;
777
778         for (i = 0; i < memory.nregions; i++) {
779                 fd  = pmsg->fds[i];
780                 reg = &dev->mem->regions[i];
781
782                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
783                 reg->guest_user_addr = memory.regions[i].userspace_addr;
784                 reg->size            = memory.regions[i].memory_size;
785                 reg->fd              = fd;
786
787                 mmap_offset = memory.regions[i].mmap_offset;
788
789                 /* Check for memory_size + mmap_offset overflow */
790                 if (mmap_offset >= -reg->size) {
791                         RTE_LOG(ERR, VHOST_CONFIG,
792                                 "mmap_offset (%#"PRIx64") and memory_size "
793                                 "(%#"PRIx64") overflow\n",
794                                 mmap_offset, reg->size);
795                         goto err_mmap;
796                 }
797
798                 mmap_size = reg->size + mmap_offset;
799
800                 /* mmap() without flag of MAP_ANONYMOUS, should be called
801                  * with length argument aligned with hugepagesz at older
802                  * longterm version Linux, like 2.6.32 and 3.2.72, or
803                  * mmap() will fail with EINVAL.
804                  *
805                  * to avoid failure, make sure in caller to keep length
806                  * aligned.
807                  */
808                 alignment = get_blk_size(fd);
809                 if (alignment == (uint64_t)-1) {
810                         RTE_LOG(ERR, VHOST_CONFIG,
811                                 "couldn't get hugepage size through fstat\n");
812                         goto err_mmap;
813                 }
814                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
815
816                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
817                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
818                                  MAP_SHARED | populate, fd, 0);
819
820                 if (mmap_addr == MAP_FAILED) {
821                         RTE_LOG(ERR, VHOST_CONFIG,
822                                 "mmap region %u failed.\n", i);
823                         goto err_mmap;
824                 }
825
826                 reg->mmap_addr = mmap_addr;
827                 reg->mmap_size = mmap_size;
828                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
829                                       mmap_offset;
830
831                 if (dev->dequeue_zero_copy)
832                         if (add_guest_pages(dev, reg, alignment) < 0) {
833                                 RTE_LOG(ERR, VHOST_CONFIG,
834                                         "adding guest pages to region %u failed.\n",
835                                         i);
836                                 goto err_mmap;
837                         }
838
839                 RTE_LOG(INFO, VHOST_CONFIG,
840                         "guest memory region %u, size: 0x%" PRIx64 "\n"
841                         "\t guest physical addr: 0x%" PRIx64 "\n"
842                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
843                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
844                         "\t mmap addr : 0x%" PRIx64 "\n"
845                         "\t mmap size : 0x%" PRIx64 "\n"
846                         "\t mmap align: 0x%" PRIx64 "\n"
847                         "\t mmap off  : 0x%" PRIx64 "\n",
848                         i, reg->size,
849                         reg->guest_phys_addr,
850                         reg->guest_user_addr,
851                         reg->host_user_addr,
852                         (uint64_t)(uintptr_t)mmap_addr,
853                         mmap_size,
854                         alignment,
855                         mmap_offset);
856         }
857
858         for (i = 0; i < dev->nr_vring; i++) {
859                 struct vhost_virtqueue *vq = dev->virtqueue[i];
860
861                 if (vq->desc || vq->avail || vq->used) {
862                         /*
863                          * If the memory table got updated, the ring addresses
864                          * need to be translated again as virtual addresses have
865                          * changed.
866                          */
867                         vring_invalidate(dev, vq);
868
869                         dev = translate_ring_addresses(dev, i);
870                         if (!dev)
871                                 return -1;
872
873                         *pdev = dev;
874                 }
875         }
876
877         dump_guest_pages(dev);
878
879         return 0;
880
881 err_mmap:
882         free_mem_region(dev);
883         rte_free(dev->mem);
884         dev->mem = NULL;
885         return -1;
886 }
887
888 static int
889 vq_is_ready(struct vhost_virtqueue *vq)
890 {
891         return vq && vq->desc && vq->avail && vq->used &&
892                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
893                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
894 }
895
896 static int
897 virtio_is_ready(struct virtio_net *dev)
898 {
899         struct vhost_virtqueue *vq;
900         uint32_t i;
901
902         if (dev->nr_vring == 0)
903                 return 0;
904
905         for (i = 0; i < dev->nr_vring; i++) {
906                 vq = dev->virtqueue[i];
907
908                 if (!vq_is_ready(vq))
909                         return 0;
910         }
911
912         RTE_LOG(INFO, VHOST_CONFIG,
913                 "virtio is now ready for processing.\n");
914         return 1;
915 }
916
917 static void
918 vhost_user_set_vring_call(struct virtio_net *dev, struct VhostUserMsg *pmsg)
919 {
920         struct vhost_vring_file file;
921         struct vhost_virtqueue *vq;
922
923         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
924         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
925                 file.fd = VIRTIO_INVALID_EVENTFD;
926         else
927                 file.fd = pmsg->fds[0];
928         RTE_LOG(INFO, VHOST_CONFIG,
929                 "vring call idx:%d file:%d\n", file.index, file.fd);
930
931         vq = dev->virtqueue[file.index];
932         if (vq->callfd >= 0)
933                 close(vq->callfd);
934
935         vq->callfd = file.fd;
936 }
937
938 static void
939 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
940 {
941         struct vhost_vring_file file;
942         struct vhost_virtqueue *vq;
943         struct virtio_net *dev = *pdev;
944
945         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
946         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
947                 file.fd = VIRTIO_INVALID_EVENTFD;
948         else
949                 file.fd = pmsg->fds[0];
950         RTE_LOG(INFO, VHOST_CONFIG,
951                 "vring kick idx:%d file:%d\n", file.index, file.fd);
952
953         /* Interpret ring addresses only when ring is started. */
954         dev = translate_ring_addresses(dev, file.index);
955         if (!dev)
956                 return;
957
958         *pdev = dev;
959
960         vq = dev->virtqueue[file.index];
961
962         /*
963          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
964          * the ring starts already enabled. Otherwise, it is enabled via
965          * the SET_VRING_ENABLE message.
966          */
967         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
968                 vq->enabled = 1;
969
970         if (vq->kickfd >= 0)
971                 close(vq->kickfd);
972         vq->kickfd = file.fd;
973 }
974
975 static void
976 free_zmbufs(struct vhost_virtqueue *vq)
977 {
978         struct zcopy_mbuf *zmbuf, *next;
979
980         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
981              zmbuf != NULL; zmbuf = next) {
982                 next = TAILQ_NEXT(zmbuf, next);
983
984                 rte_pktmbuf_free(zmbuf->mbuf);
985                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
986         }
987
988         rte_free(vq->zmbufs);
989 }
990
991 /*
992  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
993  */
994 static int
995 vhost_user_get_vring_base(struct virtio_net *dev,
996                           VhostUserMsg *msg)
997 {
998         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
999         struct rte_vdpa_device *vdpa_dev;
1000         int did = -1;
1001
1002         /* We have to stop the queue (virtio) if it is running. */
1003         if (dev->flags & VIRTIO_DEV_RUNNING) {
1004                 did = dev->vdpa_dev_id;
1005                 vdpa_dev = rte_vdpa_get_device(did);
1006                 if (vdpa_dev && vdpa_dev->ops->dev_close)
1007                         vdpa_dev->ops->dev_close(dev->vid);
1008                 dev->flags &= ~VIRTIO_DEV_RUNNING;
1009                 dev->notify_ops->destroy_device(dev->vid);
1010         }
1011
1012         dev->flags &= ~VIRTIO_DEV_READY;
1013         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
1014
1015         /* Here we are safe to get the last avail index */
1016         msg->payload.state.num = vq->last_avail_idx;
1017
1018         RTE_LOG(INFO, VHOST_CONFIG,
1019                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1020                 msg->payload.state.num);
1021         /*
1022          * Based on current qemu vhost-user implementation, this message is
1023          * sent and only sent in vhost_vring_stop.
1024          * TODO: cleanup the vring, it isn't usable since here.
1025          */
1026         if (vq->kickfd >= 0)
1027                 close(vq->kickfd);
1028
1029         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1030
1031         if (vq->callfd >= 0)
1032                 close(vq->callfd);
1033
1034         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1035
1036         if (dev->dequeue_zero_copy)
1037                 free_zmbufs(vq);
1038         rte_free(vq->shadow_used_ring);
1039         vq->shadow_used_ring = NULL;
1040
1041         rte_free(vq->batch_copy_elems);
1042         vq->batch_copy_elems = NULL;
1043
1044         return 0;
1045 }
1046
1047 /*
1048  * when virtio queues are ready to work, qemu will send us to
1049  * enable the virtio queue pair.
1050  */
1051 static int
1052 vhost_user_set_vring_enable(struct virtio_net *dev,
1053                             VhostUserMsg *msg)
1054 {
1055         int enable = (int)msg->payload.state.num;
1056         int index = (int)msg->payload.state.index;
1057         struct rte_vdpa_device *vdpa_dev;
1058         int did = -1;
1059
1060         RTE_LOG(INFO, VHOST_CONFIG,
1061                 "set queue enable: %d to qp idx: %d\n",
1062                 enable, index);
1063
1064         did = dev->vdpa_dev_id;
1065         vdpa_dev = rte_vdpa_get_device(did);
1066         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1067                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1068
1069         if (dev->notify_ops->vring_state_changed)
1070                 dev->notify_ops->vring_state_changed(dev->vid,
1071                                 index, enable);
1072
1073         dev->virtqueue[index]->enabled = enable;
1074
1075         return 0;
1076 }
1077
1078 static void
1079 vhost_user_get_protocol_features(struct virtio_net *dev,
1080                                  struct VhostUserMsg *msg)
1081 {
1082         uint64_t features, protocol_features;
1083
1084         rte_vhost_driver_get_features(dev->ifname, &features);
1085         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1086
1087         /*
1088          * REPLY_ACK protocol feature is only mandatory for now
1089          * for IOMMU feature. If IOMMU is explicitly disabled by the
1090          * application, disable also REPLY_ACK feature for older buggy
1091          * Qemu versions (from v2.7.0 to v2.9.0).
1092          */
1093         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1094                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1095
1096         msg->payload.u64 = protocol_features;
1097         msg->size = sizeof(msg->payload.u64);
1098 }
1099
1100 static void
1101 vhost_user_set_protocol_features(struct virtio_net *dev,
1102                                  uint64_t protocol_features)
1103 {
1104         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES)
1105                 return;
1106
1107         dev->protocol_features = protocol_features;
1108 }
1109
1110 static int
1111 vhost_user_set_log_base(struct virtio_net *dev, struct VhostUserMsg *msg)
1112 {
1113         int fd = msg->fds[0];
1114         uint64_t size, off;
1115         void *addr;
1116
1117         if (fd < 0) {
1118                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1119                 return -1;
1120         }
1121
1122         if (msg->size != sizeof(VhostUserLog)) {
1123                 RTE_LOG(ERR, VHOST_CONFIG,
1124                         "invalid log base msg size: %"PRId32" != %d\n",
1125                         msg->size, (int)sizeof(VhostUserLog));
1126                 return -1;
1127         }
1128
1129         size = msg->payload.log.mmap_size;
1130         off  = msg->payload.log.mmap_offset;
1131
1132         /* Don't allow mmap_offset to point outside the mmap region */
1133         if (off > size) {
1134                 RTE_LOG(ERR, VHOST_CONFIG,
1135                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1136                         off, size);
1137                 return -1;
1138         }
1139
1140         RTE_LOG(INFO, VHOST_CONFIG,
1141                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1142                 size, off);
1143
1144         /*
1145          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1146          * fail when offset is not page size aligned.
1147          */
1148         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1149         close(fd);
1150         if (addr == MAP_FAILED) {
1151                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1152                 return -1;
1153         }
1154
1155         /*
1156          * Free previously mapped log memory on occasionally
1157          * multiple VHOST_USER_SET_LOG_BASE.
1158          */
1159         if (dev->log_addr) {
1160                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1161         }
1162         dev->log_addr = (uint64_t)(uintptr_t)addr;
1163         dev->log_base = dev->log_addr + off;
1164         dev->log_size = size;
1165
1166         return 0;
1167 }
1168
1169 /*
1170  * An rarp packet is constructed and broadcasted to notify switches about
1171  * the new location of the migrated VM, so that packets from outside will
1172  * not be lost after migration.
1173  *
1174  * However, we don't actually "send" a rarp packet here, instead, we set
1175  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1176  */
1177 static int
1178 vhost_user_send_rarp(struct virtio_net *dev, struct VhostUserMsg *msg)
1179 {
1180         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1181         struct rte_vdpa_device *vdpa_dev;
1182         int did = -1;
1183
1184         RTE_LOG(DEBUG, VHOST_CONFIG,
1185                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1186                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1187         memcpy(dev->mac.addr_bytes, mac, 6);
1188
1189         /*
1190          * Set the flag to inject a RARP broadcast packet at
1191          * rte_vhost_dequeue_burst().
1192          *
1193          * rte_smp_wmb() is for making sure the mac is copied
1194          * before the flag is set.
1195          */
1196         rte_smp_wmb();
1197         rte_atomic16_set(&dev->broadcast_rarp, 1);
1198         did = dev->vdpa_dev_id;
1199         vdpa_dev = rte_vdpa_get_device(did);
1200         if (vdpa_dev && vdpa_dev->ops->migration_done)
1201                 vdpa_dev->ops->migration_done(dev->vid);
1202
1203         return 0;
1204 }
1205
1206 static int
1207 vhost_user_net_set_mtu(struct virtio_net *dev, struct VhostUserMsg *msg)
1208 {
1209         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1210                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1211                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1212                                 msg->payload.u64);
1213
1214                 return -1;
1215         }
1216
1217         dev->mtu = msg->payload.u64;
1218
1219         return 0;
1220 }
1221
1222 static int
1223 vhost_user_set_req_fd(struct virtio_net *dev, struct VhostUserMsg *msg)
1224 {
1225         int fd = msg->fds[0];
1226
1227         if (fd < 0) {
1228                 RTE_LOG(ERR, VHOST_CONFIG,
1229                                 "Invalid file descriptor for slave channel (%d)\n",
1230                                 fd);
1231                 return -1;
1232         }
1233
1234         dev->slave_req_fd = fd;
1235
1236         return 0;
1237 }
1238
1239 static int
1240 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1241 {
1242         struct vhost_vring_addr *ra;
1243         uint64_t start, end;
1244
1245         start = imsg->iova;
1246         end = start + imsg->size;
1247
1248         ra = &vq->ring_addrs;
1249         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1250                 return 1;
1251         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1252                 return 1;
1253         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1254                 return 1;
1255
1256         return 0;
1257 }
1258
1259 static int
1260 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1261                                 struct vhost_iotlb_msg *imsg)
1262 {
1263         uint64_t istart, iend, vstart, vend;
1264
1265         istart = imsg->iova;
1266         iend = istart + imsg->size - 1;
1267
1268         vstart = (uintptr_t)vq->desc;
1269         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1270         if (vstart <= iend && istart <= vend)
1271                 return 1;
1272
1273         vstart = (uintptr_t)vq->avail;
1274         vend = vstart + sizeof(struct vring_avail);
1275         vend += sizeof(uint16_t) * vq->size - 1;
1276         if (vstart <= iend && istart <= vend)
1277                 return 1;
1278
1279         vstart = (uintptr_t)vq->used;
1280         vend = vstart + sizeof(struct vring_used);
1281         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1282         if (vstart <= iend && istart <= vend)
1283                 return 1;
1284
1285         return 0;
1286 }
1287
1288 static int
1289 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1290 {
1291         struct virtio_net *dev = *pdev;
1292         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1293         uint16_t i;
1294         uint64_t vva, len;
1295
1296         switch (imsg->type) {
1297         case VHOST_IOTLB_UPDATE:
1298                 len = imsg->size;
1299                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1300                 if (!vva)
1301                         return -1;
1302
1303                 for (i = 0; i < dev->nr_vring; i++) {
1304                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1305
1306                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1307                                         len, imsg->perm);
1308
1309                         if (is_vring_iotlb_update(vq, imsg))
1310                                 *pdev = dev = translate_ring_addresses(dev, i);
1311                 }
1312                 break;
1313         case VHOST_IOTLB_INVALIDATE:
1314                 for (i = 0; i < dev->nr_vring; i++) {
1315                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1316
1317                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1318                                         imsg->size);
1319
1320                         if (is_vring_iotlb_invalidate(vq, imsg))
1321                                 vring_invalidate(dev, vq);
1322                 }
1323                 break;
1324         default:
1325                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1326                                 imsg->type);
1327                 return -1;
1328         }
1329
1330         return 0;
1331 }
1332
1333 /* return bytes# of read on success or negative val on failure. */
1334 static int
1335 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1336 {
1337         int ret;
1338
1339         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1340                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1341         if (ret <= 0)
1342                 return ret;
1343
1344         if (msg && msg->size) {
1345                 if (msg->size > sizeof(msg->payload)) {
1346                         RTE_LOG(ERR, VHOST_CONFIG,
1347                                 "invalid msg size: %d\n", msg->size);
1348                         return -1;
1349                 }
1350                 ret = read(sockfd, &msg->payload, msg->size);
1351                 if (ret <= 0)
1352                         return ret;
1353                 if (ret != (int)msg->size) {
1354                         RTE_LOG(ERR, VHOST_CONFIG,
1355                                 "read control message failed\n");
1356                         return -1;
1357                 }
1358         }
1359
1360         return ret;
1361 }
1362
1363 static int
1364 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1365 {
1366         if (!msg)
1367                 return 0;
1368
1369         return send_fd_message(sockfd, (char *)msg,
1370                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1371 }
1372
1373 static int
1374 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1375 {
1376         if (!msg)
1377                 return 0;
1378
1379         msg->flags &= ~VHOST_USER_VERSION_MASK;
1380         msg->flags &= ~VHOST_USER_NEED_REPLY;
1381         msg->flags |= VHOST_USER_VERSION;
1382         msg->flags |= VHOST_USER_REPLY_MASK;
1383
1384         return send_vhost_message(sockfd, msg, NULL, 0);
1385 }
1386
1387 static int
1388 send_vhost_slave_message(struct virtio_net *dev, struct VhostUserMsg *msg,
1389                          int *fds, int fd_num)
1390 {
1391         int ret;
1392
1393         if (msg->flags & VHOST_USER_NEED_REPLY)
1394                 rte_spinlock_lock(&dev->slave_req_lock);
1395
1396         ret = send_vhost_message(dev->slave_req_fd, msg, fds, fd_num);
1397         if (ret < 0 && (msg->flags & VHOST_USER_NEED_REPLY))
1398                 rte_spinlock_unlock(&dev->slave_req_lock);
1399
1400         return ret;
1401 }
1402
1403 /*
1404  * Allocate a queue pair if it hasn't been allocated yet
1405  */
1406 static int
1407 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev, VhostUserMsg *msg)
1408 {
1409         uint16_t vring_idx;
1410
1411         switch (msg->request.master) {
1412         case VHOST_USER_SET_VRING_KICK:
1413         case VHOST_USER_SET_VRING_CALL:
1414         case VHOST_USER_SET_VRING_ERR:
1415                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1416                 break;
1417         case VHOST_USER_SET_VRING_NUM:
1418         case VHOST_USER_SET_VRING_BASE:
1419         case VHOST_USER_SET_VRING_ENABLE:
1420                 vring_idx = msg->payload.state.index;
1421                 break;
1422         case VHOST_USER_SET_VRING_ADDR:
1423                 vring_idx = msg->payload.addr.index;
1424                 break;
1425         default:
1426                 return 0;
1427         }
1428
1429         if (vring_idx >= VHOST_MAX_VRING) {
1430                 RTE_LOG(ERR, VHOST_CONFIG,
1431                         "invalid vring index: %u\n", vring_idx);
1432                 return -1;
1433         }
1434
1435         if (dev->virtqueue[vring_idx])
1436                 return 0;
1437
1438         return alloc_vring_queue(dev, vring_idx);
1439 }
1440
1441 static void
1442 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1443 {
1444         unsigned int i = 0;
1445         unsigned int vq_num = 0;
1446
1447         while (vq_num < dev->nr_vring) {
1448                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1449
1450                 if (vq) {
1451                         rte_spinlock_lock(&vq->access_lock);
1452                         vq_num++;
1453                 }
1454                 i++;
1455         }
1456 }
1457
1458 static void
1459 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1460 {
1461         unsigned int i = 0;
1462         unsigned int vq_num = 0;
1463
1464         while (vq_num < dev->nr_vring) {
1465                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1466
1467                 if (vq) {
1468                         rte_spinlock_unlock(&vq->access_lock);
1469                         vq_num++;
1470                 }
1471                 i++;
1472         }
1473 }
1474
1475 int
1476 vhost_user_msg_handler(int vid, int fd)
1477 {
1478         struct virtio_net *dev;
1479         struct VhostUserMsg msg;
1480         struct rte_vdpa_device *vdpa_dev;
1481         int did = -1;
1482         int ret;
1483         int unlock_required = 0;
1484         uint32_t skip_master = 0;
1485
1486         dev = get_device(vid);
1487         if (dev == NULL)
1488                 return -1;
1489
1490         if (!dev->notify_ops) {
1491                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1492                 if (!dev->notify_ops) {
1493                         RTE_LOG(ERR, VHOST_CONFIG,
1494                                 "failed to get callback ops for driver %s\n",
1495                                 dev->ifname);
1496                         return -1;
1497                 }
1498         }
1499
1500         ret = read_vhost_message(fd, &msg);
1501         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1502                 if (ret < 0)
1503                         RTE_LOG(ERR, VHOST_CONFIG,
1504                                 "vhost read message failed\n");
1505                 else if (ret == 0)
1506                         RTE_LOG(INFO, VHOST_CONFIG,
1507                                 "vhost peer closed\n");
1508                 else
1509                         RTE_LOG(ERR, VHOST_CONFIG,
1510                                 "vhost read incorrect message\n");
1511
1512                 return -1;
1513         }
1514
1515         ret = 0;
1516         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1517                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1518                         vhost_message_str[msg.request.master]);
1519         else
1520                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1521                         vhost_message_str[msg.request.master]);
1522
1523         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1524         if (ret < 0) {
1525                 RTE_LOG(ERR, VHOST_CONFIG,
1526                         "failed to alloc queue\n");
1527                 return -1;
1528         }
1529
1530         /*
1531          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1532          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1533          * and device is destroyed. destroy_device waits for queues to be
1534          * inactive, so it is safe. Otherwise taking the access_lock
1535          * would cause a dead lock.
1536          */
1537         switch (msg.request.master) {
1538         case VHOST_USER_SET_FEATURES:
1539         case VHOST_USER_SET_PROTOCOL_FEATURES:
1540         case VHOST_USER_SET_OWNER:
1541         case VHOST_USER_SET_MEM_TABLE:
1542         case VHOST_USER_SET_LOG_BASE:
1543         case VHOST_USER_SET_LOG_FD:
1544         case VHOST_USER_SET_VRING_NUM:
1545         case VHOST_USER_SET_VRING_ADDR:
1546         case VHOST_USER_SET_VRING_BASE:
1547         case VHOST_USER_SET_VRING_KICK:
1548         case VHOST_USER_SET_VRING_CALL:
1549         case VHOST_USER_SET_VRING_ERR:
1550         case VHOST_USER_SET_VRING_ENABLE:
1551         case VHOST_USER_SEND_RARP:
1552         case VHOST_USER_NET_SET_MTU:
1553         case VHOST_USER_SET_SLAVE_REQ_FD:
1554                 vhost_user_lock_all_queue_pairs(dev);
1555                 unlock_required = 1;
1556                 break;
1557         default:
1558                 break;
1559
1560         }
1561
1562         if (dev->extern_ops.pre_msg_handle) {
1563                 uint32_t need_reply;
1564
1565                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1566                                 (void *)&msg, &need_reply, &skip_master);
1567                 if (ret < 0)
1568                         goto skip_to_reply;
1569
1570                 if (need_reply)
1571                         send_vhost_reply(fd, &msg);
1572
1573                 if (skip_master)
1574                         goto skip_to_post_handle;
1575         }
1576
1577         switch (msg.request.master) {
1578         case VHOST_USER_GET_FEATURES:
1579                 msg.payload.u64 = vhost_user_get_features(dev);
1580                 msg.size = sizeof(msg.payload.u64);
1581                 send_vhost_reply(fd, &msg);
1582                 break;
1583         case VHOST_USER_SET_FEATURES:
1584                 ret = vhost_user_set_features(dev, msg.payload.u64);
1585                 if (ret)
1586                         return -1;
1587                 break;
1588
1589         case VHOST_USER_GET_PROTOCOL_FEATURES:
1590                 vhost_user_get_protocol_features(dev, &msg);
1591                 send_vhost_reply(fd, &msg);
1592                 break;
1593         case VHOST_USER_SET_PROTOCOL_FEATURES:
1594                 vhost_user_set_protocol_features(dev, msg.payload.u64);
1595                 break;
1596
1597         case VHOST_USER_SET_OWNER:
1598                 vhost_user_set_owner();
1599                 break;
1600         case VHOST_USER_RESET_OWNER:
1601                 vhost_user_reset_owner(dev);
1602                 break;
1603
1604         case VHOST_USER_SET_MEM_TABLE:
1605                 ret = vhost_user_set_mem_table(&dev, &msg);
1606                 break;
1607
1608         case VHOST_USER_SET_LOG_BASE:
1609                 vhost_user_set_log_base(dev, &msg);
1610
1611                 /* it needs a reply */
1612                 msg.size = sizeof(msg.payload.u64);
1613                 send_vhost_reply(fd, &msg);
1614                 break;
1615         case VHOST_USER_SET_LOG_FD:
1616                 close(msg.fds[0]);
1617                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1618                 break;
1619
1620         case VHOST_USER_SET_VRING_NUM:
1621                 vhost_user_set_vring_num(dev, &msg);
1622                 break;
1623         case VHOST_USER_SET_VRING_ADDR:
1624                 vhost_user_set_vring_addr(&dev, &msg);
1625                 break;
1626         case VHOST_USER_SET_VRING_BASE:
1627                 vhost_user_set_vring_base(dev, &msg);
1628                 break;
1629
1630         case VHOST_USER_GET_VRING_BASE:
1631                 vhost_user_get_vring_base(dev, &msg);
1632                 msg.size = sizeof(msg.payload.state);
1633                 send_vhost_reply(fd, &msg);
1634                 break;
1635
1636         case VHOST_USER_SET_VRING_KICK:
1637                 vhost_user_set_vring_kick(&dev, &msg);
1638                 break;
1639         case VHOST_USER_SET_VRING_CALL:
1640                 vhost_user_set_vring_call(dev, &msg);
1641                 break;
1642
1643         case VHOST_USER_SET_VRING_ERR:
1644                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1645                         close(msg.fds[0]);
1646                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1647                 break;
1648
1649         case VHOST_USER_GET_QUEUE_NUM:
1650                 msg.payload.u64 = (uint64_t)vhost_user_get_queue_num(dev);
1651                 msg.size = sizeof(msg.payload.u64);
1652                 send_vhost_reply(fd, &msg);
1653                 break;
1654
1655         case VHOST_USER_SET_VRING_ENABLE:
1656                 vhost_user_set_vring_enable(dev, &msg);
1657                 break;
1658         case VHOST_USER_SEND_RARP:
1659                 vhost_user_send_rarp(dev, &msg);
1660                 break;
1661
1662         case VHOST_USER_NET_SET_MTU:
1663                 ret = vhost_user_net_set_mtu(dev, &msg);
1664                 break;
1665
1666         case VHOST_USER_SET_SLAVE_REQ_FD:
1667                 ret = vhost_user_set_req_fd(dev, &msg);
1668                 break;
1669
1670         case VHOST_USER_IOTLB_MSG:
1671                 ret = vhost_user_iotlb_msg(&dev, &msg);
1672                 break;
1673
1674         default:
1675                 ret = -1;
1676                 break;
1677         }
1678
1679 skip_to_post_handle:
1680         if (dev->extern_ops.post_msg_handle) {
1681                 uint32_t need_reply;
1682
1683                 ret = (*dev->extern_ops.post_msg_handle)(
1684                                 dev->vid, (void *)&msg, &need_reply);
1685                 if (ret < 0)
1686                         goto skip_to_reply;
1687
1688                 if (need_reply)
1689                         send_vhost_reply(fd, &msg);
1690         }
1691
1692 skip_to_reply:
1693         if (unlock_required)
1694                 vhost_user_unlock_all_queue_pairs(dev);
1695
1696         if (msg.flags & VHOST_USER_NEED_REPLY) {
1697                 msg.payload.u64 = !!ret;
1698                 msg.size = sizeof(msg.payload.u64);
1699                 send_vhost_reply(fd, &msg);
1700         }
1701
1702         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1703                 dev->flags |= VIRTIO_DEV_READY;
1704
1705                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1706                         if (dev->dequeue_zero_copy) {
1707                                 RTE_LOG(INFO, VHOST_CONFIG,
1708                                                 "dequeue zero copy is enabled\n");
1709                         }
1710
1711                         if (dev->notify_ops->new_device(dev->vid) == 0)
1712                                 dev->flags |= VIRTIO_DEV_RUNNING;
1713                 }
1714         }
1715
1716         did = dev->vdpa_dev_id;
1717         vdpa_dev = rte_vdpa_get_device(did);
1718         if (vdpa_dev && virtio_is_ready(dev) &&
1719                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1720                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1721                 if (vdpa_dev->ops->dev_conf)
1722                         vdpa_dev->ops->dev_conf(dev->vid);
1723                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1724                 if (vhost_user_host_notifier_ctrl(dev->vid, true) != 0) {
1725                         RTE_LOG(INFO, VHOST_CONFIG,
1726                                 "(%d) software relay is used for vDPA, performance may be low.\n",
1727                                 dev->vid);
1728                 }
1729         }
1730
1731         return 0;
1732 }
1733
1734 static int process_slave_message_reply(struct virtio_net *dev,
1735                                        const VhostUserMsg *msg)
1736 {
1737         VhostUserMsg msg_reply;
1738         int ret;
1739
1740         if ((msg->flags & VHOST_USER_NEED_REPLY) == 0)
1741                 return 0;
1742
1743         if (read_vhost_message(dev->slave_req_fd, &msg_reply) < 0) {
1744                 ret = -1;
1745                 goto out;
1746         }
1747
1748         if (msg_reply.request.slave != msg->request.slave) {
1749                 RTE_LOG(ERR, VHOST_CONFIG,
1750                         "Received unexpected msg type (%u), expected %u\n",
1751                         msg_reply.request.slave, msg->request.slave);
1752                 ret = -1;
1753                 goto out;
1754         }
1755
1756         ret = msg_reply.payload.u64 ? -1 : 0;
1757
1758 out:
1759         rte_spinlock_unlock(&dev->slave_req_lock);
1760         return ret;
1761 }
1762
1763 int
1764 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1765 {
1766         int ret;
1767         struct VhostUserMsg msg = {
1768                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1769                 .flags = VHOST_USER_VERSION,
1770                 .size = sizeof(msg.payload.iotlb),
1771                 .payload.iotlb = {
1772                         .iova = iova,
1773                         .perm = perm,
1774                         .type = VHOST_IOTLB_MISS,
1775                 },
1776         };
1777
1778         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1779         if (ret < 0) {
1780                 RTE_LOG(ERR, VHOST_CONFIG,
1781                                 "Failed to send IOTLB miss message (%d)\n",
1782                                 ret);
1783                 return ret;
1784         }
1785
1786         return 0;
1787 }
1788
1789 static int vhost_user_slave_set_vring_host_notifier(struct virtio_net *dev,
1790                                                     int index, int fd,
1791                                                     uint64_t offset,
1792                                                     uint64_t size)
1793 {
1794         int *fdp = NULL;
1795         size_t fd_num = 0;
1796         int ret;
1797         struct VhostUserMsg msg = {
1798                 .request.slave = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1799                 .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY,
1800                 .size = sizeof(msg.payload.area),
1801                 .payload.area = {
1802                         .u64 = index & VHOST_USER_VRING_IDX_MASK,
1803                         .size = size,
1804                         .offset = offset,
1805                 },
1806         };
1807
1808         if (fd < 0)
1809                 msg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1810         else {
1811                 fdp = &fd;
1812                 fd_num = 1;
1813         }
1814
1815         ret = send_vhost_slave_message(dev, &msg, fdp, fd_num);
1816         if (ret < 0) {
1817                 RTE_LOG(ERR, VHOST_CONFIG,
1818                         "Failed to set host notifier (%d)\n", ret);
1819                 return ret;
1820         }
1821
1822         return process_slave_message_reply(dev, &msg);
1823 }
1824
1825 int vhost_user_host_notifier_ctrl(int vid, bool enable)
1826 {
1827         struct virtio_net *dev;
1828         struct rte_vdpa_device *vdpa_dev;
1829         int vfio_device_fd, did, ret = 0;
1830         uint64_t offset, size;
1831         unsigned int i;
1832
1833         dev = get_device(vid);
1834         if (!dev)
1835                 return -ENODEV;
1836
1837         did = dev->vdpa_dev_id;
1838         if (did < 0)
1839                 return -EINVAL;
1840
1841         if (!(dev->features & (1ULL << VIRTIO_F_VERSION_1)) ||
1842             !(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)) ||
1843             !(dev->protocol_features &
1844                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) ||
1845             !(dev->protocol_features &
1846                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) ||
1847             !(dev->protocol_features &
1848                         (1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER)))
1849                 return -ENOTSUP;
1850
1851         vdpa_dev = rte_vdpa_get_device(did);
1852
1853         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_vfio_device_fd, -ENOTSUP);
1854         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_notify_area, -ENOTSUP);
1855
1856         vfio_device_fd = vdpa_dev->ops->get_vfio_device_fd(vid);
1857         if (vfio_device_fd < 0)
1858                 return -ENOTSUP;
1859
1860         if (enable) {
1861                 for (i = 0; i < dev->nr_vring; i++) {
1862                         if (vdpa_dev->ops->get_notify_area(vid, i, &offset,
1863                                         &size) < 0) {
1864                                 ret = -ENOTSUP;
1865                                 goto disable;
1866                         }
1867
1868                         if (vhost_user_slave_set_vring_host_notifier(dev, i,
1869                                         vfio_device_fd, offset, size) < 0) {
1870                                 ret = -EFAULT;
1871                                 goto disable;
1872                         }
1873                 }
1874         } else {
1875 disable:
1876                 for (i = 0; i < dev->nr_vring; i++) {
1877                         vhost_user_slave_set_vring_host_notifier(dev, i, -1,
1878                                         0, 0);
1879                 }
1880         }
1881
1882         return ret;
1883 }