vhost: fix error handling when mem table gets updated
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(struct virtio_net **pdev __rte_unused,
131                 struct VhostUserMsg *msg __rte_unused)
132 {
133         return VH_RESULT_OK;
134 }
135
136 static int
137 vhost_user_reset_owner(struct virtio_net **pdev,
138                 struct VhostUserMsg *msg __rte_unused)
139 {
140         struct virtio_net *dev = *pdev;
141         vhost_destroy_device_notify(dev);
142
143         cleanup_device(dev, 0);
144         reset_device(dev);
145         return VH_RESULT_OK;
146 }
147
148 /*
149  * The features that we support are requested.
150  */
151 static int
152 vhost_user_get_features(struct virtio_net **pdev, struct VhostUserMsg *msg)
153 {
154         struct virtio_net *dev = *pdev;
155         uint64_t features = 0;
156
157         rte_vhost_driver_get_features(dev->ifname, &features);
158
159         msg->payload.u64 = features;
160         msg->size = sizeof(msg->payload.u64);
161
162         return VH_RESULT_REPLY;
163 }
164
165 /*
166  * The queue number that we support are requested.
167  */
168 static int
169 vhost_user_get_queue_num(struct virtio_net **pdev, struct VhostUserMsg *msg)
170 {
171         struct virtio_net *dev = *pdev;
172         uint32_t queue_num = 0;
173
174         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
175
176         msg->payload.u64 = (uint64_t)queue_num;
177         msg->size = sizeof(msg->payload.u64);
178
179         return VH_RESULT_REPLY;
180 }
181
182 /*
183  * We receive the negotiated features supported by us and the virtio device.
184  */
185 static int
186 vhost_user_set_features(struct virtio_net **pdev, struct VhostUserMsg *msg)
187 {
188         struct virtio_net *dev = *pdev;
189         uint64_t features = msg->payload.u64;
190         uint64_t vhost_features = 0;
191         struct rte_vdpa_device *vdpa_dev;
192         int did = -1;
193
194         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
195         if (features & ~vhost_features) {
196                 RTE_LOG(ERR, VHOST_CONFIG,
197                         "(%d) received invalid negotiated features.\n",
198                         dev->vid);
199                 return VH_RESULT_ERR;
200         }
201
202         if (dev->flags & VIRTIO_DEV_RUNNING) {
203                 if (dev->features == features)
204                         return VH_RESULT_OK;
205
206                 /*
207                  * Error out if master tries to change features while device is
208                  * in running state. The exception being VHOST_F_LOG_ALL, which
209                  * is enabled when the live-migration starts.
210                  */
211                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
212                         RTE_LOG(ERR, VHOST_CONFIG,
213                                 "(%d) features changed while device is running.\n",
214                                 dev->vid);
215                         return VH_RESULT_ERR;
216                 }
217
218                 if (dev->notify_ops->features_changed)
219                         dev->notify_ops->features_changed(dev->vid, features);
220         }
221
222         dev->features = features;
223         if (dev->features &
224                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
225                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
226         } else {
227                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
228         }
229         VHOST_LOG_DEBUG(VHOST_CONFIG,
230                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
231                 dev->vid,
232                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
233                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
234
235         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
236             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
237                 /*
238                  * Remove all but first queue pair if MQ hasn't been
239                  * negotiated. This is safe because the device is not
240                  * running at this stage.
241                  */
242                 while (dev->nr_vring > 2) {
243                         struct vhost_virtqueue *vq;
244
245                         vq = dev->virtqueue[--dev->nr_vring];
246                         if (!vq)
247                                 continue;
248
249                         dev->virtqueue[dev->nr_vring] = NULL;
250                         cleanup_vq(vq, 1);
251                         free_vq(dev, vq);
252                 }
253         }
254
255         did = dev->vdpa_dev_id;
256         vdpa_dev = rte_vdpa_get_device(did);
257         if (vdpa_dev && vdpa_dev->ops->set_features)
258                 vdpa_dev->ops->set_features(dev->vid);
259
260         return VH_RESULT_OK;
261 }
262
263 /*
264  * The virtio device sends us the size of the descriptor ring.
265  */
266 static int
267 vhost_user_set_vring_num(struct virtio_net **pdev,
268                          struct VhostUserMsg *msg)
269 {
270         struct virtio_net *dev = *pdev;
271         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
272
273         vq->size = msg->payload.state.num;
274
275         /* VIRTIO 1.0, 2.4 Virtqueues says:
276          *
277          *   Queue Size value is always a power of 2. The maximum Queue Size
278          *   value is 32768.
279          */
280         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
281                 RTE_LOG(ERR, VHOST_CONFIG,
282                         "invalid virtqueue size %u\n", vq->size);
283                 return VH_RESULT_ERR;
284         }
285
286         if (dev->dequeue_zero_copy) {
287                 vq->nr_zmbuf = 0;
288                 vq->last_zmbuf_idx = 0;
289                 vq->zmbuf_size = vq->size;
290                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
291                                          sizeof(struct zcopy_mbuf), 0);
292                 if (vq->zmbufs == NULL) {
293                         RTE_LOG(WARNING, VHOST_CONFIG,
294                                 "failed to allocate mem for zero copy; "
295                                 "zero copy is force disabled\n");
296                         dev->dequeue_zero_copy = 0;
297                 }
298                 TAILQ_INIT(&vq->zmbuf_list);
299         }
300
301         if (vq_is_packed(dev)) {
302                 vq->shadow_used_packed = rte_malloc(NULL,
303                                 vq->size *
304                                 sizeof(struct vring_used_elem_packed),
305                                 RTE_CACHE_LINE_SIZE);
306                 if (!vq->shadow_used_packed) {
307                         RTE_LOG(ERR, VHOST_CONFIG,
308                                         "failed to allocate memory for shadow used ring.\n");
309                         return VH_RESULT_ERR;
310                 }
311
312         } else {
313                 vq->shadow_used_split = rte_malloc(NULL,
314                                 vq->size * sizeof(struct vring_used_elem),
315                                 RTE_CACHE_LINE_SIZE);
316                 if (!vq->shadow_used_split) {
317                         RTE_LOG(ERR, VHOST_CONFIG,
318                                         "failed to allocate memory for shadow used ring.\n");
319                         return VH_RESULT_ERR;
320                 }
321         }
322
323         vq->batch_copy_elems = rte_malloc(NULL,
324                                 vq->size * sizeof(struct batch_copy_elem),
325                                 RTE_CACHE_LINE_SIZE);
326         if (!vq->batch_copy_elems) {
327                 RTE_LOG(ERR, VHOST_CONFIG,
328                         "failed to allocate memory for batching copy.\n");
329                 return VH_RESULT_ERR;
330         }
331
332         return VH_RESULT_OK;
333 }
334
335 /*
336  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
337  * same numa node as the memory of vring descriptor.
338  */
339 #ifdef RTE_LIBRTE_VHOST_NUMA
340 static struct virtio_net*
341 numa_realloc(struct virtio_net *dev, int index)
342 {
343         int oldnode, newnode;
344         struct virtio_net *old_dev;
345         struct vhost_virtqueue *old_vq, *vq;
346         struct zcopy_mbuf *new_zmbuf;
347         struct vring_used_elem *new_shadow_used_split;
348         struct vring_used_elem_packed *new_shadow_used_packed;
349         struct batch_copy_elem *new_batch_copy_elems;
350         int ret;
351
352         old_dev = dev;
353         vq = old_vq = dev->virtqueue[index];
354
355         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
356                             MPOL_F_NODE | MPOL_F_ADDR);
357
358         /* check if we need to reallocate vq */
359         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
360                              MPOL_F_NODE | MPOL_F_ADDR);
361         if (ret) {
362                 RTE_LOG(ERR, VHOST_CONFIG,
363                         "Unable to get vq numa information.\n");
364                 return dev;
365         }
366         if (oldnode != newnode) {
367                 RTE_LOG(INFO, VHOST_CONFIG,
368                         "reallocate vq from %d to %d node\n", oldnode, newnode);
369                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
370                 if (!vq)
371                         return dev;
372
373                 memcpy(vq, old_vq, sizeof(*vq));
374                 TAILQ_INIT(&vq->zmbuf_list);
375
376                 if (dev->dequeue_zero_copy) {
377                         new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
378                                         sizeof(struct zcopy_mbuf), 0, newnode);
379                         if (new_zmbuf) {
380                                 rte_free(vq->zmbufs);
381                                 vq->zmbufs = new_zmbuf;
382                         }
383                 }
384
385                 if (vq_is_packed(dev)) {
386                         new_shadow_used_packed = rte_malloc_socket(NULL,
387                                         vq->size *
388                                         sizeof(struct vring_used_elem_packed),
389                                         RTE_CACHE_LINE_SIZE,
390                                         newnode);
391                         if (new_shadow_used_packed) {
392                                 rte_free(vq->shadow_used_packed);
393                                 vq->shadow_used_packed = new_shadow_used_packed;
394                         }
395                 } else {
396                         new_shadow_used_split = rte_malloc_socket(NULL,
397                                         vq->size *
398                                         sizeof(struct vring_used_elem),
399                                         RTE_CACHE_LINE_SIZE,
400                                         newnode);
401                         if (new_shadow_used_split) {
402                                 rte_free(vq->shadow_used_split);
403                                 vq->shadow_used_split = new_shadow_used_split;
404                         }
405                 }
406
407                 new_batch_copy_elems = rte_malloc_socket(NULL,
408                         vq->size * sizeof(struct batch_copy_elem),
409                         RTE_CACHE_LINE_SIZE,
410                         newnode);
411                 if (new_batch_copy_elems) {
412                         rte_free(vq->batch_copy_elems);
413                         vq->batch_copy_elems = new_batch_copy_elems;
414                 }
415
416                 rte_free(old_vq);
417         }
418
419         /* check if we need to reallocate dev */
420         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
421                             MPOL_F_NODE | MPOL_F_ADDR);
422         if (ret) {
423                 RTE_LOG(ERR, VHOST_CONFIG,
424                         "Unable to get dev numa information.\n");
425                 goto out;
426         }
427         if (oldnode != newnode) {
428                 RTE_LOG(INFO, VHOST_CONFIG,
429                         "reallocate dev from %d to %d node\n",
430                         oldnode, newnode);
431                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
432                 if (!dev) {
433                         dev = old_dev;
434                         goto out;
435                 }
436
437                 memcpy(dev, old_dev, sizeof(*dev));
438                 rte_free(old_dev);
439         }
440
441 out:
442         dev->virtqueue[index] = vq;
443         vhost_devices[dev->vid] = dev;
444
445         if (old_vq != vq)
446                 vhost_user_iotlb_init(dev, index);
447
448         return dev;
449 }
450 #else
451 static struct virtio_net*
452 numa_realloc(struct virtio_net *dev, int index __rte_unused)
453 {
454         return dev;
455 }
456 #endif
457
458 /* Converts QEMU virtual address to Vhost virtual address. */
459 static uint64_t
460 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
461 {
462         struct rte_vhost_mem_region *r;
463         uint32_t i;
464
465         /* Find the region where the address lives. */
466         for (i = 0; i < dev->mem->nregions; i++) {
467                 r = &dev->mem->regions[i];
468
469                 if (qva >= r->guest_user_addr &&
470                     qva <  r->guest_user_addr + r->size) {
471
472                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
473                                 *len = r->guest_user_addr + r->size - qva;
474
475                         return qva - r->guest_user_addr +
476                                r->host_user_addr;
477                 }
478         }
479         *len = 0;
480
481         return 0;
482 }
483
484
485 /*
486  * Converts ring address to Vhost virtual address.
487  * If IOMMU is enabled, the ring address is a guest IO virtual address,
488  * else it is a QEMU virtual address.
489  */
490 static uint64_t
491 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
492                 uint64_t ra, uint64_t *size)
493 {
494         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
495                 uint64_t vva;
496
497                 vva = vhost_user_iotlb_cache_find(vq, ra,
498                                         size, VHOST_ACCESS_RW);
499                 if (!vva)
500                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
501
502                 return vva;
503         }
504
505         return qva_to_vva(dev, ra, size);
506 }
507
508 static struct virtio_net *
509 translate_ring_addresses(struct virtio_net *dev, int vq_index)
510 {
511         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
512         struct vhost_vring_addr *addr = &vq->ring_addrs;
513         uint64_t len;
514
515         if (vq_is_packed(dev)) {
516                 len = sizeof(struct vring_packed_desc) * vq->size;
517                 vq->desc_packed = (struct vring_packed_desc *)(uintptr_t)
518                         ring_addr_to_vva(dev, vq, addr->desc_user_addr, &len);
519                 vq->log_guest_addr = 0;
520                 if (vq->desc_packed == NULL ||
521                                 len != sizeof(struct vring_packed_desc) *
522                                 vq->size) {
523                         RTE_LOG(DEBUG, VHOST_CONFIG,
524                                 "(%d) failed to map desc_packed ring.\n",
525                                 dev->vid);
526                         return dev;
527                 }
528
529                 dev = numa_realloc(dev, vq_index);
530                 vq = dev->virtqueue[vq_index];
531                 addr = &vq->ring_addrs;
532
533                 len = sizeof(struct vring_packed_desc_event);
534                 vq->driver_event = (struct vring_packed_desc_event *)
535                                         (uintptr_t)ring_addr_to_vva(dev,
536                                         vq, addr->avail_user_addr, &len);
537                 if (vq->driver_event == NULL ||
538                                 len != sizeof(struct vring_packed_desc_event)) {
539                         RTE_LOG(DEBUG, VHOST_CONFIG,
540                                 "(%d) failed to find driver area address.\n",
541                                 dev->vid);
542                         return dev;
543                 }
544
545                 len = sizeof(struct vring_packed_desc_event);
546                 vq->device_event = (struct vring_packed_desc_event *)
547                                         (uintptr_t)ring_addr_to_vva(dev,
548                                         vq, addr->used_user_addr, &len);
549                 if (vq->device_event == NULL ||
550                                 len != sizeof(struct vring_packed_desc_event)) {
551                         RTE_LOG(DEBUG, VHOST_CONFIG,
552                                 "(%d) failed to find device area address.\n",
553                                 dev->vid);
554                         return dev;
555                 }
556
557                 return dev;
558         }
559
560         /* The addresses are converted from QEMU virtual to Vhost virtual. */
561         if (vq->desc && vq->avail && vq->used)
562                 return dev;
563
564         len = sizeof(struct vring_desc) * vq->size;
565         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
566                         vq, addr->desc_user_addr, &len);
567         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
568                 RTE_LOG(DEBUG, VHOST_CONFIG,
569                         "(%d) failed to map desc ring.\n",
570                         dev->vid);
571                 return dev;
572         }
573
574         dev = numa_realloc(dev, vq_index);
575         vq = dev->virtqueue[vq_index];
576         addr = &vq->ring_addrs;
577
578         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
579         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
580                         vq, addr->avail_user_addr, &len);
581         if (vq->avail == 0 ||
582                         len != sizeof(struct vring_avail) +
583                         sizeof(uint16_t) * vq->size) {
584                 RTE_LOG(DEBUG, VHOST_CONFIG,
585                         "(%d) failed to map avail ring.\n",
586                         dev->vid);
587                 return dev;
588         }
589
590         len = sizeof(struct vring_used) +
591                 sizeof(struct vring_used_elem) * vq->size;
592         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
593                         vq, addr->used_user_addr, &len);
594         if (vq->used == 0 || len != sizeof(struct vring_used) +
595                         sizeof(struct vring_used_elem) * vq->size) {
596                 RTE_LOG(DEBUG, VHOST_CONFIG,
597                         "(%d) failed to map used ring.\n",
598                         dev->vid);
599                 return dev;
600         }
601
602         if (vq->last_used_idx != vq->used->idx) {
603                 RTE_LOG(WARNING, VHOST_CONFIG,
604                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
605                         "some packets maybe resent for Tx and dropped for Rx\n",
606                         vq->last_used_idx, vq->used->idx);
607                 vq->last_used_idx  = vq->used->idx;
608                 vq->last_avail_idx = vq->used->idx;
609         }
610
611         vq->log_guest_addr = addr->log_guest_addr;
612
613         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
614                         dev->vid, vq->desc);
615         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
616                         dev->vid, vq->avail);
617         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
618                         dev->vid, vq->used);
619         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
620                         dev->vid, vq->log_guest_addr);
621
622         return dev;
623 }
624
625 /*
626  * The virtio device sends us the desc, used and avail ring addresses.
627  * This function then converts these to our address space.
628  */
629 static int
630 vhost_user_set_vring_addr(struct virtio_net **pdev, struct VhostUserMsg *msg)
631 {
632         struct virtio_net *dev = *pdev;
633         struct vhost_virtqueue *vq;
634         struct vhost_vring_addr *addr = &msg->payload.addr;
635
636         if (dev->mem == NULL)
637                 return VH_RESULT_ERR;
638
639         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
640         vq = dev->virtqueue[msg->payload.addr.index];
641
642         /*
643          * Rings addresses should not be interpreted as long as the ring is not
644          * started and enabled
645          */
646         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
647
648         vring_invalidate(dev, vq);
649
650         if (vq->enabled && (dev->features &
651                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
652                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
653                 if (!dev)
654                         return VH_RESULT_ERR;
655
656                 *pdev = dev;
657         }
658
659         return VH_RESULT_OK;
660 }
661
662 /*
663  * The virtio device sends us the available ring last used index.
664  */
665 static int
666 vhost_user_set_vring_base(struct virtio_net **pdev,
667                           struct VhostUserMsg *msg)
668 {
669         struct virtio_net *dev = *pdev;
670         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
671                         msg->payload.state.num;
672         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
673                         msg->payload.state.num;
674
675         return VH_RESULT_OK;
676 }
677
678 static int
679 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
680                    uint64_t host_phys_addr, uint64_t size)
681 {
682         struct guest_page *page, *last_page;
683
684         if (dev->nr_guest_pages == dev->max_guest_pages) {
685                 dev->max_guest_pages *= 2;
686                 dev->guest_pages = realloc(dev->guest_pages,
687                                         dev->max_guest_pages * sizeof(*page));
688                 if (!dev->guest_pages) {
689                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
690                         return -1;
691                 }
692         }
693
694         if (dev->nr_guest_pages > 0) {
695                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
696                 /* merge if the two pages are continuous */
697                 if (host_phys_addr == last_page->host_phys_addr +
698                                       last_page->size) {
699                         last_page->size += size;
700                         return 0;
701                 }
702         }
703
704         page = &dev->guest_pages[dev->nr_guest_pages++];
705         page->guest_phys_addr = guest_phys_addr;
706         page->host_phys_addr  = host_phys_addr;
707         page->size = size;
708
709         return 0;
710 }
711
712 static int
713 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
714                 uint64_t page_size)
715 {
716         uint64_t reg_size = reg->size;
717         uint64_t host_user_addr  = reg->host_user_addr;
718         uint64_t guest_phys_addr = reg->guest_phys_addr;
719         uint64_t host_phys_addr;
720         uint64_t size;
721
722         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
723         size = page_size - (guest_phys_addr & (page_size - 1));
724         size = RTE_MIN(size, reg_size);
725
726         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
727                 return -1;
728
729         host_user_addr  += size;
730         guest_phys_addr += size;
731         reg_size -= size;
732
733         while (reg_size > 0) {
734                 size = RTE_MIN(reg_size, page_size);
735                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
736                                                   host_user_addr);
737                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
738                                 size) < 0)
739                         return -1;
740
741                 host_user_addr  += size;
742                 guest_phys_addr += size;
743                 reg_size -= size;
744         }
745
746         return 0;
747 }
748
749 #ifdef RTE_LIBRTE_VHOST_DEBUG
750 /* TODO: enable it only in debug mode? */
751 static void
752 dump_guest_pages(struct virtio_net *dev)
753 {
754         uint32_t i;
755         struct guest_page *page;
756
757         for (i = 0; i < dev->nr_guest_pages; i++) {
758                 page = &dev->guest_pages[i];
759
760                 RTE_LOG(INFO, VHOST_CONFIG,
761                         "guest physical page region %u\n"
762                         "\t guest_phys_addr: %" PRIx64 "\n"
763                         "\t host_phys_addr : %" PRIx64 "\n"
764                         "\t size           : %" PRIx64 "\n",
765                         i,
766                         page->guest_phys_addr,
767                         page->host_phys_addr,
768                         page->size);
769         }
770 }
771 #else
772 #define dump_guest_pages(dev)
773 #endif
774
775 static bool
776 vhost_memory_changed(struct VhostUserMemory *new,
777                      struct rte_vhost_memory *old)
778 {
779         uint32_t i;
780
781         if (new->nregions != old->nregions)
782                 return true;
783
784         for (i = 0; i < new->nregions; ++i) {
785                 VhostUserMemoryRegion *new_r = &new->regions[i];
786                 struct rte_vhost_mem_region *old_r = &old->regions[i];
787
788                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
789                         return true;
790                 if (new_r->memory_size != old_r->size)
791                         return true;
792                 if (new_r->userspace_addr != old_r->guest_user_addr)
793                         return true;
794         }
795
796         return false;
797 }
798
799 static int
800 vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *msg)
801 {
802         struct virtio_net *dev = *pdev;
803         struct VhostUserMemory memory = msg->payload.memory;
804         struct rte_vhost_mem_region *reg;
805         void *mmap_addr;
806         uint64_t mmap_size;
807         uint64_t mmap_offset;
808         uint64_t alignment;
809         uint32_t i;
810         int populate;
811         int fd;
812
813         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
814                 RTE_LOG(ERR, VHOST_CONFIG,
815                         "too many memory regions (%u)\n", memory.nregions);
816                 return VH_RESULT_ERR;
817         }
818
819         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
820                 RTE_LOG(INFO, VHOST_CONFIG,
821                         "(%d) memory regions not changed\n", dev->vid);
822
823                 for (i = 0; i < memory.nregions; i++)
824                         close(msg->fds[i]);
825
826                 return VH_RESULT_OK;
827         }
828
829         if (dev->mem) {
830                 free_mem_region(dev);
831                 rte_free(dev->mem);
832                 dev->mem = NULL;
833         }
834
835         /* Flush IOTLB cache as previous HVAs are now invalid */
836         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))
837                 for (i = 0; i < dev->nr_vring; i++)
838                         vhost_user_iotlb_flush_all(dev->virtqueue[i]);
839
840         dev->nr_guest_pages = 0;
841         if (!dev->guest_pages) {
842                 dev->max_guest_pages = 8;
843                 dev->guest_pages = malloc(dev->max_guest_pages *
844                                                 sizeof(struct guest_page));
845                 if (dev->guest_pages == NULL) {
846                         RTE_LOG(ERR, VHOST_CONFIG,
847                                 "(%d) failed to allocate memory "
848                                 "for dev->guest_pages\n",
849                                 dev->vid);
850                         return VH_RESULT_ERR;
851                 }
852         }
853
854         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
855                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
856         if (dev->mem == NULL) {
857                 RTE_LOG(ERR, VHOST_CONFIG,
858                         "(%d) failed to allocate memory for dev->mem\n",
859                         dev->vid);
860                 return VH_RESULT_ERR;
861         }
862         dev->mem->nregions = memory.nregions;
863
864         for (i = 0; i < memory.nregions; i++) {
865                 fd  = msg->fds[i];
866                 reg = &dev->mem->regions[i];
867
868                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
869                 reg->guest_user_addr = memory.regions[i].userspace_addr;
870                 reg->size            = memory.regions[i].memory_size;
871                 reg->fd              = fd;
872
873                 mmap_offset = memory.regions[i].mmap_offset;
874
875                 /* Check for memory_size + mmap_offset overflow */
876                 if (mmap_offset >= -reg->size) {
877                         RTE_LOG(ERR, VHOST_CONFIG,
878                                 "mmap_offset (%#"PRIx64") and memory_size "
879                                 "(%#"PRIx64") overflow\n",
880                                 mmap_offset, reg->size);
881                         goto err_mmap;
882                 }
883
884                 mmap_size = reg->size + mmap_offset;
885
886                 /* mmap() without flag of MAP_ANONYMOUS, should be called
887                  * with length argument aligned with hugepagesz at older
888                  * longterm version Linux, like 2.6.32 and 3.2.72, or
889                  * mmap() will fail with EINVAL.
890                  *
891                  * to avoid failure, make sure in caller to keep length
892                  * aligned.
893                  */
894                 alignment = get_blk_size(fd);
895                 if (alignment == (uint64_t)-1) {
896                         RTE_LOG(ERR, VHOST_CONFIG,
897                                 "couldn't get hugepage size through fstat\n");
898                         goto err_mmap;
899                 }
900                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
901
902                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
903                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
904                                  MAP_SHARED | populate, fd, 0);
905
906                 if (mmap_addr == MAP_FAILED) {
907                         RTE_LOG(ERR, VHOST_CONFIG,
908                                 "mmap region %u failed.\n", i);
909                         goto err_mmap;
910                 }
911
912                 reg->mmap_addr = mmap_addr;
913                 reg->mmap_size = mmap_size;
914                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
915                                       mmap_offset;
916
917                 if (dev->dequeue_zero_copy)
918                         if (add_guest_pages(dev, reg, alignment) < 0) {
919                                 RTE_LOG(ERR, VHOST_CONFIG,
920                                         "adding guest pages to region %u failed.\n",
921                                         i);
922                                 goto err_mmap;
923                         }
924
925                 RTE_LOG(INFO, VHOST_CONFIG,
926                         "guest memory region %u, size: 0x%" PRIx64 "\n"
927                         "\t guest physical addr: 0x%" PRIx64 "\n"
928                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
929                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
930                         "\t mmap addr : 0x%" PRIx64 "\n"
931                         "\t mmap size : 0x%" PRIx64 "\n"
932                         "\t mmap align: 0x%" PRIx64 "\n"
933                         "\t mmap off  : 0x%" PRIx64 "\n",
934                         i, reg->size,
935                         reg->guest_phys_addr,
936                         reg->guest_user_addr,
937                         reg->host_user_addr,
938                         (uint64_t)(uintptr_t)mmap_addr,
939                         mmap_size,
940                         alignment,
941                         mmap_offset);
942         }
943
944         for (i = 0; i < dev->nr_vring; i++) {
945                 struct vhost_virtqueue *vq = dev->virtqueue[i];
946
947                 if (vq->desc || vq->avail || vq->used) {
948                         /*
949                          * If the memory table got updated, the ring addresses
950                          * need to be translated again as virtual addresses have
951                          * changed.
952                          */
953                         vring_invalidate(dev, vq);
954
955                         dev = translate_ring_addresses(dev, i);
956                         if (!dev) {
957                                 dev = *pdev;
958                                 goto err_mmap;
959                         }
960
961                         *pdev = dev;
962                 }
963         }
964
965         dump_guest_pages(dev);
966
967         return VH_RESULT_OK;
968
969 err_mmap:
970         free_mem_region(dev);
971         rte_free(dev->mem);
972         dev->mem = NULL;
973         return VH_RESULT_ERR;
974 }
975
976 static bool
977 vq_is_ready(struct virtio_net *dev, struct vhost_virtqueue *vq)
978 {
979         bool rings_ok;
980
981         if (!vq)
982                 return false;
983
984         if (vq_is_packed(dev))
985                 rings_ok = !!vq->desc_packed;
986         else
987                 rings_ok = vq->desc && vq->avail && vq->used;
988
989         return rings_ok &&
990                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
991                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
992 }
993
994 static int
995 virtio_is_ready(struct virtio_net *dev)
996 {
997         struct vhost_virtqueue *vq;
998         uint32_t i;
999
1000         if (dev->nr_vring == 0)
1001                 return 0;
1002
1003         for (i = 0; i < dev->nr_vring; i++) {
1004                 vq = dev->virtqueue[i];
1005
1006                 if (!vq_is_ready(dev, vq))
1007                         return 0;
1008         }
1009
1010         RTE_LOG(INFO, VHOST_CONFIG,
1011                 "virtio is now ready for processing.\n");
1012         return 1;
1013 }
1014
1015 static int
1016 vhost_user_set_vring_call(struct virtio_net **pdev, struct VhostUserMsg *msg)
1017 {
1018         struct virtio_net *dev = *pdev;
1019         struct vhost_vring_file file;
1020         struct vhost_virtqueue *vq;
1021
1022         file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1023         if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
1024                 file.fd = VIRTIO_INVALID_EVENTFD;
1025         else
1026                 file.fd = msg->fds[0];
1027         RTE_LOG(INFO, VHOST_CONFIG,
1028                 "vring call idx:%d file:%d\n", file.index, file.fd);
1029
1030         vq = dev->virtqueue[file.index];
1031         if (vq->callfd >= 0)
1032                 close(vq->callfd);
1033
1034         vq->callfd = file.fd;
1035
1036         return VH_RESULT_OK;
1037 }
1038
1039 static int vhost_user_set_vring_err(struct virtio_net **pdev __rte_unused,
1040                         struct VhostUserMsg *msg)
1041 {
1042         if (!(msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1043                 close(msg->fds[0]);
1044         RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1045
1046         return VH_RESULT_OK;
1047 }
1048
1049 static int
1050 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *msg)
1051 {
1052         struct virtio_net *dev = *pdev;
1053         struct vhost_vring_file file;
1054         struct vhost_virtqueue *vq;
1055
1056         file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1057         if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
1058                 file.fd = VIRTIO_INVALID_EVENTFD;
1059         else
1060                 file.fd = msg->fds[0];
1061         RTE_LOG(INFO, VHOST_CONFIG,
1062                 "vring kick idx:%d file:%d\n", file.index, file.fd);
1063
1064         /* Interpret ring addresses only when ring is started. */
1065         dev = translate_ring_addresses(dev, file.index);
1066         if (!dev)
1067                 return VH_RESULT_ERR;
1068
1069         *pdev = dev;
1070
1071         vq = dev->virtqueue[file.index];
1072
1073         /*
1074          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
1075          * the ring starts already enabled. Otherwise, it is enabled via
1076          * the SET_VRING_ENABLE message.
1077          */
1078         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
1079                 vq->enabled = 1;
1080
1081         if (vq->kickfd >= 0)
1082                 close(vq->kickfd);
1083         vq->kickfd = file.fd;
1084
1085         return VH_RESULT_OK;
1086 }
1087
1088 static void
1089 free_zmbufs(struct vhost_virtqueue *vq)
1090 {
1091         struct zcopy_mbuf *zmbuf, *next;
1092
1093         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
1094              zmbuf != NULL; zmbuf = next) {
1095                 next = TAILQ_NEXT(zmbuf, next);
1096
1097                 rte_pktmbuf_free(zmbuf->mbuf);
1098                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
1099         }
1100
1101         rte_free(vq->zmbufs);
1102 }
1103
1104 /*
1105  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
1106  */
1107 static int
1108 vhost_user_get_vring_base(struct virtio_net **pdev,
1109                           struct VhostUserMsg *msg)
1110 {
1111         struct virtio_net *dev = *pdev;
1112         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
1113
1114         /* We have to stop the queue (virtio) if it is running. */
1115         vhost_destroy_device_notify(dev);
1116
1117         dev->flags &= ~VIRTIO_DEV_READY;
1118         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
1119
1120         /* Here we are safe to get the last avail index */
1121         msg->payload.state.num = vq->last_avail_idx;
1122
1123         RTE_LOG(INFO, VHOST_CONFIG,
1124                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1125                 msg->payload.state.num);
1126         /*
1127          * Based on current qemu vhost-user implementation, this message is
1128          * sent and only sent in vhost_vring_stop.
1129          * TODO: cleanup the vring, it isn't usable since here.
1130          */
1131         if (vq->kickfd >= 0)
1132                 close(vq->kickfd);
1133
1134         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1135
1136         if (vq->callfd >= 0)
1137                 close(vq->callfd);
1138
1139         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1140
1141         if (dev->dequeue_zero_copy)
1142                 free_zmbufs(vq);
1143         if (vq_is_packed(dev)) {
1144                 rte_free(vq->shadow_used_packed);
1145                 vq->shadow_used_packed = NULL;
1146         } else {
1147                 rte_free(vq->shadow_used_split);
1148                 vq->shadow_used_split = NULL;
1149         }
1150
1151         rte_free(vq->batch_copy_elems);
1152         vq->batch_copy_elems = NULL;
1153
1154         msg->size = sizeof(msg->payload.state);
1155
1156         return VH_RESULT_REPLY;
1157 }
1158
1159 /*
1160  * when virtio queues are ready to work, qemu will send us to
1161  * enable the virtio queue pair.
1162  */
1163 static int
1164 vhost_user_set_vring_enable(struct virtio_net **pdev,
1165                             struct VhostUserMsg *msg)
1166 {
1167         struct virtio_net *dev = *pdev;
1168         int enable = (int)msg->payload.state.num;
1169         int index = (int)msg->payload.state.index;
1170         struct rte_vdpa_device *vdpa_dev;
1171         int did = -1;
1172
1173         RTE_LOG(INFO, VHOST_CONFIG,
1174                 "set queue enable: %d to qp idx: %d\n",
1175                 enable, index);
1176
1177         did = dev->vdpa_dev_id;
1178         vdpa_dev = rte_vdpa_get_device(did);
1179         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1180                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1181
1182         if (dev->notify_ops->vring_state_changed)
1183                 dev->notify_ops->vring_state_changed(dev->vid,
1184                                 index, enable);
1185
1186         dev->virtqueue[index]->enabled = enable;
1187
1188         return VH_RESULT_OK;
1189 }
1190
1191 static int
1192 vhost_user_get_protocol_features(struct virtio_net **pdev,
1193                                  struct VhostUserMsg *msg)
1194 {
1195         struct virtio_net *dev = *pdev;
1196         uint64_t features, protocol_features;
1197
1198         rte_vhost_driver_get_features(dev->ifname, &features);
1199         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1200
1201         /*
1202          * REPLY_ACK protocol feature is only mandatory for now
1203          * for IOMMU feature. If IOMMU is explicitly disabled by the
1204          * application, disable also REPLY_ACK feature for older buggy
1205          * Qemu versions (from v2.7.0 to v2.9.0).
1206          */
1207         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1208                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1209
1210         msg->payload.u64 = protocol_features;
1211         msg->size = sizeof(msg->payload.u64);
1212
1213         return VH_RESULT_REPLY;
1214 }
1215
1216 static int
1217 vhost_user_set_protocol_features(struct virtio_net **pdev,
1218                                  struct VhostUserMsg *msg)
1219 {
1220         struct virtio_net *dev = *pdev;
1221         uint64_t protocol_features = msg->payload.u64;
1222         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES) {
1223                 RTE_LOG(ERR, VHOST_CONFIG,
1224                         "(%d) received invalid protocol features.\n",
1225                         dev->vid);
1226                 return VH_RESULT_ERR;
1227         }
1228
1229         dev->protocol_features = protocol_features;
1230
1231         return VH_RESULT_OK;
1232 }
1233
1234 static int
1235 vhost_user_set_log_base(struct virtio_net **pdev, struct VhostUserMsg *msg)
1236 {
1237         struct virtio_net *dev = *pdev;
1238         int fd = msg->fds[0];
1239         uint64_t size, off;
1240         void *addr;
1241
1242         if (fd < 0) {
1243                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1244                 return VH_RESULT_ERR;
1245         }
1246
1247         if (msg->size != sizeof(VhostUserLog)) {
1248                 RTE_LOG(ERR, VHOST_CONFIG,
1249                         "invalid log base msg size: %"PRId32" != %d\n",
1250                         msg->size, (int)sizeof(VhostUserLog));
1251                 return VH_RESULT_ERR;
1252         }
1253
1254         size = msg->payload.log.mmap_size;
1255         off  = msg->payload.log.mmap_offset;
1256
1257         /* Don't allow mmap_offset to point outside the mmap region */
1258         if (off > size) {
1259                 RTE_LOG(ERR, VHOST_CONFIG,
1260                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1261                         off, size);
1262                 return VH_RESULT_ERR;
1263         }
1264
1265         RTE_LOG(INFO, VHOST_CONFIG,
1266                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1267                 size, off);
1268
1269         /*
1270          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1271          * fail when offset is not page size aligned.
1272          */
1273         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1274         close(fd);
1275         if (addr == MAP_FAILED) {
1276                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1277                 return VH_RESULT_ERR;
1278         }
1279
1280         /*
1281          * Free previously mapped log memory on occasionally
1282          * multiple VHOST_USER_SET_LOG_BASE.
1283          */
1284         if (dev->log_addr) {
1285                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1286         }
1287         dev->log_addr = (uint64_t)(uintptr_t)addr;
1288         dev->log_base = dev->log_addr + off;
1289         dev->log_size = size;
1290
1291         /*
1292          * The spec is not clear about it (yet), but QEMU doesn't expect
1293          * any payload in the reply.
1294          */
1295         msg->size = 0;
1296
1297         return VH_RESULT_REPLY;
1298 }
1299
1300 static int vhost_user_set_log_fd(struct virtio_net **pdev __rte_unused,
1301                         struct VhostUserMsg *msg)
1302 {
1303         close(msg->fds[0]);
1304         RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1305
1306         return VH_RESULT_OK;
1307 }
1308
1309 /*
1310  * An rarp packet is constructed and broadcasted to notify switches about
1311  * the new location of the migrated VM, so that packets from outside will
1312  * not be lost after migration.
1313  *
1314  * However, we don't actually "send" a rarp packet here, instead, we set
1315  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1316  */
1317 static int
1318 vhost_user_send_rarp(struct virtio_net **pdev, struct VhostUserMsg *msg)
1319 {
1320         struct virtio_net *dev = *pdev;
1321         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1322         struct rte_vdpa_device *vdpa_dev;
1323         int did = -1;
1324
1325         RTE_LOG(DEBUG, VHOST_CONFIG,
1326                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1327                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1328         memcpy(dev->mac.addr_bytes, mac, 6);
1329
1330         /*
1331          * Set the flag to inject a RARP broadcast packet at
1332          * rte_vhost_dequeue_burst().
1333          *
1334          * rte_smp_wmb() is for making sure the mac is copied
1335          * before the flag is set.
1336          */
1337         rte_smp_wmb();
1338         rte_atomic16_set(&dev->broadcast_rarp, 1);
1339         did = dev->vdpa_dev_id;
1340         vdpa_dev = rte_vdpa_get_device(did);
1341         if (vdpa_dev && vdpa_dev->ops->migration_done)
1342                 vdpa_dev->ops->migration_done(dev->vid);
1343
1344         return VH_RESULT_OK;
1345 }
1346
1347 static int
1348 vhost_user_net_set_mtu(struct virtio_net **pdev, struct VhostUserMsg *msg)
1349 {
1350         struct virtio_net *dev = *pdev;
1351         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1352                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1353                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1354                                 msg->payload.u64);
1355
1356                 return VH_RESULT_ERR;
1357         }
1358
1359         dev->mtu = msg->payload.u64;
1360
1361         return VH_RESULT_OK;
1362 }
1363
1364 static int
1365 vhost_user_set_req_fd(struct virtio_net **pdev, struct VhostUserMsg *msg)
1366 {
1367         struct virtio_net *dev = *pdev;
1368         int fd = msg->fds[0];
1369
1370         if (fd < 0) {
1371                 RTE_LOG(ERR, VHOST_CONFIG,
1372                                 "Invalid file descriptor for slave channel (%d)\n",
1373                                 fd);
1374                 return VH_RESULT_ERR;
1375         }
1376
1377         dev->slave_req_fd = fd;
1378
1379         return VH_RESULT_OK;
1380 }
1381
1382 static int
1383 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1384 {
1385         struct vhost_vring_addr *ra;
1386         uint64_t start, end;
1387
1388         start = imsg->iova;
1389         end = start + imsg->size;
1390
1391         ra = &vq->ring_addrs;
1392         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1393                 return 1;
1394         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1395                 return 1;
1396         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1397                 return 1;
1398
1399         return 0;
1400 }
1401
1402 static int
1403 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1404                                 struct vhost_iotlb_msg *imsg)
1405 {
1406         uint64_t istart, iend, vstart, vend;
1407
1408         istart = imsg->iova;
1409         iend = istart + imsg->size - 1;
1410
1411         vstart = (uintptr_t)vq->desc;
1412         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1413         if (vstart <= iend && istart <= vend)
1414                 return 1;
1415
1416         vstart = (uintptr_t)vq->avail;
1417         vend = vstart + sizeof(struct vring_avail);
1418         vend += sizeof(uint16_t) * vq->size - 1;
1419         if (vstart <= iend && istart <= vend)
1420                 return 1;
1421
1422         vstart = (uintptr_t)vq->used;
1423         vend = vstart + sizeof(struct vring_used);
1424         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1425         if (vstart <= iend && istart <= vend)
1426                 return 1;
1427
1428         return 0;
1429 }
1430
1431 static int
1432 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1433 {
1434         struct virtio_net *dev = *pdev;
1435         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1436         uint16_t i;
1437         uint64_t vva, len;
1438
1439         switch (imsg->type) {
1440         case VHOST_IOTLB_UPDATE:
1441                 len = imsg->size;
1442                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1443                 if (!vva)
1444                         return VH_RESULT_ERR;
1445
1446                 for (i = 0; i < dev->nr_vring; i++) {
1447                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1448
1449                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1450                                         len, imsg->perm);
1451
1452                         if (is_vring_iotlb_update(vq, imsg))
1453                                 *pdev = dev = translate_ring_addresses(dev, i);
1454                 }
1455                 break;
1456         case VHOST_IOTLB_INVALIDATE:
1457                 for (i = 0; i < dev->nr_vring; i++) {
1458                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1459
1460                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1461                                         imsg->size);
1462
1463                         if (is_vring_iotlb_invalidate(vq, imsg))
1464                                 vring_invalidate(dev, vq);
1465                 }
1466                 break;
1467         default:
1468                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1469                                 imsg->type);
1470                 return VH_RESULT_ERR;
1471         }
1472
1473         return VH_RESULT_OK;
1474 }
1475
1476 typedef int (*vhost_message_handler_t)(struct virtio_net **pdev,
1477                                         struct VhostUserMsg *msg);
1478 static vhost_message_handler_t vhost_message_handlers[VHOST_USER_MAX] = {
1479         [VHOST_USER_NONE] = NULL,
1480         [VHOST_USER_GET_FEATURES] = vhost_user_get_features,
1481         [VHOST_USER_SET_FEATURES] = vhost_user_set_features,
1482         [VHOST_USER_SET_OWNER] = vhost_user_set_owner,
1483         [VHOST_USER_RESET_OWNER] = vhost_user_reset_owner,
1484         [VHOST_USER_SET_MEM_TABLE] = vhost_user_set_mem_table,
1485         [VHOST_USER_SET_LOG_BASE] = vhost_user_set_log_base,
1486         [VHOST_USER_SET_LOG_FD] = vhost_user_set_log_fd,
1487         [VHOST_USER_SET_VRING_NUM] = vhost_user_set_vring_num,
1488         [VHOST_USER_SET_VRING_ADDR] = vhost_user_set_vring_addr,
1489         [VHOST_USER_SET_VRING_BASE] = vhost_user_set_vring_base,
1490         [VHOST_USER_GET_VRING_BASE] = vhost_user_get_vring_base,
1491         [VHOST_USER_SET_VRING_KICK] = vhost_user_set_vring_kick,
1492         [VHOST_USER_SET_VRING_CALL] = vhost_user_set_vring_call,
1493         [VHOST_USER_SET_VRING_ERR] = vhost_user_set_vring_err,
1494         [VHOST_USER_GET_PROTOCOL_FEATURES] = vhost_user_get_protocol_features,
1495         [VHOST_USER_SET_PROTOCOL_FEATURES] = vhost_user_set_protocol_features,
1496         [VHOST_USER_GET_QUEUE_NUM] = vhost_user_get_queue_num,
1497         [VHOST_USER_SET_VRING_ENABLE] = vhost_user_set_vring_enable,
1498         [VHOST_USER_SEND_RARP] = vhost_user_send_rarp,
1499         [VHOST_USER_NET_SET_MTU] = vhost_user_net_set_mtu,
1500         [VHOST_USER_SET_SLAVE_REQ_FD] = vhost_user_set_req_fd,
1501         [VHOST_USER_IOTLB_MSG] = vhost_user_iotlb_msg,
1502 };
1503
1504
1505 /* return bytes# of read on success or negative val on failure. */
1506 static int
1507 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1508 {
1509         int ret;
1510
1511         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1512                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1513         if (ret <= 0)
1514                 return ret;
1515
1516         if (msg && msg->size) {
1517                 if (msg->size > sizeof(msg->payload)) {
1518                         RTE_LOG(ERR, VHOST_CONFIG,
1519                                 "invalid msg size: %d\n", msg->size);
1520                         return -1;
1521                 }
1522                 ret = read(sockfd, &msg->payload, msg->size);
1523                 if (ret <= 0)
1524                         return ret;
1525                 if (ret != (int)msg->size) {
1526                         RTE_LOG(ERR, VHOST_CONFIG,
1527                                 "read control message failed\n");
1528                         return -1;
1529                 }
1530         }
1531
1532         return ret;
1533 }
1534
1535 static int
1536 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1537 {
1538         if (!msg)
1539                 return 0;
1540
1541         return send_fd_message(sockfd, (char *)msg,
1542                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1543 }
1544
1545 static int
1546 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1547 {
1548         if (!msg)
1549                 return 0;
1550
1551         msg->flags &= ~VHOST_USER_VERSION_MASK;
1552         msg->flags &= ~VHOST_USER_NEED_REPLY;
1553         msg->flags |= VHOST_USER_VERSION;
1554         msg->flags |= VHOST_USER_REPLY_MASK;
1555
1556         return send_vhost_message(sockfd, msg, NULL, 0);
1557 }
1558
1559 static int
1560 send_vhost_slave_message(struct virtio_net *dev, struct VhostUserMsg *msg,
1561                          int *fds, int fd_num)
1562 {
1563         int ret;
1564
1565         if (msg->flags & VHOST_USER_NEED_REPLY)
1566                 rte_spinlock_lock(&dev->slave_req_lock);
1567
1568         ret = send_vhost_message(dev->slave_req_fd, msg, fds, fd_num);
1569         if (ret < 0 && (msg->flags & VHOST_USER_NEED_REPLY))
1570                 rte_spinlock_unlock(&dev->slave_req_lock);
1571
1572         return ret;
1573 }
1574
1575 /*
1576  * Allocate a queue pair if it hasn't been allocated yet
1577  */
1578 static int
1579 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev,
1580                         struct VhostUserMsg *msg)
1581 {
1582         uint16_t vring_idx;
1583
1584         switch (msg->request.master) {
1585         case VHOST_USER_SET_VRING_KICK:
1586         case VHOST_USER_SET_VRING_CALL:
1587         case VHOST_USER_SET_VRING_ERR:
1588                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1589                 break;
1590         case VHOST_USER_SET_VRING_NUM:
1591         case VHOST_USER_SET_VRING_BASE:
1592         case VHOST_USER_SET_VRING_ENABLE:
1593                 vring_idx = msg->payload.state.index;
1594                 break;
1595         case VHOST_USER_SET_VRING_ADDR:
1596                 vring_idx = msg->payload.addr.index;
1597                 break;
1598         default:
1599                 return 0;
1600         }
1601
1602         if (vring_idx >= VHOST_MAX_VRING) {
1603                 RTE_LOG(ERR, VHOST_CONFIG,
1604                         "invalid vring index: %u\n", vring_idx);
1605                 return -1;
1606         }
1607
1608         if (dev->virtqueue[vring_idx])
1609                 return 0;
1610
1611         return alloc_vring_queue(dev, vring_idx);
1612 }
1613
1614 static void
1615 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1616 {
1617         unsigned int i = 0;
1618         unsigned int vq_num = 0;
1619
1620         while (vq_num < dev->nr_vring) {
1621                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1622
1623                 if (vq) {
1624                         rte_spinlock_lock(&vq->access_lock);
1625                         vq_num++;
1626                 }
1627                 i++;
1628         }
1629 }
1630
1631 static void
1632 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1633 {
1634         unsigned int i = 0;
1635         unsigned int vq_num = 0;
1636
1637         while (vq_num < dev->nr_vring) {
1638                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1639
1640                 if (vq) {
1641                         rte_spinlock_unlock(&vq->access_lock);
1642                         vq_num++;
1643                 }
1644                 i++;
1645         }
1646 }
1647
1648 int
1649 vhost_user_msg_handler(int vid, int fd)
1650 {
1651         struct virtio_net *dev;
1652         struct VhostUserMsg msg;
1653         struct rte_vdpa_device *vdpa_dev;
1654         int did = -1;
1655         int ret;
1656         int unlock_required = 0;
1657         uint32_t skip_master = 0;
1658         int request;
1659
1660         dev = get_device(vid);
1661         if (dev == NULL)
1662                 return -1;
1663
1664         if (!dev->notify_ops) {
1665                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1666                 if (!dev->notify_ops) {
1667                         RTE_LOG(ERR, VHOST_CONFIG,
1668                                 "failed to get callback ops for driver %s\n",
1669                                 dev->ifname);
1670                         return -1;
1671                 }
1672         }
1673
1674         ret = read_vhost_message(fd, &msg);
1675         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1676                 if (ret < 0)
1677                         RTE_LOG(ERR, VHOST_CONFIG,
1678                                 "vhost read message failed\n");
1679                 else if (ret == 0)
1680                         RTE_LOG(INFO, VHOST_CONFIG,
1681                                 "vhost peer closed\n");
1682                 else
1683                         RTE_LOG(ERR, VHOST_CONFIG,
1684                                 "vhost read incorrect message\n");
1685
1686                 return -1;
1687         }
1688
1689         ret = 0;
1690         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1691                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1692                         vhost_message_str[msg.request.master]);
1693         else
1694                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1695                         vhost_message_str[msg.request.master]);
1696
1697         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1698         if (ret < 0) {
1699                 RTE_LOG(ERR, VHOST_CONFIG,
1700                         "failed to alloc queue\n");
1701                 return -1;
1702         }
1703
1704         /*
1705          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1706          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1707          * and device is destroyed. destroy_device waits for queues to be
1708          * inactive, so it is safe. Otherwise taking the access_lock
1709          * would cause a dead lock.
1710          */
1711         switch (msg.request.master) {
1712         case VHOST_USER_SET_FEATURES:
1713         case VHOST_USER_SET_PROTOCOL_FEATURES:
1714         case VHOST_USER_SET_OWNER:
1715         case VHOST_USER_SET_MEM_TABLE:
1716         case VHOST_USER_SET_LOG_BASE:
1717         case VHOST_USER_SET_LOG_FD:
1718         case VHOST_USER_SET_VRING_NUM:
1719         case VHOST_USER_SET_VRING_ADDR:
1720         case VHOST_USER_SET_VRING_BASE:
1721         case VHOST_USER_SET_VRING_KICK:
1722         case VHOST_USER_SET_VRING_CALL:
1723         case VHOST_USER_SET_VRING_ERR:
1724         case VHOST_USER_SET_VRING_ENABLE:
1725         case VHOST_USER_SEND_RARP:
1726         case VHOST_USER_NET_SET_MTU:
1727         case VHOST_USER_SET_SLAVE_REQ_FD:
1728                 vhost_user_lock_all_queue_pairs(dev);
1729                 unlock_required = 1;
1730                 break;
1731         default:
1732                 break;
1733
1734         }
1735
1736         if (dev->extern_ops.pre_msg_handle) {
1737                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1738                                 (void *)&msg, &skip_master);
1739                 if (ret == VH_RESULT_ERR)
1740                         goto skip_to_reply;
1741                 else if (ret == VH_RESULT_REPLY)
1742                         send_vhost_reply(fd, &msg);
1743
1744                 if (skip_master)
1745                         goto skip_to_post_handle;
1746         }
1747
1748         request = msg.request.master;
1749         if (request > VHOST_USER_NONE && request < VHOST_USER_MAX) {
1750                 if (!vhost_message_handlers[request])
1751                         goto skip_to_post_handle;
1752                 ret = vhost_message_handlers[request](&dev, &msg);
1753
1754                 switch (ret) {
1755                 case VH_RESULT_ERR:
1756                         RTE_LOG(ERR, VHOST_CONFIG,
1757                                 "Processing %s failed.\n",
1758                                 vhost_message_str[request]);
1759                         break;
1760                 case VH_RESULT_OK:
1761                         RTE_LOG(DEBUG, VHOST_CONFIG,
1762                                 "Processing %s succeeded.\n",
1763                                 vhost_message_str[request]);
1764                         break;
1765                 case VH_RESULT_REPLY:
1766                         RTE_LOG(DEBUG, VHOST_CONFIG,
1767                                 "Processing %s succeeded and needs reply.\n",
1768                                 vhost_message_str[request]);
1769                         send_vhost_reply(fd, &msg);
1770                         break;
1771                 }
1772         } else {
1773                 RTE_LOG(ERR, VHOST_CONFIG,
1774                         "Requested invalid message type %d.\n", request);
1775                 ret = VH_RESULT_ERR;
1776         }
1777
1778 skip_to_post_handle:
1779         if (ret != VH_RESULT_ERR && dev->extern_ops.post_msg_handle) {
1780                 ret = (*dev->extern_ops.post_msg_handle)(
1781                                 dev->vid, (void *)&msg);
1782                 if (ret == VH_RESULT_ERR)
1783                         goto skip_to_reply;
1784                 else if (ret == VH_RESULT_REPLY)
1785                         send_vhost_reply(fd, &msg);
1786         }
1787
1788 skip_to_reply:
1789         if (unlock_required)
1790                 vhost_user_unlock_all_queue_pairs(dev);
1791
1792         /*
1793          * If the request required a reply that was already sent,
1794          * this optional reply-ack won't be sent as the
1795          * VHOST_USER_NEED_REPLY was cleared in send_vhost_reply().
1796          */
1797         if (msg.flags & VHOST_USER_NEED_REPLY) {
1798                 msg.payload.u64 = ret == VH_RESULT_ERR;
1799                 msg.size = sizeof(msg.payload.u64);
1800                 send_vhost_reply(fd, &msg);
1801         } else if (ret == VH_RESULT_ERR) {
1802                 RTE_LOG(ERR, VHOST_CONFIG,
1803                         "vhost message handling failed.\n");
1804                 return -1;
1805         }
1806
1807         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1808                 dev->flags |= VIRTIO_DEV_READY;
1809
1810                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1811                         if (dev->dequeue_zero_copy) {
1812                                 RTE_LOG(INFO, VHOST_CONFIG,
1813                                                 "dequeue zero copy is enabled\n");
1814                         }
1815
1816                         if (dev->notify_ops->new_device(dev->vid) == 0)
1817                                 dev->flags |= VIRTIO_DEV_RUNNING;
1818                 }
1819         }
1820
1821         did = dev->vdpa_dev_id;
1822         vdpa_dev = rte_vdpa_get_device(did);
1823         if (vdpa_dev && virtio_is_ready(dev) &&
1824                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1825                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1826                 if (vdpa_dev->ops->dev_conf)
1827                         vdpa_dev->ops->dev_conf(dev->vid);
1828                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1829                 if (vhost_user_host_notifier_ctrl(dev->vid, true) != 0) {
1830                         RTE_LOG(INFO, VHOST_CONFIG,
1831                                 "(%d) software relay is used for vDPA, performance may be low.\n",
1832                                 dev->vid);
1833                 }
1834         }
1835
1836         return 0;
1837 }
1838
1839 static int process_slave_message_reply(struct virtio_net *dev,
1840                                        const struct VhostUserMsg *msg)
1841 {
1842         struct VhostUserMsg msg_reply;
1843         int ret;
1844
1845         if ((msg->flags & VHOST_USER_NEED_REPLY) == 0)
1846                 return 0;
1847
1848         if (read_vhost_message(dev->slave_req_fd, &msg_reply) < 0) {
1849                 ret = -1;
1850                 goto out;
1851         }
1852
1853         if (msg_reply.request.slave != msg->request.slave) {
1854                 RTE_LOG(ERR, VHOST_CONFIG,
1855                         "Received unexpected msg type (%u), expected %u\n",
1856                         msg_reply.request.slave, msg->request.slave);
1857                 ret = -1;
1858                 goto out;
1859         }
1860
1861         ret = msg_reply.payload.u64 ? -1 : 0;
1862
1863 out:
1864         rte_spinlock_unlock(&dev->slave_req_lock);
1865         return ret;
1866 }
1867
1868 int
1869 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1870 {
1871         int ret;
1872         struct VhostUserMsg msg = {
1873                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1874                 .flags = VHOST_USER_VERSION,
1875                 .size = sizeof(msg.payload.iotlb),
1876                 .payload.iotlb = {
1877                         .iova = iova,
1878                         .perm = perm,
1879                         .type = VHOST_IOTLB_MISS,
1880                 },
1881         };
1882
1883         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1884         if (ret < 0) {
1885                 RTE_LOG(ERR, VHOST_CONFIG,
1886                                 "Failed to send IOTLB miss message (%d)\n",
1887                                 ret);
1888                 return ret;
1889         }
1890
1891         return 0;
1892 }
1893
1894 static int vhost_user_slave_set_vring_host_notifier(struct virtio_net *dev,
1895                                                     int index, int fd,
1896                                                     uint64_t offset,
1897                                                     uint64_t size)
1898 {
1899         int *fdp = NULL;
1900         size_t fd_num = 0;
1901         int ret;
1902         struct VhostUserMsg msg = {
1903                 .request.slave = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1904                 .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY,
1905                 .size = sizeof(msg.payload.area),
1906                 .payload.area = {
1907                         .u64 = index & VHOST_USER_VRING_IDX_MASK,
1908                         .size = size,
1909                         .offset = offset,
1910                 },
1911         };
1912
1913         if (fd < 0)
1914                 msg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1915         else {
1916                 fdp = &fd;
1917                 fd_num = 1;
1918         }
1919
1920         ret = send_vhost_slave_message(dev, &msg, fdp, fd_num);
1921         if (ret < 0) {
1922                 RTE_LOG(ERR, VHOST_CONFIG,
1923                         "Failed to set host notifier (%d)\n", ret);
1924                 return ret;
1925         }
1926
1927         return process_slave_message_reply(dev, &msg);
1928 }
1929
1930 int vhost_user_host_notifier_ctrl(int vid, bool enable)
1931 {
1932         struct virtio_net *dev;
1933         struct rte_vdpa_device *vdpa_dev;
1934         int vfio_device_fd, did, ret = 0;
1935         uint64_t offset, size;
1936         unsigned int i;
1937
1938         dev = get_device(vid);
1939         if (!dev)
1940                 return -ENODEV;
1941
1942         did = dev->vdpa_dev_id;
1943         if (did < 0)
1944                 return -EINVAL;
1945
1946         if (!(dev->features & (1ULL << VIRTIO_F_VERSION_1)) ||
1947             !(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)) ||
1948             !(dev->protocol_features &
1949                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) ||
1950             !(dev->protocol_features &
1951                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) ||
1952             !(dev->protocol_features &
1953                         (1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER)))
1954                 return -ENOTSUP;
1955
1956         vdpa_dev = rte_vdpa_get_device(did);
1957         if (!vdpa_dev)
1958                 return -ENODEV;
1959
1960         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_vfio_device_fd, -ENOTSUP);
1961         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_notify_area, -ENOTSUP);
1962
1963         vfio_device_fd = vdpa_dev->ops->get_vfio_device_fd(vid);
1964         if (vfio_device_fd < 0)
1965                 return -ENOTSUP;
1966
1967         if (enable) {
1968                 for (i = 0; i < dev->nr_vring; i++) {
1969                         if (vdpa_dev->ops->get_notify_area(vid, i, &offset,
1970                                         &size) < 0) {
1971                                 ret = -ENOTSUP;
1972                                 goto disable;
1973                         }
1974
1975                         if (vhost_user_slave_set_vring_host_notifier(dev, i,
1976                                         vfio_device_fd, offset, size) < 0) {
1977                                 ret = -EFAULT;
1978                                 goto disable;
1979                         }
1980                 }
1981         } else {
1982 disable:
1983                 for (i = 0; i < dev->nr_vring; i++) {
1984                         vhost_user_slave_set_vring_host_notifier(dev, i, -1,
1985                                         0, 0);
1986                 }
1987         }
1988
1989         return ret;
1990 }