vhost: fix return code of messages requiring replies
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(struct virtio_net **pdev __rte_unused,
131                 struct VhostUserMsg *msg __rte_unused)
132 {
133         return VH_RESULT_OK;
134 }
135
136 static int
137 vhost_user_reset_owner(struct virtio_net **pdev,
138                 struct VhostUserMsg *msg __rte_unused)
139 {
140         struct virtio_net *dev = *pdev;
141         vhost_destroy_device_notify(dev);
142
143         cleanup_device(dev, 0);
144         reset_device(dev);
145         return VH_RESULT_OK;
146 }
147
148 /*
149  * The features that we support are requested.
150  */
151 static int
152 vhost_user_get_features(struct virtio_net **pdev, struct VhostUserMsg *msg)
153 {
154         struct virtio_net *dev = *pdev;
155         uint64_t features = 0;
156
157         rte_vhost_driver_get_features(dev->ifname, &features);
158
159         msg->payload.u64 = features;
160         msg->size = sizeof(msg->payload.u64);
161
162         return VH_RESULT_REPLY;
163 }
164
165 /*
166  * The queue number that we support are requested.
167  */
168 static int
169 vhost_user_get_queue_num(struct virtio_net **pdev, struct VhostUserMsg *msg)
170 {
171         struct virtio_net *dev = *pdev;
172         uint32_t queue_num = 0;
173
174         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
175
176         msg->payload.u64 = (uint64_t)queue_num;
177         msg->size = sizeof(msg->payload.u64);
178
179         return VH_RESULT_REPLY;
180 }
181
182 /*
183  * We receive the negotiated features supported by us and the virtio device.
184  */
185 static int
186 vhost_user_set_features(struct virtio_net **pdev, struct VhostUserMsg *msg)
187 {
188         struct virtio_net *dev = *pdev;
189         uint64_t features = msg->payload.u64;
190         uint64_t vhost_features = 0;
191         struct rte_vdpa_device *vdpa_dev;
192         int did = -1;
193
194         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
195         if (features & ~vhost_features) {
196                 RTE_LOG(ERR, VHOST_CONFIG,
197                         "(%d) received invalid negotiated features.\n",
198                         dev->vid);
199                 return VH_RESULT_ERR;
200         }
201
202         if (dev->flags & VIRTIO_DEV_RUNNING) {
203                 if (dev->features == features)
204                         return VH_RESULT_OK;
205
206                 /*
207                  * Error out if master tries to change features while device is
208                  * in running state. The exception being VHOST_F_LOG_ALL, which
209                  * is enabled when the live-migration starts.
210                  */
211                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
212                         RTE_LOG(ERR, VHOST_CONFIG,
213                                 "(%d) features changed while device is running.\n",
214                                 dev->vid);
215                         return VH_RESULT_ERR;
216                 }
217
218                 if (dev->notify_ops->features_changed)
219                         dev->notify_ops->features_changed(dev->vid, features);
220         }
221
222         dev->features = features;
223         if (dev->features &
224                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
225                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
226         } else {
227                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
228         }
229         VHOST_LOG_DEBUG(VHOST_CONFIG,
230                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
231                 dev->vid,
232                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
233                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
234
235         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
236             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
237                 /*
238                  * Remove all but first queue pair if MQ hasn't been
239                  * negotiated. This is safe because the device is not
240                  * running at this stage.
241                  */
242                 while (dev->nr_vring > 2) {
243                         struct vhost_virtqueue *vq;
244
245                         vq = dev->virtqueue[--dev->nr_vring];
246                         if (!vq)
247                                 continue;
248
249                         dev->virtqueue[dev->nr_vring] = NULL;
250                         cleanup_vq(vq, 1);
251                         free_vq(dev, vq);
252                 }
253         }
254
255         did = dev->vdpa_dev_id;
256         vdpa_dev = rte_vdpa_get_device(did);
257         if (vdpa_dev && vdpa_dev->ops->set_features)
258                 vdpa_dev->ops->set_features(dev->vid);
259
260         return VH_RESULT_OK;
261 }
262
263 /*
264  * The virtio device sends us the size of the descriptor ring.
265  */
266 static int
267 vhost_user_set_vring_num(struct virtio_net **pdev,
268                          struct VhostUserMsg *msg)
269 {
270         struct virtio_net *dev = *pdev;
271         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
272
273         vq->size = msg->payload.state.num;
274
275         /* VIRTIO 1.0, 2.4 Virtqueues says:
276          *
277          *   Queue Size value is always a power of 2. The maximum Queue Size
278          *   value is 32768.
279          */
280         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
281                 RTE_LOG(ERR, VHOST_CONFIG,
282                         "invalid virtqueue size %u\n", vq->size);
283                 return VH_RESULT_ERR;
284         }
285
286         if (dev->dequeue_zero_copy) {
287                 vq->nr_zmbuf = 0;
288                 vq->last_zmbuf_idx = 0;
289                 vq->zmbuf_size = vq->size;
290                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
291                                          sizeof(struct zcopy_mbuf), 0);
292                 if (vq->zmbufs == NULL) {
293                         RTE_LOG(WARNING, VHOST_CONFIG,
294                                 "failed to allocate mem for zero copy; "
295                                 "zero copy is force disabled\n");
296                         dev->dequeue_zero_copy = 0;
297                 }
298                 TAILQ_INIT(&vq->zmbuf_list);
299         }
300
301         if (vq_is_packed(dev)) {
302                 vq->shadow_used_packed = rte_malloc(NULL,
303                                 vq->size *
304                                 sizeof(struct vring_used_elem_packed),
305                                 RTE_CACHE_LINE_SIZE);
306                 if (!vq->shadow_used_packed) {
307                         RTE_LOG(ERR, VHOST_CONFIG,
308                                         "failed to allocate memory for shadow used ring.\n");
309                         return VH_RESULT_ERR;
310                 }
311
312         } else {
313                 vq->shadow_used_split = rte_malloc(NULL,
314                                 vq->size * sizeof(struct vring_used_elem),
315                                 RTE_CACHE_LINE_SIZE);
316                 if (!vq->shadow_used_split) {
317                         RTE_LOG(ERR, VHOST_CONFIG,
318                                         "failed to allocate memory for shadow used ring.\n");
319                         return VH_RESULT_ERR;
320                 }
321         }
322
323         vq->batch_copy_elems = rte_malloc(NULL,
324                                 vq->size * sizeof(struct batch_copy_elem),
325                                 RTE_CACHE_LINE_SIZE);
326         if (!vq->batch_copy_elems) {
327                 RTE_LOG(ERR, VHOST_CONFIG,
328                         "failed to allocate memory for batching copy.\n");
329                 return VH_RESULT_ERR;
330         }
331
332         return VH_RESULT_OK;
333 }
334
335 /*
336  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
337  * same numa node as the memory of vring descriptor.
338  */
339 #ifdef RTE_LIBRTE_VHOST_NUMA
340 static struct virtio_net*
341 numa_realloc(struct virtio_net *dev, int index)
342 {
343         int oldnode, newnode;
344         struct virtio_net *old_dev;
345         struct vhost_virtqueue *old_vq, *vq;
346         struct zcopy_mbuf *new_zmbuf;
347         struct vring_used_elem *new_shadow_used_split;
348         struct vring_used_elem_packed *new_shadow_used_packed;
349         struct batch_copy_elem *new_batch_copy_elems;
350         int ret;
351
352         old_dev = dev;
353         vq = old_vq = dev->virtqueue[index];
354
355         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
356                             MPOL_F_NODE | MPOL_F_ADDR);
357
358         /* check if we need to reallocate vq */
359         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
360                              MPOL_F_NODE | MPOL_F_ADDR);
361         if (ret) {
362                 RTE_LOG(ERR, VHOST_CONFIG,
363                         "Unable to get vq numa information.\n");
364                 return dev;
365         }
366         if (oldnode != newnode) {
367                 RTE_LOG(INFO, VHOST_CONFIG,
368                         "reallocate vq from %d to %d node\n", oldnode, newnode);
369                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
370                 if (!vq)
371                         return dev;
372
373                 memcpy(vq, old_vq, sizeof(*vq));
374                 TAILQ_INIT(&vq->zmbuf_list);
375
376                 if (dev->dequeue_zero_copy) {
377                         new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
378                                         sizeof(struct zcopy_mbuf), 0, newnode);
379                         if (new_zmbuf) {
380                                 rte_free(vq->zmbufs);
381                                 vq->zmbufs = new_zmbuf;
382                         }
383                 }
384
385                 if (vq_is_packed(dev)) {
386                         new_shadow_used_packed = rte_malloc_socket(NULL,
387                                         vq->size *
388                                         sizeof(struct vring_used_elem_packed),
389                                         RTE_CACHE_LINE_SIZE,
390                                         newnode);
391                         if (new_shadow_used_packed) {
392                                 rte_free(vq->shadow_used_packed);
393                                 vq->shadow_used_packed = new_shadow_used_packed;
394                         }
395                 } else {
396                         new_shadow_used_split = rte_malloc_socket(NULL,
397                                         vq->size *
398                                         sizeof(struct vring_used_elem),
399                                         RTE_CACHE_LINE_SIZE,
400                                         newnode);
401                         if (new_shadow_used_split) {
402                                 rte_free(vq->shadow_used_split);
403                                 vq->shadow_used_split = new_shadow_used_split;
404                         }
405                 }
406
407                 new_batch_copy_elems = rte_malloc_socket(NULL,
408                         vq->size * sizeof(struct batch_copy_elem),
409                         RTE_CACHE_LINE_SIZE,
410                         newnode);
411                 if (new_batch_copy_elems) {
412                         rte_free(vq->batch_copy_elems);
413                         vq->batch_copy_elems = new_batch_copy_elems;
414                 }
415
416                 rte_free(old_vq);
417         }
418
419         /* check if we need to reallocate dev */
420         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
421                             MPOL_F_NODE | MPOL_F_ADDR);
422         if (ret) {
423                 RTE_LOG(ERR, VHOST_CONFIG,
424                         "Unable to get dev numa information.\n");
425                 goto out;
426         }
427         if (oldnode != newnode) {
428                 RTE_LOG(INFO, VHOST_CONFIG,
429                         "reallocate dev from %d to %d node\n",
430                         oldnode, newnode);
431                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
432                 if (!dev) {
433                         dev = old_dev;
434                         goto out;
435                 }
436
437                 memcpy(dev, old_dev, sizeof(*dev));
438                 rte_free(old_dev);
439         }
440
441 out:
442         dev->virtqueue[index] = vq;
443         vhost_devices[dev->vid] = dev;
444
445         if (old_vq != vq)
446                 vhost_user_iotlb_init(dev, index);
447
448         return dev;
449 }
450 #else
451 static struct virtio_net*
452 numa_realloc(struct virtio_net *dev, int index __rte_unused)
453 {
454         return dev;
455 }
456 #endif
457
458 /* Converts QEMU virtual address to Vhost virtual address. */
459 static uint64_t
460 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
461 {
462         struct rte_vhost_mem_region *r;
463         uint32_t i;
464
465         /* Find the region where the address lives. */
466         for (i = 0; i < dev->mem->nregions; i++) {
467                 r = &dev->mem->regions[i];
468
469                 if (qva >= r->guest_user_addr &&
470                     qva <  r->guest_user_addr + r->size) {
471
472                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
473                                 *len = r->guest_user_addr + r->size - qva;
474
475                         return qva - r->guest_user_addr +
476                                r->host_user_addr;
477                 }
478         }
479         *len = 0;
480
481         return 0;
482 }
483
484
485 /*
486  * Converts ring address to Vhost virtual address.
487  * If IOMMU is enabled, the ring address is a guest IO virtual address,
488  * else it is a QEMU virtual address.
489  */
490 static uint64_t
491 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
492                 uint64_t ra, uint64_t *size)
493 {
494         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
495                 uint64_t vva;
496
497                 vva = vhost_user_iotlb_cache_find(vq, ra,
498                                         size, VHOST_ACCESS_RW);
499                 if (!vva)
500                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
501
502                 return vva;
503         }
504
505         return qva_to_vva(dev, ra, size);
506 }
507
508 static struct virtio_net *
509 translate_ring_addresses(struct virtio_net *dev, int vq_index)
510 {
511         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
512         struct vhost_vring_addr *addr = &vq->ring_addrs;
513         uint64_t len;
514
515         if (vq_is_packed(dev)) {
516                 len = sizeof(struct vring_packed_desc) * vq->size;
517                 vq->desc_packed = (struct vring_packed_desc *)(uintptr_t)
518                         ring_addr_to_vva(dev, vq, addr->desc_user_addr, &len);
519                 vq->log_guest_addr = 0;
520                 if (vq->desc_packed == NULL ||
521                                 len != sizeof(struct vring_packed_desc) *
522                                 vq->size) {
523                         RTE_LOG(DEBUG, VHOST_CONFIG,
524                                 "(%d) failed to map desc_packed ring.\n",
525                                 dev->vid);
526                         return dev;
527                 }
528
529                 dev = numa_realloc(dev, vq_index);
530                 vq = dev->virtqueue[vq_index];
531                 addr = &vq->ring_addrs;
532
533                 len = sizeof(struct vring_packed_desc_event);
534                 vq->driver_event = (struct vring_packed_desc_event *)
535                                         (uintptr_t)ring_addr_to_vva(dev,
536                                         vq, addr->avail_user_addr, &len);
537                 if (vq->driver_event == NULL ||
538                                 len != sizeof(struct vring_packed_desc_event)) {
539                         RTE_LOG(DEBUG, VHOST_CONFIG,
540                                 "(%d) failed to find driver area address.\n",
541                                 dev->vid);
542                         return dev;
543                 }
544
545                 len = sizeof(struct vring_packed_desc_event);
546                 vq->device_event = (struct vring_packed_desc_event *)
547                                         (uintptr_t)ring_addr_to_vva(dev,
548                                         vq, addr->used_user_addr, &len);
549                 if (vq->device_event == NULL ||
550                                 len != sizeof(struct vring_packed_desc_event)) {
551                         RTE_LOG(DEBUG, VHOST_CONFIG,
552                                 "(%d) failed to find device area address.\n",
553                                 dev->vid);
554                         return dev;
555                 }
556
557                 return dev;
558         }
559
560         /* The addresses are converted from QEMU virtual to Vhost virtual. */
561         if (vq->desc && vq->avail && vq->used)
562                 return dev;
563
564         len = sizeof(struct vring_desc) * vq->size;
565         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
566                         vq, addr->desc_user_addr, &len);
567         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
568                 RTE_LOG(DEBUG, VHOST_CONFIG,
569                         "(%d) failed to map desc ring.\n",
570                         dev->vid);
571                 return dev;
572         }
573
574         dev = numa_realloc(dev, vq_index);
575         vq = dev->virtqueue[vq_index];
576         addr = &vq->ring_addrs;
577
578         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
579         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
580                         vq, addr->avail_user_addr, &len);
581         if (vq->avail == 0 ||
582                         len != sizeof(struct vring_avail) +
583                         sizeof(uint16_t) * vq->size) {
584                 RTE_LOG(DEBUG, VHOST_CONFIG,
585                         "(%d) failed to map avail ring.\n",
586                         dev->vid);
587                 return dev;
588         }
589
590         len = sizeof(struct vring_used) +
591                 sizeof(struct vring_used_elem) * vq->size;
592         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
593                         vq, addr->used_user_addr, &len);
594         if (vq->used == 0 || len != sizeof(struct vring_used) +
595                         sizeof(struct vring_used_elem) * vq->size) {
596                 RTE_LOG(DEBUG, VHOST_CONFIG,
597                         "(%d) failed to map used ring.\n",
598                         dev->vid);
599                 return dev;
600         }
601
602         if (vq->last_used_idx != vq->used->idx) {
603                 RTE_LOG(WARNING, VHOST_CONFIG,
604                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
605                         "some packets maybe resent for Tx and dropped for Rx\n",
606                         vq->last_used_idx, vq->used->idx);
607                 vq->last_used_idx  = vq->used->idx;
608                 vq->last_avail_idx = vq->used->idx;
609         }
610
611         vq->log_guest_addr = addr->log_guest_addr;
612
613         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
614                         dev->vid, vq->desc);
615         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
616                         dev->vid, vq->avail);
617         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
618                         dev->vid, vq->used);
619         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
620                         dev->vid, vq->log_guest_addr);
621
622         return dev;
623 }
624
625 /*
626  * The virtio device sends us the desc, used and avail ring addresses.
627  * This function then converts these to our address space.
628  */
629 static int
630 vhost_user_set_vring_addr(struct virtio_net **pdev, struct VhostUserMsg *msg)
631 {
632         struct virtio_net *dev = *pdev;
633         struct vhost_virtqueue *vq;
634         struct vhost_vring_addr *addr = &msg->payload.addr;
635
636         if (dev->mem == NULL)
637                 return VH_RESULT_ERR;
638
639         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
640         vq = dev->virtqueue[msg->payload.addr.index];
641
642         /*
643          * Rings addresses should not be interpreted as long as the ring is not
644          * started and enabled
645          */
646         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
647
648         vring_invalidate(dev, vq);
649
650         if (vq->enabled && (dev->features &
651                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
652                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
653                 if (!dev)
654                         return VH_RESULT_ERR;
655
656                 *pdev = dev;
657         }
658
659         return VH_RESULT_OK;
660 }
661
662 /*
663  * The virtio device sends us the available ring last used index.
664  */
665 static int
666 vhost_user_set_vring_base(struct virtio_net **pdev,
667                           struct VhostUserMsg *msg)
668 {
669         struct virtio_net *dev = *pdev;
670         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
671                         msg->payload.state.num;
672         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
673                         msg->payload.state.num;
674
675         return VH_RESULT_OK;
676 }
677
678 static int
679 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
680                    uint64_t host_phys_addr, uint64_t size)
681 {
682         struct guest_page *page, *last_page;
683
684         if (dev->nr_guest_pages == dev->max_guest_pages) {
685                 dev->max_guest_pages *= 2;
686                 dev->guest_pages = realloc(dev->guest_pages,
687                                         dev->max_guest_pages * sizeof(*page));
688                 if (!dev->guest_pages) {
689                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
690                         return -1;
691                 }
692         }
693
694         if (dev->nr_guest_pages > 0) {
695                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
696                 /* merge if the two pages are continuous */
697                 if (host_phys_addr == last_page->host_phys_addr +
698                                       last_page->size) {
699                         last_page->size += size;
700                         return 0;
701                 }
702         }
703
704         page = &dev->guest_pages[dev->nr_guest_pages++];
705         page->guest_phys_addr = guest_phys_addr;
706         page->host_phys_addr  = host_phys_addr;
707         page->size = size;
708
709         return 0;
710 }
711
712 static int
713 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
714                 uint64_t page_size)
715 {
716         uint64_t reg_size = reg->size;
717         uint64_t host_user_addr  = reg->host_user_addr;
718         uint64_t guest_phys_addr = reg->guest_phys_addr;
719         uint64_t host_phys_addr;
720         uint64_t size;
721
722         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
723         size = page_size - (guest_phys_addr & (page_size - 1));
724         size = RTE_MIN(size, reg_size);
725
726         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
727                 return -1;
728
729         host_user_addr  += size;
730         guest_phys_addr += size;
731         reg_size -= size;
732
733         while (reg_size > 0) {
734                 size = RTE_MIN(reg_size, page_size);
735                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
736                                                   host_user_addr);
737                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
738                                 size) < 0)
739                         return -1;
740
741                 host_user_addr  += size;
742                 guest_phys_addr += size;
743                 reg_size -= size;
744         }
745
746         return 0;
747 }
748
749 #ifdef RTE_LIBRTE_VHOST_DEBUG
750 /* TODO: enable it only in debug mode? */
751 static void
752 dump_guest_pages(struct virtio_net *dev)
753 {
754         uint32_t i;
755         struct guest_page *page;
756
757         for (i = 0; i < dev->nr_guest_pages; i++) {
758                 page = &dev->guest_pages[i];
759
760                 RTE_LOG(INFO, VHOST_CONFIG,
761                         "guest physical page region %u\n"
762                         "\t guest_phys_addr: %" PRIx64 "\n"
763                         "\t host_phys_addr : %" PRIx64 "\n"
764                         "\t size           : %" PRIx64 "\n",
765                         i,
766                         page->guest_phys_addr,
767                         page->host_phys_addr,
768                         page->size);
769         }
770 }
771 #else
772 #define dump_guest_pages(dev)
773 #endif
774
775 static bool
776 vhost_memory_changed(struct VhostUserMemory *new,
777                      struct rte_vhost_memory *old)
778 {
779         uint32_t i;
780
781         if (new->nregions != old->nregions)
782                 return true;
783
784         for (i = 0; i < new->nregions; ++i) {
785                 VhostUserMemoryRegion *new_r = &new->regions[i];
786                 struct rte_vhost_mem_region *old_r = &old->regions[i];
787
788                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
789                         return true;
790                 if (new_r->memory_size != old_r->size)
791                         return true;
792                 if (new_r->userspace_addr != old_r->guest_user_addr)
793                         return true;
794         }
795
796         return false;
797 }
798
799 static int
800 vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *msg)
801 {
802         struct virtio_net *dev = *pdev;
803         struct VhostUserMemory memory = msg->payload.memory;
804         struct rte_vhost_mem_region *reg;
805         void *mmap_addr;
806         uint64_t mmap_size;
807         uint64_t mmap_offset;
808         uint64_t alignment;
809         uint32_t i;
810         int populate;
811         int fd;
812
813         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
814                 RTE_LOG(ERR, VHOST_CONFIG,
815                         "too many memory regions (%u)\n", memory.nregions);
816                 return VH_RESULT_ERR;
817         }
818
819         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
820                 RTE_LOG(INFO, VHOST_CONFIG,
821                         "(%d) memory regions not changed\n", dev->vid);
822
823                 for (i = 0; i < memory.nregions; i++)
824                         close(msg->fds[i]);
825
826                 return VH_RESULT_OK;
827         }
828
829         if (dev->mem) {
830                 free_mem_region(dev);
831                 rte_free(dev->mem);
832                 dev->mem = NULL;
833         }
834
835         /* Flush IOTLB cache as previous HVAs are now invalid */
836         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))
837                 for (i = 0; i < dev->nr_vring; i++)
838                         vhost_user_iotlb_flush_all(dev->virtqueue[i]);
839
840         dev->nr_guest_pages = 0;
841         if (!dev->guest_pages) {
842                 dev->max_guest_pages = 8;
843                 dev->guest_pages = malloc(dev->max_guest_pages *
844                                                 sizeof(struct guest_page));
845                 if (dev->guest_pages == NULL) {
846                         RTE_LOG(ERR, VHOST_CONFIG,
847                                 "(%d) failed to allocate memory "
848                                 "for dev->guest_pages\n",
849                                 dev->vid);
850                         return VH_RESULT_ERR;
851                 }
852         }
853
854         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
855                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
856         if (dev->mem == NULL) {
857                 RTE_LOG(ERR, VHOST_CONFIG,
858                         "(%d) failed to allocate memory for dev->mem\n",
859                         dev->vid);
860                 return VH_RESULT_ERR;
861         }
862         dev->mem->nregions = memory.nregions;
863
864         for (i = 0; i < memory.nregions; i++) {
865                 fd  = msg->fds[i];
866                 reg = &dev->mem->regions[i];
867
868                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
869                 reg->guest_user_addr = memory.regions[i].userspace_addr;
870                 reg->size            = memory.regions[i].memory_size;
871                 reg->fd              = fd;
872
873                 mmap_offset = memory.regions[i].mmap_offset;
874
875                 /* Check for memory_size + mmap_offset overflow */
876                 if (mmap_offset >= -reg->size) {
877                         RTE_LOG(ERR, VHOST_CONFIG,
878                                 "mmap_offset (%#"PRIx64") and memory_size "
879                                 "(%#"PRIx64") overflow\n",
880                                 mmap_offset, reg->size);
881                         goto err_mmap;
882                 }
883
884                 mmap_size = reg->size + mmap_offset;
885
886                 /* mmap() without flag of MAP_ANONYMOUS, should be called
887                  * with length argument aligned with hugepagesz at older
888                  * longterm version Linux, like 2.6.32 and 3.2.72, or
889                  * mmap() will fail with EINVAL.
890                  *
891                  * to avoid failure, make sure in caller to keep length
892                  * aligned.
893                  */
894                 alignment = get_blk_size(fd);
895                 if (alignment == (uint64_t)-1) {
896                         RTE_LOG(ERR, VHOST_CONFIG,
897                                 "couldn't get hugepage size through fstat\n");
898                         goto err_mmap;
899                 }
900                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
901
902                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
903                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
904                                  MAP_SHARED | populate, fd, 0);
905
906                 if (mmap_addr == MAP_FAILED) {
907                         RTE_LOG(ERR, VHOST_CONFIG,
908                                 "mmap region %u failed.\n", i);
909                         goto err_mmap;
910                 }
911
912                 reg->mmap_addr = mmap_addr;
913                 reg->mmap_size = mmap_size;
914                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
915                                       mmap_offset;
916
917                 if (dev->dequeue_zero_copy)
918                         if (add_guest_pages(dev, reg, alignment) < 0) {
919                                 RTE_LOG(ERR, VHOST_CONFIG,
920                                         "adding guest pages to region %u failed.\n",
921                                         i);
922                                 goto err_mmap;
923                         }
924
925                 RTE_LOG(INFO, VHOST_CONFIG,
926                         "guest memory region %u, size: 0x%" PRIx64 "\n"
927                         "\t guest physical addr: 0x%" PRIx64 "\n"
928                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
929                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
930                         "\t mmap addr : 0x%" PRIx64 "\n"
931                         "\t mmap size : 0x%" PRIx64 "\n"
932                         "\t mmap align: 0x%" PRIx64 "\n"
933                         "\t mmap off  : 0x%" PRIx64 "\n",
934                         i, reg->size,
935                         reg->guest_phys_addr,
936                         reg->guest_user_addr,
937                         reg->host_user_addr,
938                         (uint64_t)(uintptr_t)mmap_addr,
939                         mmap_size,
940                         alignment,
941                         mmap_offset);
942         }
943
944         for (i = 0; i < dev->nr_vring; i++) {
945                 struct vhost_virtqueue *vq = dev->virtqueue[i];
946
947                 if (vq->desc || vq->avail || vq->used) {
948                         /*
949                          * If the memory table got updated, the ring addresses
950                          * need to be translated again as virtual addresses have
951                          * changed.
952                          */
953                         vring_invalidate(dev, vq);
954
955                         dev = translate_ring_addresses(dev, i);
956                         if (!dev)
957                                 return VH_RESULT_ERR;
958
959                         *pdev = dev;
960                 }
961         }
962
963         dump_guest_pages(dev);
964
965         return VH_RESULT_OK;
966
967 err_mmap:
968         free_mem_region(dev);
969         rte_free(dev->mem);
970         dev->mem = NULL;
971         return VH_RESULT_ERR;
972 }
973
974 static bool
975 vq_is_ready(struct virtio_net *dev, struct vhost_virtqueue *vq)
976 {
977         bool rings_ok;
978
979         if (!vq)
980                 return false;
981
982         if (vq_is_packed(dev))
983                 rings_ok = !!vq->desc_packed;
984         else
985                 rings_ok = vq->desc && vq->avail && vq->used;
986
987         return rings_ok &&
988                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
989                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
990 }
991
992 static int
993 virtio_is_ready(struct virtio_net *dev)
994 {
995         struct vhost_virtqueue *vq;
996         uint32_t i;
997
998         if (dev->nr_vring == 0)
999                 return 0;
1000
1001         for (i = 0; i < dev->nr_vring; i++) {
1002                 vq = dev->virtqueue[i];
1003
1004                 if (!vq_is_ready(dev, vq))
1005                         return 0;
1006         }
1007
1008         RTE_LOG(INFO, VHOST_CONFIG,
1009                 "virtio is now ready for processing.\n");
1010         return 1;
1011 }
1012
1013 static int
1014 vhost_user_set_vring_call(struct virtio_net **pdev, struct VhostUserMsg *msg)
1015 {
1016         struct virtio_net *dev = *pdev;
1017         struct vhost_vring_file file;
1018         struct vhost_virtqueue *vq;
1019
1020         file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1021         if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
1022                 file.fd = VIRTIO_INVALID_EVENTFD;
1023         else
1024                 file.fd = msg->fds[0];
1025         RTE_LOG(INFO, VHOST_CONFIG,
1026                 "vring call idx:%d file:%d\n", file.index, file.fd);
1027
1028         vq = dev->virtqueue[file.index];
1029         if (vq->callfd >= 0)
1030                 close(vq->callfd);
1031
1032         vq->callfd = file.fd;
1033
1034         return VH_RESULT_OK;
1035 }
1036
1037 static int vhost_user_set_vring_err(struct virtio_net **pdev __rte_unused,
1038                         struct VhostUserMsg *msg)
1039 {
1040         if (!(msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1041                 close(msg->fds[0]);
1042         RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1043
1044         return VH_RESULT_OK;
1045 }
1046
1047 static int
1048 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *msg)
1049 {
1050         struct virtio_net *dev = *pdev;
1051         struct vhost_vring_file file;
1052         struct vhost_virtqueue *vq;
1053
1054         file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1055         if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
1056                 file.fd = VIRTIO_INVALID_EVENTFD;
1057         else
1058                 file.fd = msg->fds[0];
1059         RTE_LOG(INFO, VHOST_CONFIG,
1060                 "vring kick idx:%d file:%d\n", file.index, file.fd);
1061
1062         /* Interpret ring addresses only when ring is started. */
1063         dev = translate_ring_addresses(dev, file.index);
1064         if (!dev)
1065                 return VH_RESULT_ERR;
1066
1067         *pdev = dev;
1068
1069         vq = dev->virtqueue[file.index];
1070
1071         /*
1072          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
1073          * the ring starts already enabled. Otherwise, it is enabled via
1074          * the SET_VRING_ENABLE message.
1075          */
1076         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
1077                 vq->enabled = 1;
1078
1079         if (vq->kickfd >= 0)
1080                 close(vq->kickfd);
1081         vq->kickfd = file.fd;
1082
1083         return VH_RESULT_OK;
1084 }
1085
1086 static void
1087 free_zmbufs(struct vhost_virtqueue *vq)
1088 {
1089         struct zcopy_mbuf *zmbuf, *next;
1090
1091         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
1092              zmbuf != NULL; zmbuf = next) {
1093                 next = TAILQ_NEXT(zmbuf, next);
1094
1095                 rte_pktmbuf_free(zmbuf->mbuf);
1096                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
1097         }
1098
1099         rte_free(vq->zmbufs);
1100 }
1101
1102 /*
1103  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
1104  */
1105 static int
1106 vhost_user_get_vring_base(struct virtio_net **pdev,
1107                           struct VhostUserMsg *msg)
1108 {
1109         struct virtio_net *dev = *pdev;
1110         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
1111
1112         /* We have to stop the queue (virtio) if it is running. */
1113         vhost_destroy_device_notify(dev);
1114
1115         dev->flags &= ~VIRTIO_DEV_READY;
1116         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
1117
1118         /* Here we are safe to get the last avail index */
1119         msg->payload.state.num = vq->last_avail_idx;
1120
1121         RTE_LOG(INFO, VHOST_CONFIG,
1122                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1123                 msg->payload.state.num);
1124         /*
1125          * Based on current qemu vhost-user implementation, this message is
1126          * sent and only sent in vhost_vring_stop.
1127          * TODO: cleanup the vring, it isn't usable since here.
1128          */
1129         if (vq->kickfd >= 0)
1130                 close(vq->kickfd);
1131
1132         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1133
1134         if (vq->callfd >= 0)
1135                 close(vq->callfd);
1136
1137         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1138
1139         if (dev->dequeue_zero_copy)
1140                 free_zmbufs(vq);
1141         if (vq_is_packed(dev)) {
1142                 rte_free(vq->shadow_used_packed);
1143                 vq->shadow_used_packed = NULL;
1144         } else {
1145                 rte_free(vq->shadow_used_split);
1146                 vq->shadow_used_split = NULL;
1147         }
1148
1149         rte_free(vq->batch_copy_elems);
1150         vq->batch_copy_elems = NULL;
1151
1152         msg->size = sizeof(msg->payload.state);
1153
1154         return VH_RESULT_REPLY;
1155 }
1156
1157 /*
1158  * when virtio queues are ready to work, qemu will send us to
1159  * enable the virtio queue pair.
1160  */
1161 static int
1162 vhost_user_set_vring_enable(struct virtio_net **pdev,
1163                             struct VhostUserMsg *msg)
1164 {
1165         struct virtio_net *dev = *pdev;
1166         int enable = (int)msg->payload.state.num;
1167         int index = (int)msg->payload.state.index;
1168         struct rte_vdpa_device *vdpa_dev;
1169         int did = -1;
1170
1171         RTE_LOG(INFO, VHOST_CONFIG,
1172                 "set queue enable: %d to qp idx: %d\n",
1173                 enable, index);
1174
1175         did = dev->vdpa_dev_id;
1176         vdpa_dev = rte_vdpa_get_device(did);
1177         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1178                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1179
1180         if (dev->notify_ops->vring_state_changed)
1181                 dev->notify_ops->vring_state_changed(dev->vid,
1182                                 index, enable);
1183
1184         dev->virtqueue[index]->enabled = enable;
1185
1186         return VH_RESULT_OK;
1187 }
1188
1189 static int
1190 vhost_user_get_protocol_features(struct virtio_net **pdev,
1191                                  struct VhostUserMsg *msg)
1192 {
1193         struct virtio_net *dev = *pdev;
1194         uint64_t features, protocol_features;
1195
1196         rte_vhost_driver_get_features(dev->ifname, &features);
1197         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1198
1199         /*
1200          * REPLY_ACK protocol feature is only mandatory for now
1201          * for IOMMU feature. If IOMMU is explicitly disabled by the
1202          * application, disable also REPLY_ACK feature for older buggy
1203          * Qemu versions (from v2.7.0 to v2.9.0).
1204          */
1205         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1206                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1207
1208         msg->payload.u64 = protocol_features;
1209         msg->size = sizeof(msg->payload.u64);
1210
1211         return VH_RESULT_REPLY;
1212 }
1213
1214 static int
1215 vhost_user_set_protocol_features(struct virtio_net **pdev,
1216                                  struct VhostUserMsg *msg)
1217 {
1218         struct virtio_net *dev = *pdev;
1219         uint64_t protocol_features = msg->payload.u64;
1220         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES) {
1221                 RTE_LOG(ERR, VHOST_CONFIG,
1222                         "(%d) received invalid protocol features.\n",
1223                         dev->vid);
1224                 return VH_RESULT_ERR;
1225         }
1226
1227         dev->protocol_features = protocol_features;
1228
1229         return VH_RESULT_OK;
1230 }
1231
1232 static int
1233 vhost_user_set_log_base(struct virtio_net **pdev, struct VhostUserMsg *msg)
1234 {
1235         struct virtio_net *dev = *pdev;
1236         int fd = msg->fds[0];
1237         uint64_t size, off;
1238         void *addr;
1239
1240         if (fd < 0) {
1241                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1242                 return VH_RESULT_ERR;
1243         }
1244
1245         if (msg->size != sizeof(VhostUserLog)) {
1246                 RTE_LOG(ERR, VHOST_CONFIG,
1247                         "invalid log base msg size: %"PRId32" != %d\n",
1248                         msg->size, (int)sizeof(VhostUserLog));
1249                 return VH_RESULT_ERR;
1250         }
1251
1252         size = msg->payload.log.mmap_size;
1253         off  = msg->payload.log.mmap_offset;
1254
1255         /* Don't allow mmap_offset to point outside the mmap region */
1256         if (off > size) {
1257                 RTE_LOG(ERR, VHOST_CONFIG,
1258                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1259                         off, size);
1260                 return VH_RESULT_ERR;
1261         }
1262
1263         RTE_LOG(INFO, VHOST_CONFIG,
1264                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1265                 size, off);
1266
1267         /*
1268          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1269          * fail when offset is not page size aligned.
1270          */
1271         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1272         close(fd);
1273         if (addr == MAP_FAILED) {
1274                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1275                 return VH_RESULT_ERR;
1276         }
1277
1278         /*
1279          * Free previously mapped log memory on occasionally
1280          * multiple VHOST_USER_SET_LOG_BASE.
1281          */
1282         if (dev->log_addr) {
1283                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1284         }
1285         dev->log_addr = (uint64_t)(uintptr_t)addr;
1286         dev->log_base = dev->log_addr + off;
1287         dev->log_size = size;
1288
1289         msg->size = sizeof(msg->payload.u64);
1290
1291         return VH_RESULT_REPLY;
1292 }
1293
1294 static int vhost_user_set_log_fd(struct virtio_net **pdev __rte_unused,
1295                         struct VhostUserMsg *msg)
1296 {
1297         close(msg->fds[0]);
1298         RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1299
1300         return VH_RESULT_OK;
1301 }
1302
1303 /*
1304  * An rarp packet is constructed and broadcasted to notify switches about
1305  * the new location of the migrated VM, so that packets from outside will
1306  * not be lost after migration.
1307  *
1308  * However, we don't actually "send" a rarp packet here, instead, we set
1309  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1310  */
1311 static int
1312 vhost_user_send_rarp(struct virtio_net **pdev, struct VhostUserMsg *msg)
1313 {
1314         struct virtio_net *dev = *pdev;
1315         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1316         struct rte_vdpa_device *vdpa_dev;
1317         int did = -1;
1318
1319         RTE_LOG(DEBUG, VHOST_CONFIG,
1320                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1321                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1322         memcpy(dev->mac.addr_bytes, mac, 6);
1323
1324         /*
1325          * Set the flag to inject a RARP broadcast packet at
1326          * rte_vhost_dequeue_burst().
1327          *
1328          * rte_smp_wmb() is for making sure the mac is copied
1329          * before the flag is set.
1330          */
1331         rte_smp_wmb();
1332         rte_atomic16_set(&dev->broadcast_rarp, 1);
1333         did = dev->vdpa_dev_id;
1334         vdpa_dev = rte_vdpa_get_device(did);
1335         if (vdpa_dev && vdpa_dev->ops->migration_done)
1336                 vdpa_dev->ops->migration_done(dev->vid);
1337
1338         return VH_RESULT_OK;
1339 }
1340
1341 static int
1342 vhost_user_net_set_mtu(struct virtio_net **pdev, struct VhostUserMsg *msg)
1343 {
1344         struct virtio_net *dev = *pdev;
1345         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1346                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1347                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1348                                 msg->payload.u64);
1349
1350                 return VH_RESULT_ERR;
1351         }
1352
1353         dev->mtu = msg->payload.u64;
1354
1355         return VH_RESULT_OK;
1356 }
1357
1358 static int
1359 vhost_user_set_req_fd(struct virtio_net **pdev, struct VhostUserMsg *msg)
1360 {
1361         struct virtio_net *dev = *pdev;
1362         int fd = msg->fds[0];
1363
1364         if (fd < 0) {
1365                 RTE_LOG(ERR, VHOST_CONFIG,
1366                                 "Invalid file descriptor for slave channel (%d)\n",
1367                                 fd);
1368                 return VH_RESULT_ERR;
1369         }
1370
1371         dev->slave_req_fd = fd;
1372
1373         return VH_RESULT_OK;
1374 }
1375
1376 static int
1377 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1378 {
1379         struct vhost_vring_addr *ra;
1380         uint64_t start, end;
1381
1382         start = imsg->iova;
1383         end = start + imsg->size;
1384
1385         ra = &vq->ring_addrs;
1386         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1387                 return 1;
1388         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1389                 return 1;
1390         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1391                 return 1;
1392
1393         return 0;
1394 }
1395
1396 static int
1397 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1398                                 struct vhost_iotlb_msg *imsg)
1399 {
1400         uint64_t istart, iend, vstart, vend;
1401
1402         istart = imsg->iova;
1403         iend = istart + imsg->size - 1;
1404
1405         vstart = (uintptr_t)vq->desc;
1406         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1407         if (vstart <= iend && istart <= vend)
1408                 return 1;
1409
1410         vstart = (uintptr_t)vq->avail;
1411         vend = vstart + sizeof(struct vring_avail);
1412         vend += sizeof(uint16_t) * vq->size - 1;
1413         if (vstart <= iend && istart <= vend)
1414                 return 1;
1415
1416         vstart = (uintptr_t)vq->used;
1417         vend = vstart + sizeof(struct vring_used);
1418         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1419         if (vstart <= iend && istart <= vend)
1420                 return 1;
1421
1422         return 0;
1423 }
1424
1425 static int
1426 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1427 {
1428         struct virtio_net *dev = *pdev;
1429         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1430         uint16_t i;
1431         uint64_t vva, len;
1432
1433         switch (imsg->type) {
1434         case VHOST_IOTLB_UPDATE:
1435                 len = imsg->size;
1436                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1437                 if (!vva)
1438                         return VH_RESULT_ERR;
1439
1440                 for (i = 0; i < dev->nr_vring; i++) {
1441                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1442
1443                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1444                                         len, imsg->perm);
1445
1446                         if (is_vring_iotlb_update(vq, imsg))
1447                                 *pdev = dev = translate_ring_addresses(dev, i);
1448                 }
1449                 break;
1450         case VHOST_IOTLB_INVALIDATE:
1451                 for (i = 0; i < dev->nr_vring; i++) {
1452                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1453
1454                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1455                                         imsg->size);
1456
1457                         if (is_vring_iotlb_invalidate(vq, imsg))
1458                                 vring_invalidate(dev, vq);
1459                 }
1460                 break;
1461         default:
1462                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1463                                 imsg->type);
1464                 return VH_RESULT_ERR;
1465         }
1466
1467         return VH_RESULT_OK;
1468 }
1469
1470 typedef int (*vhost_message_handler_t)(struct virtio_net **pdev,
1471                                         struct VhostUserMsg *msg);
1472 static vhost_message_handler_t vhost_message_handlers[VHOST_USER_MAX] = {
1473         [VHOST_USER_NONE] = NULL,
1474         [VHOST_USER_GET_FEATURES] = vhost_user_get_features,
1475         [VHOST_USER_SET_FEATURES] = vhost_user_set_features,
1476         [VHOST_USER_SET_OWNER] = vhost_user_set_owner,
1477         [VHOST_USER_RESET_OWNER] = vhost_user_reset_owner,
1478         [VHOST_USER_SET_MEM_TABLE] = vhost_user_set_mem_table,
1479         [VHOST_USER_SET_LOG_BASE] = vhost_user_set_log_base,
1480         [VHOST_USER_SET_LOG_FD] = vhost_user_set_log_fd,
1481         [VHOST_USER_SET_VRING_NUM] = vhost_user_set_vring_num,
1482         [VHOST_USER_SET_VRING_ADDR] = vhost_user_set_vring_addr,
1483         [VHOST_USER_SET_VRING_BASE] = vhost_user_set_vring_base,
1484         [VHOST_USER_GET_VRING_BASE] = vhost_user_get_vring_base,
1485         [VHOST_USER_SET_VRING_KICK] = vhost_user_set_vring_kick,
1486         [VHOST_USER_SET_VRING_CALL] = vhost_user_set_vring_call,
1487         [VHOST_USER_SET_VRING_ERR] = vhost_user_set_vring_err,
1488         [VHOST_USER_GET_PROTOCOL_FEATURES] = vhost_user_get_protocol_features,
1489         [VHOST_USER_SET_PROTOCOL_FEATURES] = vhost_user_set_protocol_features,
1490         [VHOST_USER_GET_QUEUE_NUM] = vhost_user_get_queue_num,
1491         [VHOST_USER_SET_VRING_ENABLE] = vhost_user_set_vring_enable,
1492         [VHOST_USER_SEND_RARP] = vhost_user_send_rarp,
1493         [VHOST_USER_NET_SET_MTU] = vhost_user_net_set_mtu,
1494         [VHOST_USER_SET_SLAVE_REQ_FD] = vhost_user_set_req_fd,
1495         [VHOST_USER_IOTLB_MSG] = vhost_user_iotlb_msg,
1496 };
1497
1498
1499 /* return bytes# of read on success or negative val on failure. */
1500 static int
1501 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1502 {
1503         int ret;
1504
1505         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1506                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1507         if (ret <= 0)
1508                 return ret;
1509
1510         if (msg && msg->size) {
1511                 if (msg->size > sizeof(msg->payload)) {
1512                         RTE_LOG(ERR, VHOST_CONFIG,
1513                                 "invalid msg size: %d\n", msg->size);
1514                         return -1;
1515                 }
1516                 ret = read(sockfd, &msg->payload, msg->size);
1517                 if (ret <= 0)
1518                         return ret;
1519                 if (ret != (int)msg->size) {
1520                         RTE_LOG(ERR, VHOST_CONFIG,
1521                                 "read control message failed\n");
1522                         return -1;
1523                 }
1524         }
1525
1526         return ret;
1527 }
1528
1529 static int
1530 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1531 {
1532         if (!msg)
1533                 return 0;
1534
1535         return send_fd_message(sockfd, (char *)msg,
1536                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1537 }
1538
1539 static int
1540 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1541 {
1542         if (!msg)
1543                 return 0;
1544
1545         msg->flags &= ~VHOST_USER_VERSION_MASK;
1546         msg->flags &= ~VHOST_USER_NEED_REPLY;
1547         msg->flags |= VHOST_USER_VERSION;
1548         msg->flags |= VHOST_USER_REPLY_MASK;
1549
1550         return send_vhost_message(sockfd, msg, NULL, 0);
1551 }
1552
1553 static int
1554 send_vhost_slave_message(struct virtio_net *dev, struct VhostUserMsg *msg,
1555                          int *fds, int fd_num)
1556 {
1557         int ret;
1558
1559         if (msg->flags & VHOST_USER_NEED_REPLY)
1560                 rte_spinlock_lock(&dev->slave_req_lock);
1561
1562         ret = send_vhost_message(dev->slave_req_fd, msg, fds, fd_num);
1563         if (ret < 0 && (msg->flags & VHOST_USER_NEED_REPLY))
1564                 rte_spinlock_unlock(&dev->slave_req_lock);
1565
1566         return ret;
1567 }
1568
1569 /*
1570  * Allocate a queue pair if it hasn't been allocated yet
1571  */
1572 static int
1573 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev,
1574                         struct VhostUserMsg *msg)
1575 {
1576         uint16_t vring_idx;
1577
1578         switch (msg->request.master) {
1579         case VHOST_USER_SET_VRING_KICK:
1580         case VHOST_USER_SET_VRING_CALL:
1581         case VHOST_USER_SET_VRING_ERR:
1582                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1583                 break;
1584         case VHOST_USER_SET_VRING_NUM:
1585         case VHOST_USER_SET_VRING_BASE:
1586         case VHOST_USER_SET_VRING_ENABLE:
1587                 vring_idx = msg->payload.state.index;
1588                 break;
1589         case VHOST_USER_SET_VRING_ADDR:
1590                 vring_idx = msg->payload.addr.index;
1591                 break;
1592         default:
1593                 return 0;
1594         }
1595
1596         if (vring_idx >= VHOST_MAX_VRING) {
1597                 RTE_LOG(ERR, VHOST_CONFIG,
1598                         "invalid vring index: %u\n", vring_idx);
1599                 return -1;
1600         }
1601
1602         if (dev->virtqueue[vring_idx])
1603                 return 0;
1604
1605         return alloc_vring_queue(dev, vring_idx);
1606 }
1607
1608 static void
1609 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1610 {
1611         unsigned int i = 0;
1612         unsigned int vq_num = 0;
1613
1614         while (vq_num < dev->nr_vring) {
1615                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1616
1617                 if (vq) {
1618                         rte_spinlock_lock(&vq->access_lock);
1619                         vq_num++;
1620                 }
1621                 i++;
1622         }
1623 }
1624
1625 static void
1626 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1627 {
1628         unsigned int i = 0;
1629         unsigned int vq_num = 0;
1630
1631         while (vq_num < dev->nr_vring) {
1632                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1633
1634                 if (vq) {
1635                         rte_spinlock_unlock(&vq->access_lock);
1636                         vq_num++;
1637                 }
1638                 i++;
1639         }
1640 }
1641
1642 int
1643 vhost_user_msg_handler(int vid, int fd)
1644 {
1645         struct virtio_net *dev;
1646         struct VhostUserMsg msg;
1647         struct rte_vdpa_device *vdpa_dev;
1648         int did = -1;
1649         int ret;
1650         int unlock_required = 0;
1651         uint32_t skip_master = 0;
1652         int request;
1653
1654         dev = get_device(vid);
1655         if (dev == NULL)
1656                 return -1;
1657
1658         if (!dev->notify_ops) {
1659                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1660                 if (!dev->notify_ops) {
1661                         RTE_LOG(ERR, VHOST_CONFIG,
1662                                 "failed to get callback ops for driver %s\n",
1663                                 dev->ifname);
1664                         return -1;
1665                 }
1666         }
1667
1668         ret = read_vhost_message(fd, &msg);
1669         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1670                 if (ret < 0)
1671                         RTE_LOG(ERR, VHOST_CONFIG,
1672                                 "vhost read message failed\n");
1673                 else if (ret == 0)
1674                         RTE_LOG(INFO, VHOST_CONFIG,
1675                                 "vhost peer closed\n");
1676                 else
1677                         RTE_LOG(ERR, VHOST_CONFIG,
1678                                 "vhost read incorrect message\n");
1679
1680                 return -1;
1681         }
1682
1683         ret = 0;
1684         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1685                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1686                         vhost_message_str[msg.request.master]);
1687         else
1688                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1689                         vhost_message_str[msg.request.master]);
1690
1691         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1692         if (ret < 0) {
1693                 RTE_LOG(ERR, VHOST_CONFIG,
1694                         "failed to alloc queue\n");
1695                 return -1;
1696         }
1697
1698         /*
1699          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1700          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1701          * and device is destroyed. destroy_device waits for queues to be
1702          * inactive, so it is safe. Otherwise taking the access_lock
1703          * would cause a dead lock.
1704          */
1705         switch (msg.request.master) {
1706         case VHOST_USER_SET_FEATURES:
1707         case VHOST_USER_SET_PROTOCOL_FEATURES:
1708         case VHOST_USER_SET_OWNER:
1709         case VHOST_USER_SET_MEM_TABLE:
1710         case VHOST_USER_SET_LOG_BASE:
1711         case VHOST_USER_SET_LOG_FD:
1712         case VHOST_USER_SET_VRING_NUM:
1713         case VHOST_USER_SET_VRING_ADDR:
1714         case VHOST_USER_SET_VRING_BASE:
1715         case VHOST_USER_SET_VRING_KICK:
1716         case VHOST_USER_SET_VRING_CALL:
1717         case VHOST_USER_SET_VRING_ERR:
1718         case VHOST_USER_SET_VRING_ENABLE:
1719         case VHOST_USER_SEND_RARP:
1720         case VHOST_USER_NET_SET_MTU:
1721         case VHOST_USER_SET_SLAVE_REQ_FD:
1722                 vhost_user_lock_all_queue_pairs(dev);
1723                 unlock_required = 1;
1724                 break;
1725         default:
1726                 break;
1727
1728         }
1729
1730         if (dev->extern_ops.pre_msg_handle) {
1731                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1732                                 (void *)&msg, &skip_master);
1733                 if (ret == VH_RESULT_ERR)
1734                         goto skip_to_reply;
1735                 else if (ret == VH_RESULT_REPLY)
1736                         send_vhost_reply(fd, &msg);
1737
1738                 if (skip_master)
1739                         goto skip_to_post_handle;
1740         }
1741
1742         request = msg.request.master;
1743         if (request > VHOST_USER_NONE && request < VHOST_USER_MAX) {
1744                 if (!vhost_message_handlers[request])
1745                         goto skip_to_post_handle;
1746                 ret = vhost_message_handlers[request](&dev, &msg);
1747
1748                 switch (ret) {
1749                 case VH_RESULT_ERR:
1750                         RTE_LOG(ERR, VHOST_CONFIG,
1751                                 "Processing %s failed.\n",
1752                                 vhost_message_str[request]);
1753                         break;
1754                 case VH_RESULT_OK:
1755                         RTE_LOG(DEBUG, VHOST_CONFIG,
1756                                 "Processing %s succeeded.\n",
1757                                 vhost_message_str[request]);
1758                         break;
1759                 case VH_RESULT_REPLY:
1760                         RTE_LOG(DEBUG, VHOST_CONFIG,
1761                                 "Processing %s succeeded and needs reply.\n",
1762                                 vhost_message_str[request]);
1763                         send_vhost_reply(fd, &msg);
1764                         break;
1765                 }
1766         } else {
1767                 RTE_LOG(ERR, VHOST_CONFIG,
1768                         "Requested invalid message type %d.\n", request);
1769                 ret = VH_RESULT_ERR;
1770         }
1771
1772 skip_to_post_handle:
1773         if (ret != VH_RESULT_ERR && dev->extern_ops.post_msg_handle) {
1774                 ret = (*dev->extern_ops.post_msg_handle)(
1775                                 dev->vid, (void *)&msg);
1776                 if (ret == VH_RESULT_ERR)
1777                         goto skip_to_reply;
1778                 else if (ret == VH_RESULT_REPLY)
1779                         send_vhost_reply(fd, &msg);
1780         }
1781
1782 skip_to_reply:
1783         if (unlock_required)
1784                 vhost_user_unlock_all_queue_pairs(dev);
1785
1786         if (msg.flags & VHOST_USER_NEED_REPLY) {
1787                 msg.payload.u64 = ret == VH_RESULT_ERR;
1788                 msg.size = sizeof(msg.payload.u64);
1789                 send_vhost_reply(fd, &msg);
1790         } else if (ret == VH_RESULT_ERR) {
1791                 RTE_LOG(ERR, VHOST_CONFIG,
1792                         "vhost message handling failed.\n");
1793                 return -1;
1794         }
1795
1796         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1797                 dev->flags |= VIRTIO_DEV_READY;
1798
1799                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1800                         if (dev->dequeue_zero_copy) {
1801                                 RTE_LOG(INFO, VHOST_CONFIG,
1802                                                 "dequeue zero copy is enabled\n");
1803                         }
1804
1805                         if (dev->notify_ops->new_device(dev->vid) == 0)
1806                                 dev->flags |= VIRTIO_DEV_RUNNING;
1807                 }
1808         }
1809
1810         did = dev->vdpa_dev_id;
1811         vdpa_dev = rte_vdpa_get_device(did);
1812         if (vdpa_dev && virtio_is_ready(dev) &&
1813                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1814                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1815                 if (vdpa_dev->ops->dev_conf)
1816                         vdpa_dev->ops->dev_conf(dev->vid);
1817                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1818                 if (vhost_user_host_notifier_ctrl(dev->vid, true) != 0) {
1819                         RTE_LOG(INFO, VHOST_CONFIG,
1820                                 "(%d) software relay is used for vDPA, performance may be low.\n",
1821                                 dev->vid);
1822                 }
1823         }
1824
1825         return 0;
1826 }
1827
1828 static int process_slave_message_reply(struct virtio_net *dev,
1829                                        const struct VhostUserMsg *msg)
1830 {
1831         struct VhostUserMsg msg_reply;
1832         int ret;
1833
1834         if ((msg->flags & VHOST_USER_NEED_REPLY) == 0)
1835                 return 0;
1836
1837         if (read_vhost_message(dev->slave_req_fd, &msg_reply) < 0) {
1838                 ret = -1;
1839                 goto out;
1840         }
1841
1842         if (msg_reply.request.slave != msg->request.slave) {
1843                 RTE_LOG(ERR, VHOST_CONFIG,
1844                         "Received unexpected msg type (%u), expected %u\n",
1845                         msg_reply.request.slave, msg->request.slave);
1846                 ret = -1;
1847                 goto out;
1848         }
1849
1850         ret = msg_reply.payload.u64 ? -1 : 0;
1851
1852 out:
1853         rte_spinlock_unlock(&dev->slave_req_lock);
1854         return ret;
1855 }
1856
1857 int
1858 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1859 {
1860         int ret;
1861         struct VhostUserMsg msg = {
1862                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1863                 .flags = VHOST_USER_VERSION,
1864                 .size = sizeof(msg.payload.iotlb),
1865                 .payload.iotlb = {
1866                         .iova = iova,
1867                         .perm = perm,
1868                         .type = VHOST_IOTLB_MISS,
1869                 },
1870         };
1871
1872         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1873         if (ret < 0) {
1874                 RTE_LOG(ERR, VHOST_CONFIG,
1875                                 "Failed to send IOTLB miss message (%d)\n",
1876                                 ret);
1877                 return ret;
1878         }
1879
1880         return 0;
1881 }
1882
1883 static int vhost_user_slave_set_vring_host_notifier(struct virtio_net *dev,
1884                                                     int index, int fd,
1885                                                     uint64_t offset,
1886                                                     uint64_t size)
1887 {
1888         int *fdp = NULL;
1889         size_t fd_num = 0;
1890         int ret;
1891         struct VhostUserMsg msg = {
1892                 .request.slave = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1893                 .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY,
1894                 .size = sizeof(msg.payload.area),
1895                 .payload.area = {
1896                         .u64 = index & VHOST_USER_VRING_IDX_MASK,
1897                         .size = size,
1898                         .offset = offset,
1899                 },
1900         };
1901
1902         if (fd < 0)
1903                 msg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1904         else {
1905                 fdp = &fd;
1906                 fd_num = 1;
1907         }
1908
1909         ret = send_vhost_slave_message(dev, &msg, fdp, fd_num);
1910         if (ret < 0) {
1911                 RTE_LOG(ERR, VHOST_CONFIG,
1912                         "Failed to set host notifier (%d)\n", ret);
1913                 return ret;
1914         }
1915
1916         return process_slave_message_reply(dev, &msg);
1917 }
1918
1919 int vhost_user_host_notifier_ctrl(int vid, bool enable)
1920 {
1921         struct virtio_net *dev;
1922         struct rte_vdpa_device *vdpa_dev;
1923         int vfio_device_fd, did, ret = 0;
1924         uint64_t offset, size;
1925         unsigned int i;
1926
1927         dev = get_device(vid);
1928         if (!dev)
1929                 return -ENODEV;
1930
1931         did = dev->vdpa_dev_id;
1932         if (did < 0)
1933                 return -EINVAL;
1934
1935         if (!(dev->features & (1ULL << VIRTIO_F_VERSION_1)) ||
1936             !(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)) ||
1937             !(dev->protocol_features &
1938                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) ||
1939             !(dev->protocol_features &
1940                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) ||
1941             !(dev->protocol_features &
1942                         (1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER)))
1943                 return -ENOTSUP;
1944
1945         vdpa_dev = rte_vdpa_get_device(did);
1946         if (!vdpa_dev)
1947                 return -ENODEV;
1948
1949         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_vfio_device_fd, -ENOTSUP);
1950         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_notify_area, -ENOTSUP);
1951
1952         vfio_device_fd = vdpa_dev->ops->get_vfio_device_fd(vid);
1953         if (vfio_device_fd < 0)
1954                 return -ENOTSUP;
1955
1956         if (enable) {
1957                 for (i = 0; i < dev->nr_vring; i++) {
1958                         if (vdpa_dev->ops->get_notify_area(vid, i, &offset,
1959                                         &size) < 0) {
1960                                 ret = -ENOTSUP;
1961                                 goto disable;
1962                         }
1963
1964                         if (vhost_user_slave_set_vring_host_notifier(dev, i,
1965                                         vfio_device_fd, offset, size) < 0) {
1966                                 ret = -EFAULT;
1967                                 goto disable;
1968                         }
1969                 }
1970         } else {
1971 disable:
1972                 for (i = 0; i < dev->nr_vring; i++) {
1973                         vhost_user_slave_set_vring_host_notifier(dev, i, -1,
1974                                         0, 0);
1975                 }
1976         }
1977
1978         return ret;
1979 }