vhost: fix vDPA set features
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(void)
131 {
132         return 0;
133 }
134
135 static int
136 vhost_user_reset_owner(struct virtio_net *dev)
137 {
138         struct rte_vdpa_device *vdpa_dev;
139         int did = -1;
140
141         if (dev->flags & VIRTIO_DEV_RUNNING) {
142                 did = dev->vdpa_dev_id;
143                 vdpa_dev = rte_vdpa_get_device(did);
144                 if (vdpa_dev && vdpa_dev->ops->dev_close)
145                         vdpa_dev->ops->dev_close(dev->vid);
146                 dev->flags &= ~VIRTIO_DEV_RUNNING;
147                 dev->notify_ops->destroy_device(dev->vid);
148         }
149
150         cleanup_device(dev, 0);
151         reset_device(dev);
152         return 0;
153 }
154
155 /*
156  * The features that we support are requested.
157  */
158 static uint64_t
159 vhost_user_get_features(struct virtio_net *dev)
160 {
161         uint64_t features = 0;
162
163         rte_vhost_driver_get_features(dev->ifname, &features);
164         return features;
165 }
166
167 /*
168  * The queue number that we support are requested.
169  */
170 static uint32_t
171 vhost_user_get_queue_num(struct virtio_net *dev)
172 {
173         uint32_t queue_num = 0;
174
175         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
176         return queue_num;
177 }
178
179 /*
180  * We receive the negotiated features supported by us and the virtio device.
181  */
182 static int
183 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
184 {
185         uint64_t vhost_features = 0;
186         struct rte_vdpa_device *vdpa_dev;
187         int did = -1;
188
189         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
190         if (features & ~vhost_features) {
191                 RTE_LOG(ERR, VHOST_CONFIG,
192                         "(%d) received invalid negotiated features.\n",
193                         dev->vid);
194                 return -1;
195         }
196
197         if (dev->flags & VIRTIO_DEV_RUNNING) {
198                 if (dev->features == features)
199                         return 0;
200
201                 /*
202                  * Error out if master tries to change features while device is
203                  * in running state. The exception being VHOST_F_LOG_ALL, which
204                  * is enabled when the live-migration starts.
205                  */
206                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
207                         RTE_LOG(ERR, VHOST_CONFIG,
208                                 "(%d) features changed while device is running.\n",
209                                 dev->vid);
210                         return -1;
211                 }
212
213                 if (dev->notify_ops->features_changed)
214                         dev->notify_ops->features_changed(dev->vid, features);
215         }
216
217         dev->features = features;
218         if (dev->features &
219                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
220                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
221         } else {
222                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
223         }
224         VHOST_LOG_DEBUG(VHOST_CONFIG,
225                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
226                 dev->vid,
227                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
228                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
229
230         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
231             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
232                 /*
233                  * Remove all but first queue pair if MQ hasn't been
234                  * negotiated. This is safe because the device is not
235                  * running at this stage.
236                  */
237                 while (dev->nr_vring > 2) {
238                         struct vhost_virtqueue *vq;
239
240                         vq = dev->virtqueue[--dev->nr_vring];
241                         if (!vq)
242                                 continue;
243
244                         dev->virtqueue[dev->nr_vring] = NULL;
245                         cleanup_vq(vq, 1);
246                         free_vq(vq);
247                 }
248         }
249
250         did = dev->vdpa_dev_id;
251         vdpa_dev = rte_vdpa_get_device(did);
252         if (vdpa_dev && vdpa_dev->ops->set_features)
253                 vdpa_dev->ops->set_features(dev->vid);
254
255         return 0;
256 }
257
258 /*
259  * The virtio device sends us the size of the descriptor ring.
260  */
261 static int
262 vhost_user_set_vring_num(struct virtio_net *dev,
263                          VhostUserMsg *msg)
264 {
265         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
266
267         vq->size = msg->payload.state.num;
268
269         /* VIRTIO 1.0, 2.4 Virtqueues says:
270          *
271          *   Queue Size value is always a power of 2. The maximum Queue Size
272          *   value is 32768.
273          */
274         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
275                 RTE_LOG(ERR, VHOST_CONFIG,
276                         "invalid virtqueue size %u\n", vq->size);
277                 return -1;
278         }
279
280         if (dev->dequeue_zero_copy) {
281                 vq->nr_zmbuf = 0;
282                 vq->last_zmbuf_idx = 0;
283                 vq->zmbuf_size = vq->size;
284                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
285                                          sizeof(struct zcopy_mbuf), 0);
286                 if (vq->zmbufs == NULL) {
287                         RTE_LOG(WARNING, VHOST_CONFIG,
288                                 "failed to allocate mem for zero copy; "
289                                 "zero copy is force disabled\n");
290                         dev->dequeue_zero_copy = 0;
291                 }
292                 TAILQ_INIT(&vq->zmbuf_list);
293         }
294
295         vq->shadow_used_ring = rte_malloc(NULL,
296                                 vq->size * sizeof(struct vring_used_elem),
297                                 RTE_CACHE_LINE_SIZE);
298         if (!vq->shadow_used_ring) {
299                 RTE_LOG(ERR, VHOST_CONFIG,
300                         "failed to allocate memory for shadow used ring.\n");
301                 return -1;
302         }
303
304         vq->batch_copy_elems = rte_malloc(NULL,
305                                 vq->size * sizeof(struct batch_copy_elem),
306                                 RTE_CACHE_LINE_SIZE);
307         if (!vq->batch_copy_elems) {
308                 RTE_LOG(ERR, VHOST_CONFIG,
309                         "failed to allocate memory for batching copy.\n");
310                 return -1;
311         }
312
313         return 0;
314 }
315
316 /*
317  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
318  * same numa node as the memory of vring descriptor.
319  */
320 #ifdef RTE_LIBRTE_VHOST_NUMA
321 static struct virtio_net*
322 numa_realloc(struct virtio_net *dev, int index)
323 {
324         int oldnode, newnode;
325         struct virtio_net *old_dev;
326         struct vhost_virtqueue *old_vq, *vq;
327         struct zcopy_mbuf *new_zmbuf;
328         struct vring_used_elem *new_shadow_used_ring;
329         struct batch_copy_elem *new_batch_copy_elems;
330         int ret;
331
332         old_dev = dev;
333         vq = old_vq = dev->virtqueue[index];
334
335         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
336                             MPOL_F_NODE | MPOL_F_ADDR);
337
338         /* check if we need to reallocate vq */
339         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
340                              MPOL_F_NODE | MPOL_F_ADDR);
341         if (ret) {
342                 RTE_LOG(ERR, VHOST_CONFIG,
343                         "Unable to get vq numa information.\n");
344                 return dev;
345         }
346         if (oldnode != newnode) {
347                 RTE_LOG(INFO, VHOST_CONFIG,
348                         "reallocate vq from %d to %d node\n", oldnode, newnode);
349                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
350                 if (!vq)
351                         return dev;
352
353                 memcpy(vq, old_vq, sizeof(*vq));
354                 TAILQ_INIT(&vq->zmbuf_list);
355
356                 new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
357                         sizeof(struct zcopy_mbuf), 0, newnode);
358                 if (new_zmbuf) {
359                         rte_free(vq->zmbufs);
360                         vq->zmbufs = new_zmbuf;
361                 }
362
363                 new_shadow_used_ring = rte_malloc_socket(NULL,
364                         vq->size * sizeof(struct vring_used_elem),
365                         RTE_CACHE_LINE_SIZE,
366                         newnode);
367                 if (new_shadow_used_ring) {
368                         rte_free(vq->shadow_used_ring);
369                         vq->shadow_used_ring = new_shadow_used_ring;
370                 }
371
372                 new_batch_copy_elems = rte_malloc_socket(NULL,
373                         vq->size * sizeof(struct batch_copy_elem),
374                         RTE_CACHE_LINE_SIZE,
375                         newnode);
376                 if (new_batch_copy_elems) {
377                         rte_free(vq->batch_copy_elems);
378                         vq->batch_copy_elems = new_batch_copy_elems;
379                 }
380
381                 rte_free(old_vq);
382         }
383
384         /* check if we need to reallocate dev */
385         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
386                             MPOL_F_NODE | MPOL_F_ADDR);
387         if (ret) {
388                 RTE_LOG(ERR, VHOST_CONFIG,
389                         "Unable to get dev numa information.\n");
390                 goto out;
391         }
392         if (oldnode != newnode) {
393                 RTE_LOG(INFO, VHOST_CONFIG,
394                         "reallocate dev from %d to %d node\n",
395                         oldnode, newnode);
396                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
397                 if (!dev) {
398                         dev = old_dev;
399                         goto out;
400                 }
401
402                 memcpy(dev, old_dev, sizeof(*dev));
403                 rte_free(old_dev);
404         }
405
406 out:
407         dev->virtqueue[index] = vq;
408         vhost_devices[dev->vid] = dev;
409
410         if (old_vq != vq)
411                 vhost_user_iotlb_init(dev, index);
412
413         return dev;
414 }
415 #else
416 static struct virtio_net*
417 numa_realloc(struct virtio_net *dev, int index __rte_unused)
418 {
419         return dev;
420 }
421 #endif
422
423 /* Converts QEMU virtual address to Vhost virtual address. */
424 static uint64_t
425 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
426 {
427         struct rte_vhost_mem_region *r;
428         uint32_t i;
429
430         /* Find the region where the address lives. */
431         for (i = 0; i < dev->mem->nregions; i++) {
432                 r = &dev->mem->regions[i];
433
434                 if (qva >= r->guest_user_addr &&
435                     qva <  r->guest_user_addr + r->size) {
436
437                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
438                                 *len = r->guest_user_addr + r->size - qva;
439
440                         return qva - r->guest_user_addr +
441                                r->host_user_addr;
442                 }
443         }
444         *len = 0;
445
446         return 0;
447 }
448
449
450 /*
451  * Converts ring address to Vhost virtual address.
452  * If IOMMU is enabled, the ring address is a guest IO virtual address,
453  * else it is a QEMU virtual address.
454  */
455 static uint64_t
456 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
457                 uint64_t ra, uint64_t *size)
458 {
459         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
460                 uint64_t vva;
461
462                 vva = vhost_user_iotlb_cache_find(vq, ra,
463                                         size, VHOST_ACCESS_RW);
464                 if (!vva)
465                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
466
467                 return vva;
468         }
469
470         return qva_to_vva(dev, ra, size);
471 }
472
473 static struct virtio_net *
474 translate_ring_addresses(struct virtio_net *dev, int vq_index)
475 {
476         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
477         struct vhost_vring_addr *addr = &vq->ring_addrs;
478         uint64_t len;
479
480         /* The addresses are converted from QEMU virtual to Vhost virtual. */
481         if (vq->desc && vq->avail && vq->used)
482                 return dev;
483
484         len = sizeof(struct vring_desc) * vq->size;
485         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
486                         vq, addr->desc_user_addr, &len);
487         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
488                 RTE_LOG(DEBUG, VHOST_CONFIG,
489                         "(%d) failed to map desc ring.\n",
490                         dev->vid);
491                 return dev;
492         }
493
494         dev = numa_realloc(dev, vq_index);
495         vq = dev->virtqueue[vq_index];
496         addr = &vq->ring_addrs;
497
498         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
499         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
500                         vq, addr->avail_user_addr, &len);
501         if (vq->avail == 0 ||
502                         len != sizeof(struct vring_avail) +
503                         sizeof(uint16_t) * vq->size) {
504                 RTE_LOG(DEBUG, VHOST_CONFIG,
505                         "(%d) failed to map avail ring.\n",
506                         dev->vid);
507                 return dev;
508         }
509
510         len = sizeof(struct vring_used) +
511                 sizeof(struct vring_used_elem) * vq->size;
512         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
513                         vq, addr->used_user_addr, &len);
514         if (vq->used == 0 || len != sizeof(struct vring_used) +
515                         sizeof(struct vring_used_elem) * vq->size) {
516                 RTE_LOG(DEBUG, VHOST_CONFIG,
517                         "(%d) failed to map used ring.\n",
518                         dev->vid);
519                 return dev;
520         }
521
522         if (vq->last_used_idx != vq->used->idx) {
523                 RTE_LOG(WARNING, VHOST_CONFIG,
524                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
525                         "some packets maybe resent for Tx and dropped for Rx\n",
526                         vq->last_used_idx, vq->used->idx);
527                 vq->last_used_idx  = vq->used->idx;
528                 vq->last_avail_idx = vq->used->idx;
529         }
530
531         vq->log_guest_addr = addr->log_guest_addr;
532
533         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
534                         dev->vid, vq->desc);
535         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
536                         dev->vid, vq->avail);
537         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
538                         dev->vid, vq->used);
539         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
540                         dev->vid, vq->log_guest_addr);
541
542         return dev;
543 }
544
545 /*
546  * The virtio device sends us the desc, used and avail ring addresses.
547  * This function then converts these to our address space.
548  */
549 static int
550 vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
551 {
552         struct vhost_virtqueue *vq;
553         struct vhost_vring_addr *addr = &msg->payload.addr;
554         struct virtio_net *dev = *pdev;
555
556         if (dev->mem == NULL)
557                 return -1;
558
559         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
560         vq = dev->virtqueue[msg->payload.addr.index];
561
562         /*
563          * Rings addresses should not be interpreted as long as the ring is not
564          * started and enabled
565          */
566         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
567
568         vring_invalidate(dev, vq);
569
570         if (vq->enabled && (dev->features &
571                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
572                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
573                 if (!dev)
574                         return -1;
575
576                 *pdev = dev;
577         }
578
579         return 0;
580 }
581
582 /*
583  * The virtio device sends us the available ring last used index.
584  */
585 static int
586 vhost_user_set_vring_base(struct virtio_net *dev,
587                           VhostUserMsg *msg)
588 {
589         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
590                         msg->payload.state.num;
591         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
592                         msg->payload.state.num;
593
594         return 0;
595 }
596
597 static int
598 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
599                    uint64_t host_phys_addr, uint64_t size)
600 {
601         struct guest_page *page, *last_page;
602
603         if (dev->nr_guest_pages == dev->max_guest_pages) {
604                 dev->max_guest_pages *= 2;
605                 dev->guest_pages = realloc(dev->guest_pages,
606                                         dev->max_guest_pages * sizeof(*page));
607                 if (!dev->guest_pages) {
608                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
609                         return -1;
610                 }
611         }
612
613         if (dev->nr_guest_pages > 0) {
614                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
615                 /* merge if the two pages are continuous */
616                 if (host_phys_addr == last_page->host_phys_addr +
617                                       last_page->size) {
618                         last_page->size += size;
619                         return 0;
620                 }
621         }
622
623         page = &dev->guest_pages[dev->nr_guest_pages++];
624         page->guest_phys_addr = guest_phys_addr;
625         page->host_phys_addr  = host_phys_addr;
626         page->size = size;
627
628         return 0;
629 }
630
631 static int
632 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
633                 uint64_t page_size)
634 {
635         uint64_t reg_size = reg->size;
636         uint64_t host_user_addr  = reg->host_user_addr;
637         uint64_t guest_phys_addr = reg->guest_phys_addr;
638         uint64_t host_phys_addr;
639         uint64_t size;
640
641         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
642         size = page_size - (guest_phys_addr & (page_size - 1));
643         size = RTE_MIN(size, reg_size);
644
645         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
646                 return -1;
647
648         host_user_addr  += size;
649         guest_phys_addr += size;
650         reg_size -= size;
651
652         while (reg_size > 0) {
653                 size = RTE_MIN(reg_size, page_size);
654                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
655                                                   host_user_addr);
656                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
657                                 size) < 0)
658                         return -1;
659
660                 host_user_addr  += size;
661                 guest_phys_addr += size;
662                 reg_size -= size;
663         }
664
665         return 0;
666 }
667
668 #ifdef RTE_LIBRTE_VHOST_DEBUG
669 /* TODO: enable it only in debug mode? */
670 static void
671 dump_guest_pages(struct virtio_net *dev)
672 {
673         uint32_t i;
674         struct guest_page *page;
675
676         for (i = 0; i < dev->nr_guest_pages; i++) {
677                 page = &dev->guest_pages[i];
678
679                 RTE_LOG(INFO, VHOST_CONFIG,
680                         "guest physical page region %u\n"
681                         "\t guest_phys_addr: %" PRIx64 "\n"
682                         "\t host_phys_addr : %" PRIx64 "\n"
683                         "\t size           : %" PRIx64 "\n",
684                         i,
685                         page->guest_phys_addr,
686                         page->host_phys_addr,
687                         page->size);
688         }
689 }
690 #else
691 #define dump_guest_pages(dev)
692 #endif
693
694 static bool
695 vhost_memory_changed(struct VhostUserMemory *new,
696                      struct rte_vhost_memory *old)
697 {
698         uint32_t i;
699
700         if (new->nregions != old->nregions)
701                 return true;
702
703         for (i = 0; i < new->nregions; ++i) {
704                 VhostUserMemoryRegion *new_r = &new->regions[i];
705                 struct rte_vhost_mem_region *old_r = &old->regions[i];
706
707                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
708                         return true;
709                 if (new_r->memory_size != old_r->size)
710                         return true;
711                 if (new_r->userspace_addr != old_r->guest_user_addr)
712                         return true;
713         }
714
715         return false;
716 }
717
718 static int
719 vhost_user_set_mem_table(struct virtio_net *dev, struct VhostUserMsg *pmsg)
720 {
721         struct VhostUserMemory memory = pmsg->payload.memory;
722         struct rte_vhost_mem_region *reg;
723         void *mmap_addr;
724         uint64_t mmap_size;
725         uint64_t mmap_offset;
726         uint64_t alignment;
727         uint32_t i;
728         int populate;
729         int fd;
730
731         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
732                 RTE_LOG(ERR, VHOST_CONFIG,
733                         "too many memory regions (%u)\n", memory.nregions);
734                 return -1;
735         }
736
737         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
738                 RTE_LOG(INFO, VHOST_CONFIG,
739                         "(%d) memory regions not changed\n", dev->vid);
740
741                 for (i = 0; i < memory.nregions; i++)
742                         close(pmsg->fds[i]);
743
744                 return 0;
745         }
746
747         if (dev->mem) {
748                 free_mem_region(dev);
749                 rte_free(dev->mem);
750                 dev->mem = NULL;
751         }
752
753         dev->nr_guest_pages = 0;
754         if (!dev->guest_pages) {
755                 dev->max_guest_pages = 8;
756                 dev->guest_pages = malloc(dev->max_guest_pages *
757                                                 sizeof(struct guest_page));
758                 if (dev->guest_pages == NULL) {
759                         RTE_LOG(ERR, VHOST_CONFIG,
760                                 "(%d) failed to allocate memory "
761                                 "for dev->guest_pages\n",
762                                 dev->vid);
763                         return -1;
764                 }
765         }
766
767         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
768                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
769         if (dev->mem == NULL) {
770                 RTE_LOG(ERR, VHOST_CONFIG,
771                         "(%d) failed to allocate memory for dev->mem\n",
772                         dev->vid);
773                 return -1;
774         }
775         dev->mem->nregions = memory.nregions;
776
777         for (i = 0; i < memory.nregions; i++) {
778                 fd  = pmsg->fds[i];
779                 reg = &dev->mem->regions[i];
780
781                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
782                 reg->guest_user_addr = memory.regions[i].userspace_addr;
783                 reg->size            = memory.regions[i].memory_size;
784                 reg->fd              = fd;
785
786                 mmap_offset = memory.regions[i].mmap_offset;
787
788                 /* Check for memory_size + mmap_offset overflow */
789                 if (mmap_offset >= -reg->size) {
790                         RTE_LOG(ERR, VHOST_CONFIG,
791                                 "mmap_offset (%#"PRIx64") and memory_size "
792                                 "(%#"PRIx64") overflow\n",
793                                 mmap_offset, reg->size);
794                         goto err_mmap;
795                 }
796
797                 mmap_size = reg->size + mmap_offset;
798
799                 /* mmap() without flag of MAP_ANONYMOUS, should be called
800                  * with length argument aligned with hugepagesz at older
801                  * longterm version Linux, like 2.6.32 and 3.2.72, or
802                  * mmap() will fail with EINVAL.
803                  *
804                  * to avoid failure, make sure in caller to keep length
805                  * aligned.
806                  */
807                 alignment = get_blk_size(fd);
808                 if (alignment == (uint64_t)-1) {
809                         RTE_LOG(ERR, VHOST_CONFIG,
810                                 "couldn't get hugepage size through fstat\n");
811                         goto err_mmap;
812                 }
813                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
814
815                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
816                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
817                                  MAP_SHARED | populate, fd, 0);
818
819                 if (mmap_addr == MAP_FAILED) {
820                         RTE_LOG(ERR, VHOST_CONFIG,
821                                 "mmap region %u failed.\n", i);
822                         goto err_mmap;
823                 }
824
825                 reg->mmap_addr = mmap_addr;
826                 reg->mmap_size = mmap_size;
827                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
828                                       mmap_offset;
829
830                 if (dev->dequeue_zero_copy)
831                         if (add_guest_pages(dev, reg, alignment) < 0) {
832                                 RTE_LOG(ERR, VHOST_CONFIG,
833                                         "adding guest pages to region %u failed.\n",
834                                         i);
835                                 goto err_mmap;
836                         }
837
838                 RTE_LOG(INFO, VHOST_CONFIG,
839                         "guest memory region %u, size: 0x%" PRIx64 "\n"
840                         "\t guest physical addr: 0x%" PRIx64 "\n"
841                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
842                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
843                         "\t mmap addr : 0x%" PRIx64 "\n"
844                         "\t mmap size : 0x%" PRIx64 "\n"
845                         "\t mmap align: 0x%" PRIx64 "\n"
846                         "\t mmap off  : 0x%" PRIx64 "\n",
847                         i, reg->size,
848                         reg->guest_phys_addr,
849                         reg->guest_user_addr,
850                         reg->host_user_addr,
851                         (uint64_t)(uintptr_t)mmap_addr,
852                         mmap_size,
853                         alignment,
854                         mmap_offset);
855         }
856
857         dump_guest_pages(dev);
858
859         return 0;
860
861 err_mmap:
862         free_mem_region(dev);
863         rte_free(dev->mem);
864         dev->mem = NULL;
865         return -1;
866 }
867
868 static int
869 vq_is_ready(struct vhost_virtqueue *vq)
870 {
871         return vq && vq->desc && vq->avail && vq->used &&
872                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
873                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
874 }
875
876 static int
877 virtio_is_ready(struct virtio_net *dev)
878 {
879         struct vhost_virtqueue *vq;
880         uint32_t i;
881
882         if (dev->nr_vring == 0)
883                 return 0;
884
885         for (i = 0; i < dev->nr_vring; i++) {
886                 vq = dev->virtqueue[i];
887
888                 if (!vq_is_ready(vq))
889                         return 0;
890         }
891
892         RTE_LOG(INFO, VHOST_CONFIG,
893                 "virtio is now ready for processing.\n");
894         return 1;
895 }
896
897 static void
898 vhost_user_set_vring_call(struct virtio_net *dev, struct VhostUserMsg *pmsg)
899 {
900         struct vhost_vring_file file;
901         struct vhost_virtqueue *vq;
902
903         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
904         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
905                 file.fd = VIRTIO_INVALID_EVENTFD;
906         else
907                 file.fd = pmsg->fds[0];
908         RTE_LOG(INFO, VHOST_CONFIG,
909                 "vring call idx:%d file:%d\n", file.index, file.fd);
910
911         vq = dev->virtqueue[file.index];
912         if (vq->callfd >= 0)
913                 close(vq->callfd);
914
915         vq->callfd = file.fd;
916 }
917
918 static void
919 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *pmsg)
920 {
921         struct vhost_vring_file file;
922         struct vhost_virtqueue *vq;
923         struct virtio_net *dev = *pdev;
924
925         file.index = pmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
926         if (pmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
927                 file.fd = VIRTIO_INVALID_EVENTFD;
928         else
929                 file.fd = pmsg->fds[0];
930         RTE_LOG(INFO, VHOST_CONFIG,
931                 "vring kick idx:%d file:%d\n", file.index, file.fd);
932
933         /* Interpret ring addresses only when ring is started. */
934         dev = translate_ring_addresses(dev, file.index);
935         if (!dev)
936                 return;
937
938         *pdev = dev;
939
940         vq = dev->virtqueue[file.index];
941
942         /*
943          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
944          * the ring starts already enabled. Otherwise, it is enabled via
945          * the SET_VRING_ENABLE message.
946          */
947         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
948                 vq->enabled = 1;
949
950         if (vq->kickfd >= 0)
951                 close(vq->kickfd);
952         vq->kickfd = file.fd;
953 }
954
955 static void
956 free_zmbufs(struct vhost_virtqueue *vq)
957 {
958         struct zcopy_mbuf *zmbuf, *next;
959
960         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
961              zmbuf != NULL; zmbuf = next) {
962                 next = TAILQ_NEXT(zmbuf, next);
963
964                 rte_pktmbuf_free(zmbuf->mbuf);
965                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
966         }
967
968         rte_free(vq->zmbufs);
969 }
970
971 /*
972  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
973  */
974 static int
975 vhost_user_get_vring_base(struct virtio_net *dev,
976                           VhostUserMsg *msg)
977 {
978         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
979         struct rte_vdpa_device *vdpa_dev;
980         int did = -1;
981
982         /* We have to stop the queue (virtio) if it is running. */
983         if (dev->flags & VIRTIO_DEV_RUNNING) {
984                 did = dev->vdpa_dev_id;
985                 vdpa_dev = rte_vdpa_get_device(did);
986                 if (vdpa_dev && vdpa_dev->ops->dev_close)
987                         vdpa_dev->ops->dev_close(dev->vid);
988                 dev->flags &= ~VIRTIO_DEV_RUNNING;
989                 dev->notify_ops->destroy_device(dev->vid);
990         }
991
992         dev->flags &= ~VIRTIO_DEV_READY;
993         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
994
995         /* Here we are safe to get the last avail index */
996         msg->payload.state.num = vq->last_avail_idx;
997
998         RTE_LOG(INFO, VHOST_CONFIG,
999                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1000                 msg->payload.state.num);
1001         /*
1002          * Based on current qemu vhost-user implementation, this message is
1003          * sent and only sent in vhost_vring_stop.
1004          * TODO: cleanup the vring, it isn't usable since here.
1005          */
1006         if (vq->kickfd >= 0)
1007                 close(vq->kickfd);
1008
1009         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1010
1011         if (vq->callfd >= 0)
1012                 close(vq->callfd);
1013
1014         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1015
1016         if (dev->dequeue_zero_copy)
1017                 free_zmbufs(vq);
1018         rte_free(vq->shadow_used_ring);
1019         vq->shadow_used_ring = NULL;
1020
1021         rte_free(vq->batch_copy_elems);
1022         vq->batch_copy_elems = NULL;
1023
1024         return 0;
1025 }
1026
1027 /*
1028  * when virtio queues are ready to work, qemu will send us to
1029  * enable the virtio queue pair.
1030  */
1031 static int
1032 vhost_user_set_vring_enable(struct virtio_net *dev,
1033                             VhostUserMsg *msg)
1034 {
1035         int enable = (int)msg->payload.state.num;
1036         int index = (int)msg->payload.state.index;
1037         struct rte_vdpa_device *vdpa_dev;
1038         int did = -1;
1039
1040         RTE_LOG(INFO, VHOST_CONFIG,
1041                 "set queue enable: %d to qp idx: %d\n",
1042                 enable, index);
1043
1044         did = dev->vdpa_dev_id;
1045         vdpa_dev = rte_vdpa_get_device(did);
1046         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1047                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1048
1049         if (dev->notify_ops->vring_state_changed)
1050                 dev->notify_ops->vring_state_changed(dev->vid,
1051                                 index, enable);
1052
1053         dev->virtqueue[index]->enabled = enable;
1054
1055         return 0;
1056 }
1057
1058 static void
1059 vhost_user_get_protocol_features(struct virtio_net *dev,
1060                                  struct VhostUserMsg *msg)
1061 {
1062         uint64_t features, protocol_features;
1063
1064         rte_vhost_driver_get_features(dev->ifname, &features);
1065         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1066
1067         /*
1068          * REPLY_ACK protocol feature is only mandatory for now
1069          * for IOMMU feature. If IOMMU is explicitly disabled by the
1070          * application, disable also REPLY_ACK feature for older buggy
1071          * Qemu versions (from v2.7.0 to v2.9.0).
1072          */
1073         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1074                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1075
1076         msg->payload.u64 = protocol_features;
1077         msg->size = sizeof(msg->payload.u64);
1078 }
1079
1080 static void
1081 vhost_user_set_protocol_features(struct virtio_net *dev,
1082                                  uint64_t protocol_features)
1083 {
1084         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES)
1085                 return;
1086
1087         dev->protocol_features = protocol_features;
1088 }
1089
1090 static int
1091 vhost_user_set_log_base(struct virtio_net *dev, struct VhostUserMsg *msg)
1092 {
1093         int fd = msg->fds[0];
1094         uint64_t size, off;
1095         void *addr;
1096
1097         if (fd < 0) {
1098                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1099                 return -1;
1100         }
1101
1102         if (msg->size != sizeof(VhostUserLog)) {
1103                 RTE_LOG(ERR, VHOST_CONFIG,
1104                         "invalid log base msg size: %"PRId32" != %d\n",
1105                         msg->size, (int)sizeof(VhostUserLog));
1106                 return -1;
1107         }
1108
1109         size = msg->payload.log.mmap_size;
1110         off  = msg->payload.log.mmap_offset;
1111
1112         /* Don't allow mmap_offset to point outside the mmap region */
1113         if (off > size) {
1114                 RTE_LOG(ERR, VHOST_CONFIG,
1115                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1116                         off, size);
1117                 return -1;
1118         }
1119
1120         RTE_LOG(INFO, VHOST_CONFIG,
1121                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1122                 size, off);
1123
1124         /*
1125          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1126          * fail when offset is not page size aligned.
1127          */
1128         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1129         close(fd);
1130         if (addr == MAP_FAILED) {
1131                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1132                 return -1;
1133         }
1134
1135         /*
1136          * Free previously mapped log memory on occasionally
1137          * multiple VHOST_USER_SET_LOG_BASE.
1138          */
1139         if (dev->log_addr) {
1140                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1141         }
1142         dev->log_addr = (uint64_t)(uintptr_t)addr;
1143         dev->log_base = dev->log_addr + off;
1144         dev->log_size = size;
1145
1146         return 0;
1147 }
1148
1149 /*
1150  * An rarp packet is constructed and broadcasted to notify switches about
1151  * the new location of the migrated VM, so that packets from outside will
1152  * not be lost after migration.
1153  *
1154  * However, we don't actually "send" a rarp packet here, instead, we set
1155  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1156  */
1157 static int
1158 vhost_user_send_rarp(struct virtio_net *dev, struct VhostUserMsg *msg)
1159 {
1160         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1161         struct rte_vdpa_device *vdpa_dev;
1162         int did = -1;
1163
1164         RTE_LOG(DEBUG, VHOST_CONFIG,
1165                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1166                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1167         memcpy(dev->mac.addr_bytes, mac, 6);
1168
1169         /*
1170          * Set the flag to inject a RARP broadcast packet at
1171          * rte_vhost_dequeue_burst().
1172          *
1173          * rte_smp_wmb() is for making sure the mac is copied
1174          * before the flag is set.
1175          */
1176         rte_smp_wmb();
1177         rte_atomic16_set(&dev->broadcast_rarp, 1);
1178         did = dev->vdpa_dev_id;
1179         vdpa_dev = rte_vdpa_get_device(did);
1180         if (vdpa_dev && vdpa_dev->ops->migration_done)
1181                 vdpa_dev->ops->migration_done(dev->vid);
1182
1183         return 0;
1184 }
1185
1186 static int
1187 vhost_user_net_set_mtu(struct virtio_net *dev, struct VhostUserMsg *msg)
1188 {
1189         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1190                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1191                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1192                                 msg->payload.u64);
1193
1194                 return -1;
1195         }
1196
1197         dev->mtu = msg->payload.u64;
1198
1199         return 0;
1200 }
1201
1202 static int
1203 vhost_user_set_req_fd(struct virtio_net *dev, struct VhostUserMsg *msg)
1204 {
1205         int fd = msg->fds[0];
1206
1207         if (fd < 0) {
1208                 RTE_LOG(ERR, VHOST_CONFIG,
1209                                 "Invalid file descriptor for slave channel (%d)\n",
1210                                 fd);
1211                 return -1;
1212         }
1213
1214         dev->slave_req_fd = fd;
1215
1216         return 0;
1217 }
1218
1219 static int
1220 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1221 {
1222         struct vhost_vring_addr *ra;
1223         uint64_t start, end;
1224
1225         start = imsg->iova;
1226         end = start + imsg->size;
1227
1228         ra = &vq->ring_addrs;
1229         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1230                 return 1;
1231         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1232                 return 1;
1233         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1234                 return 1;
1235
1236         return 0;
1237 }
1238
1239 static int
1240 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1241                                 struct vhost_iotlb_msg *imsg)
1242 {
1243         uint64_t istart, iend, vstart, vend;
1244
1245         istart = imsg->iova;
1246         iend = istart + imsg->size - 1;
1247
1248         vstart = (uintptr_t)vq->desc;
1249         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1250         if (vstart <= iend && istart <= vend)
1251                 return 1;
1252
1253         vstart = (uintptr_t)vq->avail;
1254         vend = vstart + sizeof(struct vring_avail);
1255         vend += sizeof(uint16_t) * vq->size - 1;
1256         if (vstart <= iend && istart <= vend)
1257                 return 1;
1258
1259         vstart = (uintptr_t)vq->used;
1260         vend = vstart + sizeof(struct vring_used);
1261         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1262         if (vstart <= iend && istart <= vend)
1263                 return 1;
1264
1265         return 0;
1266 }
1267
1268 static int
1269 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1270 {
1271         struct virtio_net *dev = *pdev;
1272         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1273         uint16_t i;
1274         uint64_t vva, len;
1275
1276         switch (imsg->type) {
1277         case VHOST_IOTLB_UPDATE:
1278                 len = imsg->size;
1279                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1280                 if (!vva)
1281                         return -1;
1282
1283                 for (i = 0; i < dev->nr_vring; i++) {
1284                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1285
1286                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1287                                         len, imsg->perm);
1288
1289                         if (is_vring_iotlb_update(vq, imsg))
1290                                 *pdev = dev = translate_ring_addresses(dev, i);
1291                 }
1292                 break;
1293         case VHOST_IOTLB_INVALIDATE:
1294                 for (i = 0; i < dev->nr_vring; i++) {
1295                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1296
1297                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1298                                         imsg->size);
1299
1300                         if (is_vring_iotlb_invalidate(vq, imsg))
1301                                 vring_invalidate(dev, vq);
1302                 }
1303                 break;
1304         default:
1305                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1306                                 imsg->type);
1307                 return -1;
1308         }
1309
1310         return 0;
1311 }
1312
1313 /* return bytes# of read on success or negative val on failure. */
1314 static int
1315 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1316 {
1317         int ret;
1318
1319         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1320                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1321         if (ret <= 0)
1322                 return ret;
1323
1324         if (msg && msg->size) {
1325                 if (msg->size > sizeof(msg->payload)) {
1326                         RTE_LOG(ERR, VHOST_CONFIG,
1327                                 "invalid msg size: %d\n", msg->size);
1328                         return -1;
1329                 }
1330                 ret = read(sockfd, &msg->payload, msg->size);
1331                 if (ret <= 0)
1332                         return ret;
1333                 if (ret != (int)msg->size) {
1334                         RTE_LOG(ERR, VHOST_CONFIG,
1335                                 "read control message failed\n");
1336                         return -1;
1337                 }
1338         }
1339
1340         return ret;
1341 }
1342
1343 static int
1344 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1345 {
1346         if (!msg)
1347                 return 0;
1348
1349         return send_fd_message(sockfd, (char *)msg,
1350                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1351 }
1352
1353 static int
1354 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1355 {
1356         if (!msg)
1357                 return 0;
1358
1359         msg->flags &= ~VHOST_USER_VERSION_MASK;
1360         msg->flags &= ~VHOST_USER_NEED_REPLY;
1361         msg->flags |= VHOST_USER_VERSION;
1362         msg->flags |= VHOST_USER_REPLY_MASK;
1363
1364         return send_vhost_message(sockfd, msg, NULL, 0);
1365 }
1366
1367 /*
1368  * Allocate a queue pair if it hasn't been allocated yet
1369  */
1370 static int
1371 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev, VhostUserMsg *msg)
1372 {
1373         uint16_t vring_idx;
1374
1375         switch (msg->request.master) {
1376         case VHOST_USER_SET_VRING_KICK:
1377         case VHOST_USER_SET_VRING_CALL:
1378         case VHOST_USER_SET_VRING_ERR:
1379                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1380                 break;
1381         case VHOST_USER_SET_VRING_NUM:
1382         case VHOST_USER_SET_VRING_BASE:
1383         case VHOST_USER_SET_VRING_ENABLE:
1384                 vring_idx = msg->payload.state.index;
1385                 break;
1386         case VHOST_USER_SET_VRING_ADDR:
1387                 vring_idx = msg->payload.addr.index;
1388                 break;
1389         default:
1390                 return 0;
1391         }
1392
1393         if (vring_idx >= VHOST_MAX_VRING) {
1394                 RTE_LOG(ERR, VHOST_CONFIG,
1395                         "invalid vring index: %u\n", vring_idx);
1396                 return -1;
1397         }
1398
1399         if (dev->virtqueue[vring_idx])
1400                 return 0;
1401
1402         return alloc_vring_queue(dev, vring_idx);
1403 }
1404
1405 static void
1406 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1407 {
1408         unsigned int i = 0;
1409         unsigned int vq_num = 0;
1410
1411         while (vq_num < dev->nr_vring) {
1412                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1413
1414                 if (vq) {
1415                         rte_spinlock_lock(&vq->access_lock);
1416                         vq_num++;
1417                 }
1418                 i++;
1419         }
1420 }
1421
1422 static void
1423 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1424 {
1425         unsigned int i = 0;
1426         unsigned int vq_num = 0;
1427
1428         while (vq_num < dev->nr_vring) {
1429                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1430
1431                 if (vq) {
1432                         rte_spinlock_unlock(&vq->access_lock);
1433                         vq_num++;
1434                 }
1435                 i++;
1436         }
1437 }
1438
1439 int
1440 vhost_user_msg_handler(int vid, int fd)
1441 {
1442         struct virtio_net *dev;
1443         struct VhostUserMsg msg;
1444         struct rte_vdpa_device *vdpa_dev;
1445         int did = -1;
1446         int ret;
1447         int unlock_required = 0;
1448         uint32_t skip_master = 0;
1449
1450         dev = get_device(vid);
1451         if (dev == NULL)
1452                 return -1;
1453
1454         if (!dev->notify_ops) {
1455                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1456                 if (!dev->notify_ops) {
1457                         RTE_LOG(ERR, VHOST_CONFIG,
1458                                 "failed to get callback ops for driver %s\n",
1459                                 dev->ifname);
1460                         return -1;
1461                 }
1462         }
1463
1464         ret = read_vhost_message(fd, &msg);
1465         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1466                 if (ret < 0)
1467                         RTE_LOG(ERR, VHOST_CONFIG,
1468                                 "vhost read message failed\n");
1469                 else if (ret == 0)
1470                         RTE_LOG(INFO, VHOST_CONFIG,
1471                                 "vhost peer closed\n");
1472                 else
1473                         RTE_LOG(ERR, VHOST_CONFIG,
1474                                 "vhost read incorrect message\n");
1475
1476                 return -1;
1477         }
1478
1479         ret = 0;
1480         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1481                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1482                         vhost_message_str[msg.request.master]);
1483         else
1484                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1485                         vhost_message_str[msg.request.master]);
1486
1487         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1488         if (ret < 0) {
1489                 RTE_LOG(ERR, VHOST_CONFIG,
1490                         "failed to alloc queue\n");
1491                 return -1;
1492         }
1493
1494         /*
1495          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1496          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1497          * and device is destroyed. destroy_device waits for queues to be
1498          * inactive, so it is safe. Otherwise taking the access_lock
1499          * would cause a dead lock.
1500          */
1501         switch (msg.request.master) {
1502         case VHOST_USER_SET_FEATURES:
1503         case VHOST_USER_SET_PROTOCOL_FEATURES:
1504         case VHOST_USER_SET_OWNER:
1505         case VHOST_USER_SET_MEM_TABLE:
1506         case VHOST_USER_SET_LOG_BASE:
1507         case VHOST_USER_SET_LOG_FD:
1508         case VHOST_USER_SET_VRING_NUM:
1509         case VHOST_USER_SET_VRING_ADDR:
1510         case VHOST_USER_SET_VRING_BASE:
1511         case VHOST_USER_SET_VRING_KICK:
1512         case VHOST_USER_SET_VRING_CALL:
1513         case VHOST_USER_SET_VRING_ERR:
1514         case VHOST_USER_SET_VRING_ENABLE:
1515         case VHOST_USER_SEND_RARP:
1516         case VHOST_USER_NET_SET_MTU:
1517         case VHOST_USER_SET_SLAVE_REQ_FD:
1518                 vhost_user_lock_all_queue_pairs(dev);
1519                 unlock_required = 1;
1520                 break;
1521         default:
1522                 break;
1523
1524         }
1525
1526         if (dev->extern_ops.pre_msg_handle) {
1527                 uint32_t need_reply;
1528
1529                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1530                                 (void *)&msg, &need_reply, &skip_master);
1531                 if (ret < 0)
1532                         goto skip_to_reply;
1533
1534                 if (need_reply)
1535                         send_vhost_reply(fd, &msg);
1536
1537                 if (skip_master)
1538                         goto skip_to_post_handle;
1539         }
1540
1541         switch (msg.request.master) {
1542         case VHOST_USER_GET_FEATURES:
1543                 msg.payload.u64 = vhost_user_get_features(dev);
1544                 msg.size = sizeof(msg.payload.u64);
1545                 send_vhost_reply(fd, &msg);
1546                 break;
1547         case VHOST_USER_SET_FEATURES:
1548                 ret = vhost_user_set_features(dev, msg.payload.u64);
1549                 if (ret)
1550                         return -1;
1551                 break;
1552
1553         case VHOST_USER_GET_PROTOCOL_FEATURES:
1554                 vhost_user_get_protocol_features(dev, &msg);
1555                 send_vhost_reply(fd, &msg);
1556                 break;
1557         case VHOST_USER_SET_PROTOCOL_FEATURES:
1558                 vhost_user_set_protocol_features(dev, msg.payload.u64);
1559                 break;
1560
1561         case VHOST_USER_SET_OWNER:
1562                 vhost_user_set_owner();
1563                 break;
1564         case VHOST_USER_RESET_OWNER:
1565                 vhost_user_reset_owner(dev);
1566                 break;
1567
1568         case VHOST_USER_SET_MEM_TABLE:
1569                 ret = vhost_user_set_mem_table(dev, &msg);
1570                 break;
1571
1572         case VHOST_USER_SET_LOG_BASE:
1573                 vhost_user_set_log_base(dev, &msg);
1574
1575                 /* it needs a reply */
1576                 msg.size = sizeof(msg.payload.u64);
1577                 send_vhost_reply(fd, &msg);
1578                 break;
1579         case VHOST_USER_SET_LOG_FD:
1580                 close(msg.fds[0]);
1581                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1582                 break;
1583
1584         case VHOST_USER_SET_VRING_NUM:
1585                 vhost_user_set_vring_num(dev, &msg);
1586                 break;
1587         case VHOST_USER_SET_VRING_ADDR:
1588                 vhost_user_set_vring_addr(&dev, &msg);
1589                 break;
1590         case VHOST_USER_SET_VRING_BASE:
1591                 vhost_user_set_vring_base(dev, &msg);
1592                 break;
1593
1594         case VHOST_USER_GET_VRING_BASE:
1595                 vhost_user_get_vring_base(dev, &msg);
1596                 msg.size = sizeof(msg.payload.state);
1597                 send_vhost_reply(fd, &msg);
1598                 break;
1599
1600         case VHOST_USER_SET_VRING_KICK:
1601                 vhost_user_set_vring_kick(&dev, &msg);
1602                 break;
1603         case VHOST_USER_SET_VRING_CALL:
1604                 vhost_user_set_vring_call(dev, &msg);
1605                 break;
1606
1607         case VHOST_USER_SET_VRING_ERR:
1608                 if (!(msg.payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1609                         close(msg.fds[0]);
1610                 RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1611                 break;
1612
1613         case VHOST_USER_GET_QUEUE_NUM:
1614                 msg.payload.u64 = (uint64_t)vhost_user_get_queue_num(dev);
1615                 msg.size = sizeof(msg.payload.u64);
1616                 send_vhost_reply(fd, &msg);
1617                 break;
1618
1619         case VHOST_USER_SET_VRING_ENABLE:
1620                 vhost_user_set_vring_enable(dev, &msg);
1621                 break;
1622         case VHOST_USER_SEND_RARP:
1623                 vhost_user_send_rarp(dev, &msg);
1624                 break;
1625
1626         case VHOST_USER_NET_SET_MTU:
1627                 ret = vhost_user_net_set_mtu(dev, &msg);
1628                 break;
1629
1630         case VHOST_USER_SET_SLAVE_REQ_FD:
1631                 ret = vhost_user_set_req_fd(dev, &msg);
1632                 break;
1633
1634         case VHOST_USER_IOTLB_MSG:
1635                 ret = vhost_user_iotlb_msg(&dev, &msg);
1636                 break;
1637
1638         default:
1639                 ret = -1;
1640                 break;
1641         }
1642
1643 skip_to_post_handle:
1644         if (dev->extern_ops.post_msg_handle) {
1645                 uint32_t need_reply;
1646
1647                 ret = (*dev->extern_ops.post_msg_handle)(
1648                                 dev->vid, (void *)&msg, &need_reply);
1649                 if (ret < 0)
1650                         goto skip_to_reply;
1651
1652                 if (need_reply)
1653                         send_vhost_reply(fd, &msg);
1654         }
1655
1656 skip_to_reply:
1657         if (unlock_required)
1658                 vhost_user_unlock_all_queue_pairs(dev);
1659
1660         if (msg.flags & VHOST_USER_NEED_REPLY) {
1661                 msg.payload.u64 = !!ret;
1662                 msg.size = sizeof(msg.payload.u64);
1663                 send_vhost_reply(fd, &msg);
1664         }
1665
1666         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1667                 dev->flags |= VIRTIO_DEV_READY;
1668
1669                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1670                         if (dev->dequeue_zero_copy) {
1671                                 RTE_LOG(INFO, VHOST_CONFIG,
1672                                                 "dequeue zero copy is enabled\n");
1673                         }
1674
1675                         if (dev->notify_ops->new_device(dev->vid) == 0)
1676                                 dev->flags |= VIRTIO_DEV_RUNNING;
1677                 }
1678         }
1679
1680         did = dev->vdpa_dev_id;
1681         vdpa_dev = rte_vdpa_get_device(did);
1682         if (vdpa_dev && virtio_is_ready(dev) &&
1683                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1684                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1685                 if (vdpa_dev->ops->dev_conf)
1686                         vdpa_dev->ops->dev_conf(dev->vid);
1687                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1688         }
1689
1690         return 0;
1691 }
1692
1693 int
1694 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1695 {
1696         int ret;
1697         struct VhostUserMsg msg = {
1698                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1699                 .flags = VHOST_USER_VERSION,
1700                 .size = sizeof(msg.payload.iotlb),
1701                 .payload.iotlb = {
1702                         .iova = iova,
1703                         .perm = perm,
1704                         .type = VHOST_IOTLB_MISS,
1705                 },
1706         };
1707
1708         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1709         if (ret < 0) {
1710                 RTE_LOG(ERR, VHOST_CONFIG,
1711                                 "Failed to send IOTLB miss message (%d)\n",
1712                                 ret);
1713                 return ret;
1714         }
1715
1716         return 0;
1717 }