vhost: fix payload size of reply
[dpdk.git] / lib / librte_vhost / vhost_user.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2010-2018 Intel Corporation
3  */
4
5 /* Security model
6  * --------------
7  * The vhost-user protocol connection is an external interface, so it must be
8  * robust against invalid inputs.
9  *
10  * This is important because the vhost-user master is only one step removed
11  * from the guest.  Malicious guests that have escaped will then launch further
12  * attacks from the vhost-user master.
13  *
14  * Even in deployments where guests are trusted, a bug in the vhost-user master
15  * can still cause invalid messages to be sent.  Such messages must not
16  * compromise the stability of the DPDK application by causing crashes, memory
17  * corruption, or other problematic behavior.
18  *
19  * Do not assume received VhostUserMsg fields contain sensible values!
20  */
21
22 #include <stdint.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <sys/mman.h>
28 #include <sys/types.h>
29 #include <sys/stat.h>
30 #include <assert.h>
31 #ifdef RTE_LIBRTE_VHOST_NUMA
32 #include <numaif.h>
33 #endif
34
35 #include <rte_common.h>
36 #include <rte_malloc.h>
37 #include <rte_log.h>
38
39 #include "iotlb.h"
40 #include "vhost.h"
41 #include "vhost_user.h"
42
43 #define VIRTIO_MIN_MTU 68
44 #define VIRTIO_MAX_MTU 65535
45
46 static const char *vhost_message_str[VHOST_USER_MAX] = {
47         [VHOST_USER_NONE] = "VHOST_USER_NONE",
48         [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
49         [VHOST_USER_SET_FEATURES] = "VHOST_USER_SET_FEATURES",
50         [VHOST_USER_SET_OWNER] = "VHOST_USER_SET_OWNER",
51         [VHOST_USER_RESET_OWNER] = "VHOST_USER_RESET_OWNER",
52         [VHOST_USER_SET_MEM_TABLE] = "VHOST_USER_SET_MEM_TABLE",
53         [VHOST_USER_SET_LOG_BASE] = "VHOST_USER_SET_LOG_BASE",
54         [VHOST_USER_SET_LOG_FD] = "VHOST_USER_SET_LOG_FD",
55         [VHOST_USER_SET_VRING_NUM] = "VHOST_USER_SET_VRING_NUM",
56         [VHOST_USER_SET_VRING_ADDR] = "VHOST_USER_SET_VRING_ADDR",
57         [VHOST_USER_SET_VRING_BASE] = "VHOST_USER_SET_VRING_BASE",
58         [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
59         [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
60         [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
61         [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
62         [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
63         [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
64         [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
65         [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
66         [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
67         [VHOST_USER_NET_SET_MTU]  = "VHOST_USER_NET_SET_MTU",
68         [VHOST_USER_SET_SLAVE_REQ_FD]  = "VHOST_USER_SET_SLAVE_REQ_FD",
69         [VHOST_USER_IOTLB_MSG]  = "VHOST_USER_IOTLB_MSG",
70         [VHOST_USER_CRYPTO_CREATE_SESS] = "VHOST_USER_CRYPTO_CREATE_SESS",
71         [VHOST_USER_CRYPTO_CLOSE_SESS] = "VHOST_USER_CRYPTO_CLOSE_SESS",
72 };
73
74 static uint64_t
75 get_blk_size(int fd)
76 {
77         struct stat stat;
78         int ret;
79
80         ret = fstat(fd, &stat);
81         return ret == -1 ? (uint64_t)-1 : (uint64_t)stat.st_blksize;
82 }
83
84 static void
85 free_mem_region(struct virtio_net *dev)
86 {
87         uint32_t i;
88         struct rte_vhost_mem_region *reg;
89
90         if (!dev || !dev->mem)
91                 return;
92
93         for (i = 0; i < dev->mem->nregions; i++) {
94                 reg = &dev->mem->regions[i];
95                 if (reg->host_user_addr) {
96                         munmap(reg->mmap_addr, reg->mmap_size);
97                         close(reg->fd);
98                 }
99         }
100 }
101
102 void
103 vhost_backend_cleanup(struct virtio_net *dev)
104 {
105         if (dev->mem) {
106                 free_mem_region(dev);
107                 rte_free(dev->mem);
108                 dev->mem = NULL;
109         }
110
111         free(dev->guest_pages);
112         dev->guest_pages = NULL;
113
114         if (dev->log_addr) {
115                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
116                 dev->log_addr = 0;
117         }
118
119         if (dev->slave_req_fd >= 0) {
120                 close(dev->slave_req_fd);
121                 dev->slave_req_fd = -1;
122         }
123 }
124
125 /*
126  * This function just returns success at the moment unless
127  * the device hasn't been initialised.
128  */
129 static int
130 vhost_user_set_owner(struct virtio_net **pdev __rte_unused,
131                 struct VhostUserMsg *msg __rte_unused)
132 {
133         return VH_RESULT_OK;
134 }
135
136 static int
137 vhost_user_reset_owner(struct virtio_net **pdev,
138                 struct VhostUserMsg *msg __rte_unused)
139 {
140         struct virtio_net *dev = *pdev;
141         vhost_destroy_device_notify(dev);
142
143         cleanup_device(dev, 0);
144         reset_device(dev);
145         return VH_RESULT_OK;
146 }
147
148 /*
149  * The features that we support are requested.
150  */
151 static int
152 vhost_user_get_features(struct virtio_net **pdev, struct VhostUserMsg *msg)
153 {
154         struct virtio_net *dev = *pdev;
155         uint64_t features = 0;
156
157         rte_vhost_driver_get_features(dev->ifname, &features);
158
159         msg->payload.u64 = features;
160         msg->size = sizeof(msg->payload.u64);
161
162         return VH_RESULT_REPLY;
163 }
164
165 /*
166  * The queue number that we support are requested.
167  */
168 static int
169 vhost_user_get_queue_num(struct virtio_net **pdev, struct VhostUserMsg *msg)
170 {
171         struct virtio_net *dev = *pdev;
172         uint32_t queue_num = 0;
173
174         rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
175
176         msg->payload.u64 = (uint64_t)queue_num;
177         msg->size = sizeof(msg->payload.u64);
178
179         return VH_RESULT_REPLY;
180 }
181
182 /*
183  * We receive the negotiated features supported by us and the virtio device.
184  */
185 static int
186 vhost_user_set_features(struct virtio_net **pdev, struct VhostUserMsg *msg)
187 {
188         struct virtio_net *dev = *pdev;
189         uint64_t features = msg->payload.u64;
190         uint64_t vhost_features = 0;
191         struct rte_vdpa_device *vdpa_dev;
192         int did = -1;
193
194         rte_vhost_driver_get_features(dev->ifname, &vhost_features);
195         if (features & ~vhost_features) {
196                 RTE_LOG(ERR, VHOST_CONFIG,
197                         "(%d) received invalid negotiated features.\n",
198                         dev->vid);
199                 return VH_RESULT_ERR;
200         }
201
202         if (dev->flags & VIRTIO_DEV_RUNNING) {
203                 if (dev->features == features)
204                         return VH_RESULT_OK;
205
206                 /*
207                  * Error out if master tries to change features while device is
208                  * in running state. The exception being VHOST_F_LOG_ALL, which
209                  * is enabled when the live-migration starts.
210                  */
211                 if ((dev->features ^ features) & ~(1ULL << VHOST_F_LOG_ALL)) {
212                         RTE_LOG(ERR, VHOST_CONFIG,
213                                 "(%d) features changed while device is running.\n",
214                                 dev->vid);
215                         return VH_RESULT_ERR;
216                 }
217
218                 if (dev->notify_ops->features_changed)
219                         dev->notify_ops->features_changed(dev->vid, features);
220         }
221
222         dev->features = features;
223         if (dev->features &
224                 ((1 << VIRTIO_NET_F_MRG_RXBUF) | (1ULL << VIRTIO_F_VERSION_1))) {
225                 dev->vhost_hlen = sizeof(struct virtio_net_hdr_mrg_rxbuf);
226         } else {
227                 dev->vhost_hlen = sizeof(struct virtio_net_hdr);
228         }
229         VHOST_LOG_DEBUG(VHOST_CONFIG,
230                 "(%d) mergeable RX buffers %s, virtio 1 %s\n",
231                 dev->vid,
232                 (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? "on" : "off",
233                 (dev->features & (1ULL << VIRTIO_F_VERSION_1)) ? "on" : "off");
234
235         if ((dev->flags & VIRTIO_DEV_BUILTIN_VIRTIO_NET) &&
236             !(dev->features & (1ULL << VIRTIO_NET_F_MQ))) {
237                 /*
238                  * Remove all but first queue pair if MQ hasn't been
239                  * negotiated. This is safe because the device is not
240                  * running at this stage.
241                  */
242                 while (dev->nr_vring > 2) {
243                         struct vhost_virtqueue *vq;
244
245                         vq = dev->virtqueue[--dev->nr_vring];
246                         if (!vq)
247                                 continue;
248
249                         dev->virtqueue[dev->nr_vring] = NULL;
250                         cleanup_vq(vq, 1);
251                         free_vq(dev, vq);
252                 }
253         }
254
255         did = dev->vdpa_dev_id;
256         vdpa_dev = rte_vdpa_get_device(did);
257         if (vdpa_dev && vdpa_dev->ops->set_features)
258                 vdpa_dev->ops->set_features(dev->vid);
259
260         return VH_RESULT_OK;
261 }
262
263 /*
264  * The virtio device sends us the size of the descriptor ring.
265  */
266 static int
267 vhost_user_set_vring_num(struct virtio_net **pdev,
268                          struct VhostUserMsg *msg)
269 {
270         struct virtio_net *dev = *pdev;
271         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
272
273         vq->size = msg->payload.state.num;
274
275         /* VIRTIO 1.0, 2.4 Virtqueues says:
276          *
277          *   Queue Size value is always a power of 2. The maximum Queue Size
278          *   value is 32768.
279          */
280         if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
281                 RTE_LOG(ERR, VHOST_CONFIG,
282                         "invalid virtqueue size %u\n", vq->size);
283                 return VH_RESULT_ERR;
284         }
285
286         if (dev->dequeue_zero_copy) {
287                 vq->nr_zmbuf = 0;
288                 vq->last_zmbuf_idx = 0;
289                 vq->zmbuf_size = vq->size;
290                 vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
291                                          sizeof(struct zcopy_mbuf), 0);
292                 if (vq->zmbufs == NULL) {
293                         RTE_LOG(WARNING, VHOST_CONFIG,
294                                 "failed to allocate mem for zero copy; "
295                                 "zero copy is force disabled\n");
296                         dev->dequeue_zero_copy = 0;
297                 }
298                 TAILQ_INIT(&vq->zmbuf_list);
299         }
300
301         if (vq_is_packed(dev)) {
302                 vq->shadow_used_packed = rte_malloc(NULL,
303                                 vq->size *
304                                 sizeof(struct vring_used_elem_packed),
305                                 RTE_CACHE_LINE_SIZE);
306                 if (!vq->shadow_used_packed) {
307                         RTE_LOG(ERR, VHOST_CONFIG,
308                                         "failed to allocate memory for shadow used ring.\n");
309                         return VH_RESULT_ERR;
310                 }
311
312         } else {
313                 vq->shadow_used_split = rte_malloc(NULL,
314                                 vq->size * sizeof(struct vring_used_elem),
315                                 RTE_CACHE_LINE_SIZE);
316                 if (!vq->shadow_used_split) {
317                         RTE_LOG(ERR, VHOST_CONFIG,
318                                         "failed to allocate memory for shadow used ring.\n");
319                         return VH_RESULT_ERR;
320                 }
321         }
322
323         vq->batch_copy_elems = rte_malloc(NULL,
324                                 vq->size * sizeof(struct batch_copy_elem),
325                                 RTE_CACHE_LINE_SIZE);
326         if (!vq->batch_copy_elems) {
327                 RTE_LOG(ERR, VHOST_CONFIG,
328                         "failed to allocate memory for batching copy.\n");
329                 return VH_RESULT_ERR;
330         }
331
332         return VH_RESULT_OK;
333 }
334
335 /*
336  * Reallocate virtio_dev and vhost_virtqueue data structure to make them on the
337  * same numa node as the memory of vring descriptor.
338  */
339 #ifdef RTE_LIBRTE_VHOST_NUMA
340 static struct virtio_net*
341 numa_realloc(struct virtio_net *dev, int index)
342 {
343         int oldnode, newnode;
344         struct virtio_net *old_dev;
345         struct vhost_virtqueue *old_vq, *vq;
346         struct zcopy_mbuf *new_zmbuf;
347         struct vring_used_elem *new_shadow_used_split;
348         struct vring_used_elem_packed *new_shadow_used_packed;
349         struct batch_copy_elem *new_batch_copy_elems;
350         int ret;
351
352         old_dev = dev;
353         vq = old_vq = dev->virtqueue[index];
354
355         ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
356                             MPOL_F_NODE | MPOL_F_ADDR);
357
358         /* check if we need to reallocate vq */
359         ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
360                              MPOL_F_NODE | MPOL_F_ADDR);
361         if (ret) {
362                 RTE_LOG(ERR, VHOST_CONFIG,
363                         "Unable to get vq numa information.\n");
364                 return dev;
365         }
366         if (oldnode != newnode) {
367                 RTE_LOG(INFO, VHOST_CONFIG,
368                         "reallocate vq from %d to %d node\n", oldnode, newnode);
369                 vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
370                 if (!vq)
371                         return dev;
372
373                 memcpy(vq, old_vq, sizeof(*vq));
374                 TAILQ_INIT(&vq->zmbuf_list);
375
376                 if (dev->dequeue_zero_copy) {
377                         new_zmbuf = rte_malloc_socket(NULL, vq->zmbuf_size *
378                                         sizeof(struct zcopy_mbuf), 0, newnode);
379                         if (new_zmbuf) {
380                                 rte_free(vq->zmbufs);
381                                 vq->zmbufs = new_zmbuf;
382                         }
383                 }
384
385                 if (vq_is_packed(dev)) {
386                         new_shadow_used_packed = rte_malloc_socket(NULL,
387                                         vq->size *
388                                         sizeof(struct vring_used_elem_packed),
389                                         RTE_CACHE_LINE_SIZE,
390                                         newnode);
391                         if (new_shadow_used_packed) {
392                                 rte_free(vq->shadow_used_packed);
393                                 vq->shadow_used_packed = new_shadow_used_packed;
394                         }
395                 } else {
396                         new_shadow_used_split = rte_malloc_socket(NULL,
397                                         vq->size *
398                                         sizeof(struct vring_used_elem),
399                                         RTE_CACHE_LINE_SIZE,
400                                         newnode);
401                         if (new_shadow_used_split) {
402                                 rte_free(vq->shadow_used_split);
403                                 vq->shadow_used_split = new_shadow_used_split;
404                         }
405                 }
406
407                 new_batch_copy_elems = rte_malloc_socket(NULL,
408                         vq->size * sizeof(struct batch_copy_elem),
409                         RTE_CACHE_LINE_SIZE,
410                         newnode);
411                 if (new_batch_copy_elems) {
412                         rte_free(vq->batch_copy_elems);
413                         vq->batch_copy_elems = new_batch_copy_elems;
414                 }
415
416                 rte_free(old_vq);
417         }
418
419         /* check if we need to reallocate dev */
420         ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
421                             MPOL_F_NODE | MPOL_F_ADDR);
422         if (ret) {
423                 RTE_LOG(ERR, VHOST_CONFIG,
424                         "Unable to get dev numa information.\n");
425                 goto out;
426         }
427         if (oldnode != newnode) {
428                 RTE_LOG(INFO, VHOST_CONFIG,
429                         "reallocate dev from %d to %d node\n",
430                         oldnode, newnode);
431                 dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
432                 if (!dev) {
433                         dev = old_dev;
434                         goto out;
435                 }
436
437                 memcpy(dev, old_dev, sizeof(*dev));
438                 rte_free(old_dev);
439         }
440
441 out:
442         dev->virtqueue[index] = vq;
443         vhost_devices[dev->vid] = dev;
444
445         if (old_vq != vq)
446                 vhost_user_iotlb_init(dev, index);
447
448         return dev;
449 }
450 #else
451 static struct virtio_net*
452 numa_realloc(struct virtio_net *dev, int index __rte_unused)
453 {
454         return dev;
455 }
456 #endif
457
458 /* Converts QEMU virtual address to Vhost virtual address. */
459 static uint64_t
460 qva_to_vva(struct virtio_net *dev, uint64_t qva, uint64_t *len)
461 {
462         struct rte_vhost_mem_region *r;
463         uint32_t i;
464
465         /* Find the region where the address lives. */
466         for (i = 0; i < dev->mem->nregions; i++) {
467                 r = &dev->mem->regions[i];
468
469                 if (qva >= r->guest_user_addr &&
470                     qva <  r->guest_user_addr + r->size) {
471
472                         if (unlikely(*len > r->guest_user_addr + r->size - qva))
473                                 *len = r->guest_user_addr + r->size - qva;
474
475                         return qva - r->guest_user_addr +
476                                r->host_user_addr;
477                 }
478         }
479         *len = 0;
480
481         return 0;
482 }
483
484
485 /*
486  * Converts ring address to Vhost virtual address.
487  * If IOMMU is enabled, the ring address is a guest IO virtual address,
488  * else it is a QEMU virtual address.
489  */
490 static uint64_t
491 ring_addr_to_vva(struct virtio_net *dev, struct vhost_virtqueue *vq,
492                 uint64_t ra, uint64_t *size)
493 {
494         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)) {
495                 uint64_t vva;
496
497                 vva = vhost_user_iotlb_cache_find(vq, ra,
498                                         size, VHOST_ACCESS_RW);
499                 if (!vva)
500                         vhost_user_iotlb_miss(dev, ra, VHOST_ACCESS_RW);
501
502                 return vva;
503         }
504
505         return qva_to_vva(dev, ra, size);
506 }
507
508 static struct virtio_net *
509 translate_ring_addresses(struct virtio_net *dev, int vq_index)
510 {
511         struct vhost_virtqueue *vq = dev->virtqueue[vq_index];
512         struct vhost_vring_addr *addr = &vq->ring_addrs;
513         uint64_t len;
514
515         if (vq_is_packed(dev)) {
516                 len = sizeof(struct vring_packed_desc) * vq->size;
517                 vq->desc_packed = (struct vring_packed_desc *)(uintptr_t)
518                         ring_addr_to_vva(dev, vq, addr->desc_user_addr, &len);
519                 vq->log_guest_addr = 0;
520                 if (vq->desc_packed == NULL ||
521                                 len != sizeof(struct vring_packed_desc) *
522                                 vq->size) {
523                         RTE_LOG(DEBUG, VHOST_CONFIG,
524                                 "(%d) failed to map desc_packed ring.\n",
525                                 dev->vid);
526                         return dev;
527                 }
528
529                 dev = numa_realloc(dev, vq_index);
530                 vq = dev->virtqueue[vq_index];
531                 addr = &vq->ring_addrs;
532
533                 len = sizeof(struct vring_packed_desc_event);
534                 vq->driver_event = (struct vring_packed_desc_event *)
535                                         (uintptr_t)ring_addr_to_vva(dev,
536                                         vq, addr->avail_user_addr, &len);
537                 if (vq->driver_event == NULL ||
538                                 len != sizeof(struct vring_packed_desc_event)) {
539                         RTE_LOG(DEBUG, VHOST_CONFIG,
540                                 "(%d) failed to find driver area address.\n",
541                                 dev->vid);
542                         return dev;
543                 }
544
545                 len = sizeof(struct vring_packed_desc_event);
546                 vq->device_event = (struct vring_packed_desc_event *)
547                                         (uintptr_t)ring_addr_to_vva(dev,
548                                         vq, addr->used_user_addr, &len);
549                 if (vq->device_event == NULL ||
550                                 len != sizeof(struct vring_packed_desc_event)) {
551                         RTE_LOG(DEBUG, VHOST_CONFIG,
552                                 "(%d) failed to find device area address.\n",
553                                 dev->vid);
554                         return dev;
555                 }
556
557                 return dev;
558         }
559
560         /* The addresses are converted from QEMU virtual to Vhost virtual. */
561         if (vq->desc && vq->avail && vq->used)
562                 return dev;
563
564         len = sizeof(struct vring_desc) * vq->size;
565         vq->desc = (struct vring_desc *)(uintptr_t)ring_addr_to_vva(dev,
566                         vq, addr->desc_user_addr, &len);
567         if (vq->desc == 0 || len != sizeof(struct vring_desc) * vq->size) {
568                 RTE_LOG(DEBUG, VHOST_CONFIG,
569                         "(%d) failed to map desc ring.\n",
570                         dev->vid);
571                 return dev;
572         }
573
574         dev = numa_realloc(dev, vq_index);
575         vq = dev->virtqueue[vq_index];
576         addr = &vq->ring_addrs;
577
578         len = sizeof(struct vring_avail) + sizeof(uint16_t) * vq->size;
579         vq->avail = (struct vring_avail *)(uintptr_t)ring_addr_to_vva(dev,
580                         vq, addr->avail_user_addr, &len);
581         if (vq->avail == 0 ||
582                         len != sizeof(struct vring_avail) +
583                         sizeof(uint16_t) * vq->size) {
584                 RTE_LOG(DEBUG, VHOST_CONFIG,
585                         "(%d) failed to map avail ring.\n",
586                         dev->vid);
587                 return dev;
588         }
589
590         len = sizeof(struct vring_used) +
591                 sizeof(struct vring_used_elem) * vq->size;
592         vq->used = (struct vring_used *)(uintptr_t)ring_addr_to_vva(dev,
593                         vq, addr->used_user_addr, &len);
594         if (vq->used == 0 || len != sizeof(struct vring_used) +
595                         sizeof(struct vring_used_elem) * vq->size) {
596                 RTE_LOG(DEBUG, VHOST_CONFIG,
597                         "(%d) failed to map used ring.\n",
598                         dev->vid);
599                 return dev;
600         }
601
602         if (vq->last_used_idx != vq->used->idx) {
603                 RTE_LOG(WARNING, VHOST_CONFIG,
604                         "last_used_idx (%u) and vq->used->idx (%u) mismatches; "
605                         "some packets maybe resent for Tx and dropped for Rx\n",
606                         vq->last_used_idx, vq->used->idx);
607                 vq->last_used_idx  = vq->used->idx;
608                 vq->last_avail_idx = vq->used->idx;
609         }
610
611         vq->log_guest_addr = addr->log_guest_addr;
612
613         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
614                         dev->vid, vq->desc);
615         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address avail: %p\n",
616                         dev->vid, vq->avail);
617         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address used: %p\n",
618                         dev->vid, vq->used);
619         VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) log_guest_addr: %" PRIx64 "\n",
620                         dev->vid, vq->log_guest_addr);
621
622         return dev;
623 }
624
625 /*
626  * The virtio device sends us the desc, used and avail ring addresses.
627  * This function then converts these to our address space.
628  */
629 static int
630 vhost_user_set_vring_addr(struct virtio_net **pdev, struct VhostUserMsg *msg)
631 {
632         struct virtio_net *dev = *pdev;
633         struct vhost_virtqueue *vq;
634         struct vhost_vring_addr *addr = &msg->payload.addr;
635
636         if (dev->mem == NULL)
637                 return VH_RESULT_ERR;
638
639         /* addr->index refers to the queue index. The txq 1, rxq is 0. */
640         vq = dev->virtqueue[msg->payload.addr.index];
641
642         /*
643          * Rings addresses should not be interpreted as long as the ring is not
644          * started and enabled
645          */
646         memcpy(&vq->ring_addrs, addr, sizeof(*addr));
647
648         vring_invalidate(dev, vq);
649
650         if (vq->enabled && (dev->features &
651                                 (1ULL << VHOST_USER_F_PROTOCOL_FEATURES))) {
652                 dev = translate_ring_addresses(dev, msg->payload.addr.index);
653                 if (!dev)
654                         return VH_RESULT_ERR;
655
656                 *pdev = dev;
657         }
658
659         return VH_RESULT_OK;
660 }
661
662 /*
663  * The virtio device sends us the available ring last used index.
664  */
665 static int
666 vhost_user_set_vring_base(struct virtio_net **pdev,
667                           struct VhostUserMsg *msg)
668 {
669         struct virtio_net *dev = *pdev;
670         dev->virtqueue[msg->payload.state.index]->last_used_idx  =
671                         msg->payload.state.num;
672         dev->virtqueue[msg->payload.state.index]->last_avail_idx =
673                         msg->payload.state.num;
674
675         return VH_RESULT_OK;
676 }
677
678 static int
679 add_one_guest_page(struct virtio_net *dev, uint64_t guest_phys_addr,
680                    uint64_t host_phys_addr, uint64_t size)
681 {
682         struct guest_page *page, *last_page;
683
684         if (dev->nr_guest_pages == dev->max_guest_pages) {
685                 dev->max_guest_pages *= 2;
686                 dev->guest_pages = realloc(dev->guest_pages,
687                                         dev->max_guest_pages * sizeof(*page));
688                 if (!dev->guest_pages) {
689                         RTE_LOG(ERR, VHOST_CONFIG, "cannot realloc guest_pages\n");
690                         return -1;
691                 }
692         }
693
694         if (dev->nr_guest_pages > 0) {
695                 last_page = &dev->guest_pages[dev->nr_guest_pages - 1];
696                 /* merge if the two pages are continuous */
697                 if (host_phys_addr == last_page->host_phys_addr +
698                                       last_page->size) {
699                         last_page->size += size;
700                         return 0;
701                 }
702         }
703
704         page = &dev->guest_pages[dev->nr_guest_pages++];
705         page->guest_phys_addr = guest_phys_addr;
706         page->host_phys_addr  = host_phys_addr;
707         page->size = size;
708
709         return 0;
710 }
711
712 static int
713 add_guest_pages(struct virtio_net *dev, struct rte_vhost_mem_region *reg,
714                 uint64_t page_size)
715 {
716         uint64_t reg_size = reg->size;
717         uint64_t host_user_addr  = reg->host_user_addr;
718         uint64_t guest_phys_addr = reg->guest_phys_addr;
719         uint64_t host_phys_addr;
720         uint64_t size;
721
722         host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)host_user_addr);
723         size = page_size - (guest_phys_addr & (page_size - 1));
724         size = RTE_MIN(size, reg_size);
725
726         if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr, size) < 0)
727                 return -1;
728
729         host_user_addr  += size;
730         guest_phys_addr += size;
731         reg_size -= size;
732
733         while (reg_size > 0) {
734                 size = RTE_MIN(reg_size, page_size);
735                 host_phys_addr = rte_mem_virt2iova((void *)(uintptr_t)
736                                                   host_user_addr);
737                 if (add_one_guest_page(dev, guest_phys_addr, host_phys_addr,
738                                 size) < 0)
739                         return -1;
740
741                 host_user_addr  += size;
742                 guest_phys_addr += size;
743                 reg_size -= size;
744         }
745
746         return 0;
747 }
748
749 #ifdef RTE_LIBRTE_VHOST_DEBUG
750 /* TODO: enable it only in debug mode? */
751 static void
752 dump_guest_pages(struct virtio_net *dev)
753 {
754         uint32_t i;
755         struct guest_page *page;
756
757         for (i = 0; i < dev->nr_guest_pages; i++) {
758                 page = &dev->guest_pages[i];
759
760                 RTE_LOG(INFO, VHOST_CONFIG,
761                         "guest physical page region %u\n"
762                         "\t guest_phys_addr: %" PRIx64 "\n"
763                         "\t host_phys_addr : %" PRIx64 "\n"
764                         "\t size           : %" PRIx64 "\n",
765                         i,
766                         page->guest_phys_addr,
767                         page->host_phys_addr,
768                         page->size);
769         }
770 }
771 #else
772 #define dump_guest_pages(dev)
773 #endif
774
775 static bool
776 vhost_memory_changed(struct VhostUserMemory *new,
777                      struct rte_vhost_memory *old)
778 {
779         uint32_t i;
780
781         if (new->nregions != old->nregions)
782                 return true;
783
784         for (i = 0; i < new->nregions; ++i) {
785                 VhostUserMemoryRegion *new_r = &new->regions[i];
786                 struct rte_vhost_mem_region *old_r = &old->regions[i];
787
788                 if (new_r->guest_phys_addr != old_r->guest_phys_addr)
789                         return true;
790                 if (new_r->memory_size != old_r->size)
791                         return true;
792                 if (new_r->userspace_addr != old_r->guest_user_addr)
793                         return true;
794         }
795
796         return false;
797 }
798
799 static int
800 vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *msg)
801 {
802         struct virtio_net *dev = *pdev;
803         struct VhostUserMemory memory = msg->payload.memory;
804         struct rte_vhost_mem_region *reg;
805         void *mmap_addr;
806         uint64_t mmap_size;
807         uint64_t mmap_offset;
808         uint64_t alignment;
809         uint32_t i;
810         int populate;
811         int fd;
812
813         if (memory.nregions > VHOST_MEMORY_MAX_NREGIONS) {
814                 RTE_LOG(ERR, VHOST_CONFIG,
815                         "too many memory regions (%u)\n", memory.nregions);
816                 return VH_RESULT_ERR;
817         }
818
819         if (dev->mem && !vhost_memory_changed(&memory, dev->mem)) {
820                 RTE_LOG(INFO, VHOST_CONFIG,
821                         "(%d) memory regions not changed\n", dev->vid);
822
823                 for (i = 0; i < memory.nregions; i++)
824                         close(msg->fds[i]);
825
826                 return VH_RESULT_OK;
827         }
828
829         if (dev->mem) {
830                 free_mem_region(dev);
831                 rte_free(dev->mem);
832                 dev->mem = NULL;
833         }
834
835         /* Flush IOTLB cache as previous HVAs are now invalid */
836         if (dev->features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))
837                 for (i = 0; i < dev->nr_vring; i++)
838                         vhost_user_iotlb_flush_all(dev->virtqueue[i]);
839
840         dev->nr_guest_pages = 0;
841         if (!dev->guest_pages) {
842                 dev->max_guest_pages = 8;
843                 dev->guest_pages = malloc(dev->max_guest_pages *
844                                                 sizeof(struct guest_page));
845                 if (dev->guest_pages == NULL) {
846                         RTE_LOG(ERR, VHOST_CONFIG,
847                                 "(%d) failed to allocate memory "
848                                 "for dev->guest_pages\n",
849                                 dev->vid);
850                         return VH_RESULT_ERR;
851                 }
852         }
853
854         dev->mem = rte_zmalloc("vhost-mem-table", sizeof(struct rte_vhost_memory) +
855                 sizeof(struct rte_vhost_mem_region) * memory.nregions, 0);
856         if (dev->mem == NULL) {
857                 RTE_LOG(ERR, VHOST_CONFIG,
858                         "(%d) failed to allocate memory for dev->mem\n",
859                         dev->vid);
860                 return VH_RESULT_ERR;
861         }
862         dev->mem->nregions = memory.nregions;
863
864         for (i = 0; i < memory.nregions; i++) {
865                 fd  = msg->fds[i];
866                 reg = &dev->mem->regions[i];
867
868                 reg->guest_phys_addr = memory.regions[i].guest_phys_addr;
869                 reg->guest_user_addr = memory.regions[i].userspace_addr;
870                 reg->size            = memory.regions[i].memory_size;
871                 reg->fd              = fd;
872
873                 mmap_offset = memory.regions[i].mmap_offset;
874
875                 /* Check for memory_size + mmap_offset overflow */
876                 if (mmap_offset >= -reg->size) {
877                         RTE_LOG(ERR, VHOST_CONFIG,
878                                 "mmap_offset (%#"PRIx64") and memory_size "
879                                 "(%#"PRIx64") overflow\n",
880                                 mmap_offset, reg->size);
881                         goto err_mmap;
882                 }
883
884                 mmap_size = reg->size + mmap_offset;
885
886                 /* mmap() without flag of MAP_ANONYMOUS, should be called
887                  * with length argument aligned with hugepagesz at older
888                  * longterm version Linux, like 2.6.32 and 3.2.72, or
889                  * mmap() will fail with EINVAL.
890                  *
891                  * to avoid failure, make sure in caller to keep length
892                  * aligned.
893                  */
894                 alignment = get_blk_size(fd);
895                 if (alignment == (uint64_t)-1) {
896                         RTE_LOG(ERR, VHOST_CONFIG,
897                                 "couldn't get hugepage size through fstat\n");
898                         goto err_mmap;
899                 }
900                 mmap_size = RTE_ALIGN_CEIL(mmap_size, alignment);
901
902                 populate = (dev->dequeue_zero_copy) ? MAP_POPULATE : 0;
903                 mmap_addr = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE,
904                                  MAP_SHARED | populate, fd, 0);
905
906                 if (mmap_addr == MAP_FAILED) {
907                         RTE_LOG(ERR, VHOST_CONFIG,
908                                 "mmap region %u failed.\n", i);
909                         goto err_mmap;
910                 }
911
912                 reg->mmap_addr = mmap_addr;
913                 reg->mmap_size = mmap_size;
914                 reg->host_user_addr = (uint64_t)(uintptr_t)mmap_addr +
915                                       mmap_offset;
916
917                 if (dev->dequeue_zero_copy)
918                         if (add_guest_pages(dev, reg, alignment) < 0) {
919                                 RTE_LOG(ERR, VHOST_CONFIG,
920                                         "adding guest pages to region %u failed.\n",
921                                         i);
922                                 goto err_mmap;
923                         }
924
925                 RTE_LOG(INFO, VHOST_CONFIG,
926                         "guest memory region %u, size: 0x%" PRIx64 "\n"
927                         "\t guest physical addr: 0x%" PRIx64 "\n"
928                         "\t guest virtual  addr: 0x%" PRIx64 "\n"
929                         "\t host  virtual  addr: 0x%" PRIx64 "\n"
930                         "\t mmap addr : 0x%" PRIx64 "\n"
931                         "\t mmap size : 0x%" PRIx64 "\n"
932                         "\t mmap align: 0x%" PRIx64 "\n"
933                         "\t mmap off  : 0x%" PRIx64 "\n",
934                         i, reg->size,
935                         reg->guest_phys_addr,
936                         reg->guest_user_addr,
937                         reg->host_user_addr,
938                         (uint64_t)(uintptr_t)mmap_addr,
939                         mmap_size,
940                         alignment,
941                         mmap_offset);
942         }
943
944         for (i = 0; i < dev->nr_vring; i++) {
945                 struct vhost_virtqueue *vq = dev->virtqueue[i];
946
947                 if (vq->desc || vq->avail || vq->used) {
948                         /*
949                          * If the memory table got updated, the ring addresses
950                          * need to be translated again as virtual addresses have
951                          * changed.
952                          */
953                         vring_invalidate(dev, vq);
954
955                         dev = translate_ring_addresses(dev, i);
956                         if (!dev)
957                                 return VH_RESULT_ERR;
958
959                         *pdev = dev;
960                 }
961         }
962
963         dump_guest_pages(dev);
964
965         return VH_RESULT_OK;
966
967 err_mmap:
968         free_mem_region(dev);
969         rte_free(dev->mem);
970         dev->mem = NULL;
971         return VH_RESULT_ERR;
972 }
973
974 static bool
975 vq_is_ready(struct virtio_net *dev, struct vhost_virtqueue *vq)
976 {
977         bool rings_ok;
978
979         if (!vq)
980                 return false;
981
982         if (vq_is_packed(dev))
983                 rings_ok = !!vq->desc_packed;
984         else
985                 rings_ok = vq->desc && vq->avail && vq->used;
986
987         return rings_ok &&
988                vq->kickfd != VIRTIO_UNINITIALIZED_EVENTFD &&
989                vq->callfd != VIRTIO_UNINITIALIZED_EVENTFD;
990 }
991
992 static int
993 virtio_is_ready(struct virtio_net *dev)
994 {
995         struct vhost_virtqueue *vq;
996         uint32_t i;
997
998         if (dev->nr_vring == 0)
999                 return 0;
1000
1001         for (i = 0; i < dev->nr_vring; i++) {
1002                 vq = dev->virtqueue[i];
1003
1004                 if (!vq_is_ready(dev, vq))
1005                         return 0;
1006         }
1007
1008         RTE_LOG(INFO, VHOST_CONFIG,
1009                 "virtio is now ready for processing.\n");
1010         return 1;
1011 }
1012
1013 static int
1014 vhost_user_set_vring_call(struct virtio_net **pdev, struct VhostUserMsg *msg)
1015 {
1016         struct virtio_net *dev = *pdev;
1017         struct vhost_vring_file file;
1018         struct vhost_virtqueue *vq;
1019
1020         file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1021         if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
1022                 file.fd = VIRTIO_INVALID_EVENTFD;
1023         else
1024                 file.fd = msg->fds[0];
1025         RTE_LOG(INFO, VHOST_CONFIG,
1026                 "vring call idx:%d file:%d\n", file.index, file.fd);
1027
1028         vq = dev->virtqueue[file.index];
1029         if (vq->callfd >= 0)
1030                 close(vq->callfd);
1031
1032         vq->callfd = file.fd;
1033
1034         return VH_RESULT_OK;
1035 }
1036
1037 static int vhost_user_set_vring_err(struct virtio_net **pdev __rte_unused,
1038                         struct VhostUserMsg *msg)
1039 {
1040         if (!(msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK))
1041                 close(msg->fds[0]);
1042         RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
1043
1044         return VH_RESULT_OK;
1045 }
1046
1047 static int
1048 vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *msg)
1049 {
1050         struct virtio_net *dev = *pdev;
1051         struct vhost_vring_file file;
1052         struct vhost_virtqueue *vq;
1053
1054         file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1055         if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
1056                 file.fd = VIRTIO_INVALID_EVENTFD;
1057         else
1058                 file.fd = msg->fds[0];
1059         RTE_LOG(INFO, VHOST_CONFIG,
1060                 "vring kick idx:%d file:%d\n", file.index, file.fd);
1061
1062         /* Interpret ring addresses only when ring is started. */
1063         dev = translate_ring_addresses(dev, file.index);
1064         if (!dev)
1065                 return VH_RESULT_ERR;
1066
1067         *pdev = dev;
1068
1069         vq = dev->virtqueue[file.index];
1070
1071         /*
1072          * When VHOST_USER_F_PROTOCOL_FEATURES is not negotiated,
1073          * the ring starts already enabled. Otherwise, it is enabled via
1074          * the SET_VRING_ENABLE message.
1075          */
1076         if (!(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)))
1077                 vq->enabled = 1;
1078
1079         if (vq->kickfd >= 0)
1080                 close(vq->kickfd);
1081         vq->kickfd = file.fd;
1082
1083         return VH_RESULT_OK;
1084 }
1085
1086 static void
1087 free_zmbufs(struct vhost_virtqueue *vq)
1088 {
1089         struct zcopy_mbuf *zmbuf, *next;
1090
1091         for (zmbuf = TAILQ_FIRST(&vq->zmbuf_list);
1092              zmbuf != NULL; zmbuf = next) {
1093                 next = TAILQ_NEXT(zmbuf, next);
1094
1095                 rte_pktmbuf_free(zmbuf->mbuf);
1096                 TAILQ_REMOVE(&vq->zmbuf_list, zmbuf, next);
1097         }
1098
1099         rte_free(vq->zmbufs);
1100 }
1101
1102 /*
1103  * when virtio is stopped, qemu will send us the GET_VRING_BASE message.
1104  */
1105 static int
1106 vhost_user_get_vring_base(struct virtio_net **pdev,
1107                           struct VhostUserMsg *msg)
1108 {
1109         struct virtio_net *dev = *pdev;
1110         struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
1111
1112         /* We have to stop the queue (virtio) if it is running. */
1113         vhost_destroy_device_notify(dev);
1114
1115         dev->flags &= ~VIRTIO_DEV_READY;
1116         dev->flags &= ~VIRTIO_DEV_VDPA_CONFIGURED;
1117
1118         /* Here we are safe to get the last avail index */
1119         msg->payload.state.num = vq->last_avail_idx;
1120
1121         RTE_LOG(INFO, VHOST_CONFIG,
1122                 "vring base idx:%d file:%d\n", msg->payload.state.index,
1123                 msg->payload.state.num);
1124         /*
1125          * Based on current qemu vhost-user implementation, this message is
1126          * sent and only sent in vhost_vring_stop.
1127          * TODO: cleanup the vring, it isn't usable since here.
1128          */
1129         if (vq->kickfd >= 0)
1130                 close(vq->kickfd);
1131
1132         vq->kickfd = VIRTIO_UNINITIALIZED_EVENTFD;
1133
1134         if (vq->callfd >= 0)
1135                 close(vq->callfd);
1136
1137         vq->callfd = VIRTIO_UNINITIALIZED_EVENTFD;
1138
1139         if (dev->dequeue_zero_copy)
1140                 free_zmbufs(vq);
1141         if (vq_is_packed(dev)) {
1142                 rte_free(vq->shadow_used_packed);
1143                 vq->shadow_used_packed = NULL;
1144         } else {
1145                 rte_free(vq->shadow_used_split);
1146                 vq->shadow_used_split = NULL;
1147         }
1148
1149         rte_free(vq->batch_copy_elems);
1150         vq->batch_copy_elems = NULL;
1151
1152         msg->size = sizeof(msg->payload.state);
1153
1154         return VH_RESULT_REPLY;
1155 }
1156
1157 /*
1158  * when virtio queues are ready to work, qemu will send us to
1159  * enable the virtio queue pair.
1160  */
1161 static int
1162 vhost_user_set_vring_enable(struct virtio_net **pdev,
1163                             struct VhostUserMsg *msg)
1164 {
1165         struct virtio_net *dev = *pdev;
1166         int enable = (int)msg->payload.state.num;
1167         int index = (int)msg->payload.state.index;
1168         struct rte_vdpa_device *vdpa_dev;
1169         int did = -1;
1170
1171         RTE_LOG(INFO, VHOST_CONFIG,
1172                 "set queue enable: %d to qp idx: %d\n",
1173                 enable, index);
1174
1175         did = dev->vdpa_dev_id;
1176         vdpa_dev = rte_vdpa_get_device(did);
1177         if (vdpa_dev && vdpa_dev->ops->set_vring_state)
1178                 vdpa_dev->ops->set_vring_state(dev->vid, index, enable);
1179
1180         if (dev->notify_ops->vring_state_changed)
1181                 dev->notify_ops->vring_state_changed(dev->vid,
1182                                 index, enable);
1183
1184         dev->virtqueue[index]->enabled = enable;
1185
1186         return VH_RESULT_OK;
1187 }
1188
1189 static int
1190 vhost_user_get_protocol_features(struct virtio_net **pdev,
1191                                  struct VhostUserMsg *msg)
1192 {
1193         struct virtio_net *dev = *pdev;
1194         uint64_t features, protocol_features;
1195
1196         rte_vhost_driver_get_features(dev->ifname, &features);
1197         rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
1198
1199         /*
1200          * REPLY_ACK protocol feature is only mandatory for now
1201          * for IOMMU feature. If IOMMU is explicitly disabled by the
1202          * application, disable also REPLY_ACK feature for older buggy
1203          * Qemu versions (from v2.7.0 to v2.9.0).
1204          */
1205         if (!(features & (1ULL << VIRTIO_F_IOMMU_PLATFORM)))
1206                 protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
1207
1208         msg->payload.u64 = protocol_features;
1209         msg->size = sizeof(msg->payload.u64);
1210
1211         return VH_RESULT_REPLY;
1212 }
1213
1214 static int
1215 vhost_user_set_protocol_features(struct virtio_net **pdev,
1216                                  struct VhostUserMsg *msg)
1217 {
1218         struct virtio_net *dev = *pdev;
1219         uint64_t protocol_features = msg->payload.u64;
1220         if (protocol_features & ~VHOST_USER_PROTOCOL_FEATURES) {
1221                 RTE_LOG(ERR, VHOST_CONFIG,
1222                         "(%d) received invalid protocol features.\n",
1223                         dev->vid);
1224                 return VH_RESULT_ERR;
1225         }
1226
1227         dev->protocol_features = protocol_features;
1228
1229         return VH_RESULT_OK;
1230 }
1231
1232 static int
1233 vhost_user_set_log_base(struct virtio_net **pdev, struct VhostUserMsg *msg)
1234 {
1235         struct virtio_net *dev = *pdev;
1236         int fd = msg->fds[0];
1237         uint64_t size, off;
1238         void *addr;
1239
1240         if (fd < 0) {
1241                 RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
1242                 return VH_RESULT_ERR;
1243         }
1244
1245         if (msg->size != sizeof(VhostUserLog)) {
1246                 RTE_LOG(ERR, VHOST_CONFIG,
1247                         "invalid log base msg size: %"PRId32" != %d\n",
1248                         msg->size, (int)sizeof(VhostUserLog));
1249                 return VH_RESULT_ERR;
1250         }
1251
1252         size = msg->payload.log.mmap_size;
1253         off  = msg->payload.log.mmap_offset;
1254
1255         /* Don't allow mmap_offset to point outside the mmap region */
1256         if (off > size) {
1257                 RTE_LOG(ERR, VHOST_CONFIG,
1258                         "log offset %#"PRIx64" exceeds log size %#"PRIx64"\n",
1259                         off, size);
1260                 return VH_RESULT_ERR;
1261         }
1262
1263         RTE_LOG(INFO, VHOST_CONFIG,
1264                 "log mmap size: %"PRId64", offset: %"PRId64"\n",
1265                 size, off);
1266
1267         /*
1268          * mmap from 0 to workaround a hugepage mmap bug: mmap will
1269          * fail when offset is not page size aligned.
1270          */
1271         addr = mmap(0, size + off, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
1272         close(fd);
1273         if (addr == MAP_FAILED) {
1274                 RTE_LOG(ERR, VHOST_CONFIG, "mmap log base failed!\n");
1275                 return VH_RESULT_ERR;
1276         }
1277
1278         /*
1279          * Free previously mapped log memory on occasionally
1280          * multiple VHOST_USER_SET_LOG_BASE.
1281          */
1282         if (dev->log_addr) {
1283                 munmap((void *)(uintptr_t)dev->log_addr, dev->log_size);
1284         }
1285         dev->log_addr = (uint64_t)(uintptr_t)addr;
1286         dev->log_base = dev->log_addr + off;
1287         dev->log_size = size;
1288
1289         /*
1290          * The spec is not clear about it (yet), but QEMU doesn't expect
1291          * any payload in the reply.
1292          */
1293         msg->size = 0;
1294
1295         return VH_RESULT_REPLY;
1296 }
1297
1298 static int vhost_user_set_log_fd(struct virtio_net **pdev __rte_unused,
1299                         struct VhostUserMsg *msg)
1300 {
1301         close(msg->fds[0]);
1302         RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
1303
1304         return VH_RESULT_OK;
1305 }
1306
1307 /*
1308  * An rarp packet is constructed and broadcasted to notify switches about
1309  * the new location of the migrated VM, so that packets from outside will
1310  * not be lost after migration.
1311  *
1312  * However, we don't actually "send" a rarp packet here, instead, we set
1313  * a flag 'broadcast_rarp' to let rte_vhost_dequeue_burst() inject it.
1314  */
1315 static int
1316 vhost_user_send_rarp(struct virtio_net **pdev, struct VhostUserMsg *msg)
1317 {
1318         struct virtio_net *dev = *pdev;
1319         uint8_t *mac = (uint8_t *)&msg->payload.u64;
1320         struct rte_vdpa_device *vdpa_dev;
1321         int did = -1;
1322
1323         RTE_LOG(DEBUG, VHOST_CONFIG,
1324                 ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
1325                 mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
1326         memcpy(dev->mac.addr_bytes, mac, 6);
1327
1328         /*
1329          * Set the flag to inject a RARP broadcast packet at
1330          * rte_vhost_dequeue_burst().
1331          *
1332          * rte_smp_wmb() is for making sure the mac is copied
1333          * before the flag is set.
1334          */
1335         rte_smp_wmb();
1336         rte_atomic16_set(&dev->broadcast_rarp, 1);
1337         did = dev->vdpa_dev_id;
1338         vdpa_dev = rte_vdpa_get_device(did);
1339         if (vdpa_dev && vdpa_dev->ops->migration_done)
1340                 vdpa_dev->ops->migration_done(dev->vid);
1341
1342         return VH_RESULT_OK;
1343 }
1344
1345 static int
1346 vhost_user_net_set_mtu(struct virtio_net **pdev, struct VhostUserMsg *msg)
1347 {
1348         struct virtio_net *dev = *pdev;
1349         if (msg->payload.u64 < VIRTIO_MIN_MTU ||
1350                         msg->payload.u64 > VIRTIO_MAX_MTU) {
1351                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
1352                                 msg->payload.u64);
1353
1354                 return VH_RESULT_ERR;
1355         }
1356
1357         dev->mtu = msg->payload.u64;
1358
1359         return VH_RESULT_OK;
1360 }
1361
1362 static int
1363 vhost_user_set_req_fd(struct virtio_net **pdev, struct VhostUserMsg *msg)
1364 {
1365         struct virtio_net *dev = *pdev;
1366         int fd = msg->fds[0];
1367
1368         if (fd < 0) {
1369                 RTE_LOG(ERR, VHOST_CONFIG,
1370                                 "Invalid file descriptor for slave channel (%d)\n",
1371                                 fd);
1372                 return VH_RESULT_ERR;
1373         }
1374
1375         dev->slave_req_fd = fd;
1376
1377         return VH_RESULT_OK;
1378 }
1379
1380 static int
1381 is_vring_iotlb_update(struct vhost_virtqueue *vq, struct vhost_iotlb_msg *imsg)
1382 {
1383         struct vhost_vring_addr *ra;
1384         uint64_t start, end;
1385
1386         start = imsg->iova;
1387         end = start + imsg->size;
1388
1389         ra = &vq->ring_addrs;
1390         if (ra->desc_user_addr >= start && ra->desc_user_addr < end)
1391                 return 1;
1392         if (ra->avail_user_addr >= start && ra->avail_user_addr < end)
1393                 return 1;
1394         if (ra->used_user_addr >= start && ra->used_user_addr < end)
1395                 return 1;
1396
1397         return 0;
1398 }
1399
1400 static int
1401 is_vring_iotlb_invalidate(struct vhost_virtqueue *vq,
1402                                 struct vhost_iotlb_msg *imsg)
1403 {
1404         uint64_t istart, iend, vstart, vend;
1405
1406         istart = imsg->iova;
1407         iend = istart + imsg->size - 1;
1408
1409         vstart = (uintptr_t)vq->desc;
1410         vend = vstart + sizeof(struct vring_desc) * vq->size - 1;
1411         if (vstart <= iend && istart <= vend)
1412                 return 1;
1413
1414         vstart = (uintptr_t)vq->avail;
1415         vend = vstart + sizeof(struct vring_avail);
1416         vend += sizeof(uint16_t) * vq->size - 1;
1417         if (vstart <= iend && istart <= vend)
1418                 return 1;
1419
1420         vstart = (uintptr_t)vq->used;
1421         vend = vstart + sizeof(struct vring_used);
1422         vend += sizeof(struct vring_used_elem) * vq->size - 1;
1423         if (vstart <= iend && istart <= vend)
1424                 return 1;
1425
1426         return 0;
1427 }
1428
1429 static int
1430 vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg)
1431 {
1432         struct virtio_net *dev = *pdev;
1433         struct vhost_iotlb_msg *imsg = &msg->payload.iotlb;
1434         uint16_t i;
1435         uint64_t vva, len;
1436
1437         switch (imsg->type) {
1438         case VHOST_IOTLB_UPDATE:
1439                 len = imsg->size;
1440                 vva = qva_to_vva(dev, imsg->uaddr, &len);
1441                 if (!vva)
1442                         return VH_RESULT_ERR;
1443
1444                 for (i = 0; i < dev->nr_vring; i++) {
1445                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1446
1447                         vhost_user_iotlb_cache_insert(vq, imsg->iova, vva,
1448                                         len, imsg->perm);
1449
1450                         if (is_vring_iotlb_update(vq, imsg))
1451                                 *pdev = dev = translate_ring_addresses(dev, i);
1452                 }
1453                 break;
1454         case VHOST_IOTLB_INVALIDATE:
1455                 for (i = 0; i < dev->nr_vring; i++) {
1456                         struct vhost_virtqueue *vq = dev->virtqueue[i];
1457
1458                         vhost_user_iotlb_cache_remove(vq, imsg->iova,
1459                                         imsg->size);
1460
1461                         if (is_vring_iotlb_invalidate(vq, imsg))
1462                                 vring_invalidate(dev, vq);
1463                 }
1464                 break;
1465         default:
1466                 RTE_LOG(ERR, VHOST_CONFIG, "Invalid IOTLB message type (%d)\n",
1467                                 imsg->type);
1468                 return VH_RESULT_ERR;
1469         }
1470
1471         return VH_RESULT_OK;
1472 }
1473
1474 typedef int (*vhost_message_handler_t)(struct virtio_net **pdev,
1475                                         struct VhostUserMsg *msg);
1476 static vhost_message_handler_t vhost_message_handlers[VHOST_USER_MAX] = {
1477         [VHOST_USER_NONE] = NULL,
1478         [VHOST_USER_GET_FEATURES] = vhost_user_get_features,
1479         [VHOST_USER_SET_FEATURES] = vhost_user_set_features,
1480         [VHOST_USER_SET_OWNER] = vhost_user_set_owner,
1481         [VHOST_USER_RESET_OWNER] = vhost_user_reset_owner,
1482         [VHOST_USER_SET_MEM_TABLE] = vhost_user_set_mem_table,
1483         [VHOST_USER_SET_LOG_BASE] = vhost_user_set_log_base,
1484         [VHOST_USER_SET_LOG_FD] = vhost_user_set_log_fd,
1485         [VHOST_USER_SET_VRING_NUM] = vhost_user_set_vring_num,
1486         [VHOST_USER_SET_VRING_ADDR] = vhost_user_set_vring_addr,
1487         [VHOST_USER_SET_VRING_BASE] = vhost_user_set_vring_base,
1488         [VHOST_USER_GET_VRING_BASE] = vhost_user_get_vring_base,
1489         [VHOST_USER_SET_VRING_KICK] = vhost_user_set_vring_kick,
1490         [VHOST_USER_SET_VRING_CALL] = vhost_user_set_vring_call,
1491         [VHOST_USER_SET_VRING_ERR] = vhost_user_set_vring_err,
1492         [VHOST_USER_GET_PROTOCOL_FEATURES] = vhost_user_get_protocol_features,
1493         [VHOST_USER_SET_PROTOCOL_FEATURES] = vhost_user_set_protocol_features,
1494         [VHOST_USER_GET_QUEUE_NUM] = vhost_user_get_queue_num,
1495         [VHOST_USER_SET_VRING_ENABLE] = vhost_user_set_vring_enable,
1496         [VHOST_USER_SEND_RARP] = vhost_user_send_rarp,
1497         [VHOST_USER_NET_SET_MTU] = vhost_user_net_set_mtu,
1498         [VHOST_USER_SET_SLAVE_REQ_FD] = vhost_user_set_req_fd,
1499         [VHOST_USER_IOTLB_MSG] = vhost_user_iotlb_msg,
1500 };
1501
1502
1503 /* return bytes# of read on success or negative val on failure. */
1504 static int
1505 read_vhost_message(int sockfd, struct VhostUserMsg *msg)
1506 {
1507         int ret;
1508
1509         ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
1510                 msg->fds, VHOST_MEMORY_MAX_NREGIONS);
1511         if (ret <= 0)
1512                 return ret;
1513
1514         if (msg && msg->size) {
1515                 if (msg->size > sizeof(msg->payload)) {
1516                         RTE_LOG(ERR, VHOST_CONFIG,
1517                                 "invalid msg size: %d\n", msg->size);
1518                         return -1;
1519                 }
1520                 ret = read(sockfd, &msg->payload, msg->size);
1521                 if (ret <= 0)
1522                         return ret;
1523                 if (ret != (int)msg->size) {
1524                         RTE_LOG(ERR, VHOST_CONFIG,
1525                                 "read control message failed\n");
1526                         return -1;
1527                 }
1528         }
1529
1530         return ret;
1531 }
1532
1533 static int
1534 send_vhost_message(int sockfd, struct VhostUserMsg *msg, int *fds, int fd_num)
1535 {
1536         if (!msg)
1537                 return 0;
1538
1539         return send_fd_message(sockfd, (char *)msg,
1540                 VHOST_USER_HDR_SIZE + msg->size, fds, fd_num);
1541 }
1542
1543 static int
1544 send_vhost_reply(int sockfd, struct VhostUserMsg *msg)
1545 {
1546         if (!msg)
1547                 return 0;
1548
1549         msg->flags &= ~VHOST_USER_VERSION_MASK;
1550         msg->flags &= ~VHOST_USER_NEED_REPLY;
1551         msg->flags |= VHOST_USER_VERSION;
1552         msg->flags |= VHOST_USER_REPLY_MASK;
1553
1554         return send_vhost_message(sockfd, msg, NULL, 0);
1555 }
1556
1557 static int
1558 send_vhost_slave_message(struct virtio_net *dev, struct VhostUserMsg *msg,
1559                          int *fds, int fd_num)
1560 {
1561         int ret;
1562
1563         if (msg->flags & VHOST_USER_NEED_REPLY)
1564                 rte_spinlock_lock(&dev->slave_req_lock);
1565
1566         ret = send_vhost_message(dev->slave_req_fd, msg, fds, fd_num);
1567         if (ret < 0 && (msg->flags & VHOST_USER_NEED_REPLY))
1568                 rte_spinlock_unlock(&dev->slave_req_lock);
1569
1570         return ret;
1571 }
1572
1573 /*
1574  * Allocate a queue pair if it hasn't been allocated yet
1575  */
1576 static int
1577 vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev,
1578                         struct VhostUserMsg *msg)
1579 {
1580         uint16_t vring_idx;
1581
1582         switch (msg->request.master) {
1583         case VHOST_USER_SET_VRING_KICK:
1584         case VHOST_USER_SET_VRING_CALL:
1585         case VHOST_USER_SET_VRING_ERR:
1586                 vring_idx = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1587                 break;
1588         case VHOST_USER_SET_VRING_NUM:
1589         case VHOST_USER_SET_VRING_BASE:
1590         case VHOST_USER_SET_VRING_ENABLE:
1591                 vring_idx = msg->payload.state.index;
1592                 break;
1593         case VHOST_USER_SET_VRING_ADDR:
1594                 vring_idx = msg->payload.addr.index;
1595                 break;
1596         default:
1597                 return 0;
1598         }
1599
1600         if (vring_idx >= VHOST_MAX_VRING) {
1601                 RTE_LOG(ERR, VHOST_CONFIG,
1602                         "invalid vring index: %u\n", vring_idx);
1603                 return -1;
1604         }
1605
1606         if (dev->virtqueue[vring_idx])
1607                 return 0;
1608
1609         return alloc_vring_queue(dev, vring_idx);
1610 }
1611
1612 static void
1613 vhost_user_lock_all_queue_pairs(struct virtio_net *dev)
1614 {
1615         unsigned int i = 0;
1616         unsigned int vq_num = 0;
1617
1618         while (vq_num < dev->nr_vring) {
1619                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1620
1621                 if (vq) {
1622                         rte_spinlock_lock(&vq->access_lock);
1623                         vq_num++;
1624                 }
1625                 i++;
1626         }
1627 }
1628
1629 static void
1630 vhost_user_unlock_all_queue_pairs(struct virtio_net *dev)
1631 {
1632         unsigned int i = 0;
1633         unsigned int vq_num = 0;
1634
1635         while (vq_num < dev->nr_vring) {
1636                 struct vhost_virtqueue *vq = dev->virtqueue[i];
1637
1638                 if (vq) {
1639                         rte_spinlock_unlock(&vq->access_lock);
1640                         vq_num++;
1641                 }
1642                 i++;
1643         }
1644 }
1645
1646 int
1647 vhost_user_msg_handler(int vid, int fd)
1648 {
1649         struct virtio_net *dev;
1650         struct VhostUserMsg msg;
1651         struct rte_vdpa_device *vdpa_dev;
1652         int did = -1;
1653         int ret;
1654         int unlock_required = 0;
1655         uint32_t skip_master = 0;
1656         int request;
1657
1658         dev = get_device(vid);
1659         if (dev == NULL)
1660                 return -1;
1661
1662         if (!dev->notify_ops) {
1663                 dev->notify_ops = vhost_driver_callback_get(dev->ifname);
1664                 if (!dev->notify_ops) {
1665                         RTE_LOG(ERR, VHOST_CONFIG,
1666                                 "failed to get callback ops for driver %s\n",
1667                                 dev->ifname);
1668                         return -1;
1669                 }
1670         }
1671
1672         ret = read_vhost_message(fd, &msg);
1673         if (ret <= 0 || msg.request.master >= VHOST_USER_MAX) {
1674                 if (ret < 0)
1675                         RTE_LOG(ERR, VHOST_CONFIG,
1676                                 "vhost read message failed\n");
1677                 else if (ret == 0)
1678                         RTE_LOG(INFO, VHOST_CONFIG,
1679                                 "vhost peer closed\n");
1680                 else
1681                         RTE_LOG(ERR, VHOST_CONFIG,
1682                                 "vhost read incorrect message\n");
1683
1684                 return -1;
1685         }
1686
1687         ret = 0;
1688         if (msg.request.master != VHOST_USER_IOTLB_MSG)
1689                 RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
1690                         vhost_message_str[msg.request.master]);
1691         else
1692                 RTE_LOG(DEBUG, VHOST_CONFIG, "read message %s\n",
1693                         vhost_message_str[msg.request.master]);
1694
1695         ret = vhost_user_check_and_alloc_queue_pair(dev, &msg);
1696         if (ret < 0) {
1697                 RTE_LOG(ERR, VHOST_CONFIG,
1698                         "failed to alloc queue\n");
1699                 return -1;
1700         }
1701
1702         /*
1703          * Note: we don't lock all queues on VHOST_USER_GET_VRING_BASE
1704          * and VHOST_USER_RESET_OWNER, since it is sent when virtio stops
1705          * and device is destroyed. destroy_device waits for queues to be
1706          * inactive, so it is safe. Otherwise taking the access_lock
1707          * would cause a dead lock.
1708          */
1709         switch (msg.request.master) {
1710         case VHOST_USER_SET_FEATURES:
1711         case VHOST_USER_SET_PROTOCOL_FEATURES:
1712         case VHOST_USER_SET_OWNER:
1713         case VHOST_USER_SET_MEM_TABLE:
1714         case VHOST_USER_SET_LOG_BASE:
1715         case VHOST_USER_SET_LOG_FD:
1716         case VHOST_USER_SET_VRING_NUM:
1717         case VHOST_USER_SET_VRING_ADDR:
1718         case VHOST_USER_SET_VRING_BASE:
1719         case VHOST_USER_SET_VRING_KICK:
1720         case VHOST_USER_SET_VRING_CALL:
1721         case VHOST_USER_SET_VRING_ERR:
1722         case VHOST_USER_SET_VRING_ENABLE:
1723         case VHOST_USER_SEND_RARP:
1724         case VHOST_USER_NET_SET_MTU:
1725         case VHOST_USER_SET_SLAVE_REQ_FD:
1726                 vhost_user_lock_all_queue_pairs(dev);
1727                 unlock_required = 1;
1728                 break;
1729         default:
1730                 break;
1731
1732         }
1733
1734         if (dev->extern_ops.pre_msg_handle) {
1735                 ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
1736                                 (void *)&msg, &skip_master);
1737                 if (ret == VH_RESULT_ERR)
1738                         goto skip_to_reply;
1739                 else if (ret == VH_RESULT_REPLY)
1740                         send_vhost_reply(fd, &msg);
1741
1742                 if (skip_master)
1743                         goto skip_to_post_handle;
1744         }
1745
1746         request = msg.request.master;
1747         if (request > VHOST_USER_NONE && request < VHOST_USER_MAX) {
1748                 if (!vhost_message_handlers[request])
1749                         goto skip_to_post_handle;
1750                 ret = vhost_message_handlers[request](&dev, &msg);
1751
1752                 switch (ret) {
1753                 case VH_RESULT_ERR:
1754                         RTE_LOG(ERR, VHOST_CONFIG,
1755                                 "Processing %s failed.\n",
1756                                 vhost_message_str[request]);
1757                         break;
1758                 case VH_RESULT_OK:
1759                         RTE_LOG(DEBUG, VHOST_CONFIG,
1760                                 "Processing %s succeeded.\n",
1761                                 vhost_message_str[request]);
1762                         break;
1763                 case VH_RESULT_REPLY:
1764                         RTE_LOG(DEBUG, VHOST_CONFIG,
1765                                 "Processing %s succeeded and needs reply.\n",
1766                                 vhost_message_str[request]);
1767                         send_vhost_reply(fd, &msg);
1768                         break;
1769                 }
1770         } else {
1771                 RTE_LOG(ERR, VHOST_CONFIG,
1772                         "Requested invalid message type %d.\n", request);
1773                 ret = VH_RESULT_ERR;
1774         }
1775
1776 skip_to_post_handle:
1777         if (ret != VH_RESULT_ERR && dev->extern_ops.post_msg_handle) {
1778                 ret = (*dev->extern_ops.post_msg_handle)(
1779                                 dev->vid, (void *)&msg);
1780                 if (ret == VH_RESULT_ERR)
1781                         goto skip_to_reply;
1782                 else if (ret == VH_RESULT_REPLY)
1783                         send_vhost_reply(fd, &msg);
1784         }
1785
1786 skip_to_reply:
1787         if (unlock_required)
1788                 vhost_user_unlock_all_queue_pairs(dev);
1789
1790         /*
1791          * If the request required a reply that was already sent,
1792          * this optional reply-ack won't be sent as the
1793          * VHOST_USER_NEED_REPLY was cleared in send_vhost_reply().
1794          */
1795         if (msg.flags & VHOST_USER_NEED_REPLY) {
1796                 msg.payload.u64 = ret == VH_RESULT_ERR;
1797                 msg.size = sizeof(msg.payload.u64);
1798                 send_vhost_reply(fd, &msg);
1799         } else if (ret == VH_RESULT_ERR) {
1800                 RTE_LOG(ERR, VHOST_CONFIG,
1801                         "vhost message handling failed.\n");
1802                 return -1;
1803         }
1804
1805         if (!(dev->flags & VIRTIO_DEV_RUNNING) && virtio_is_ready(dev)) {
1806                 dev->flags |= VIRTIO_DEV_READY;
1807
1808                 if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
1809                         if (dev->dequeue_zero_copy) {
1810                                 RTE_LOG(INFO, VHOST_CONFIG,
1811                                                 "dequeue zero copy is enabled\n");
1812                         }
1813
1814                         if (dev->notify_ops->new_device(dev->vid) == 0)
1815                                 dev->flags |= VIRTIO_DEV_RUNNING;
1816                 }
1817         }
1818
1819         did = dev->vdpa_dev_id;
1820         vdpa_dev = rte_vdpa_get_device(did);
1821         if (vdpa_dev && virtio_is_ready(dev) &&
1822                         !(dev->flags & VIRTIO_DEV_VDPA_CONFIGURED) &&
1823                         msg.request.master == VHOST_USER_SET_VRING_ENABLE) {
1824                 if (vdpa_dev->ops->dev_conf)
1825                         vdpa_dev->ops->dev_conf(dev->vid);
1826                 dev->flags |= VIRTIO_DEV_VDPA_CONFIGURED;
1827                 if (vhost_user_host_notifier_ctrl(dev->vid, true) != 0) {
1828                         RTE_LOG(INFO, VHOST_CONFIG,
1829                                 "(%d) software relay is used for vDPA, performance may be low.\n",
1830                                 dev->vid);
1831                 }
1832         }
1833
1834         return 0;
1835 }
1836
1837 static int process_slave_message_reply(struct virtio_net *dev,
1838                                        const struct VhostUserMsg *msg)
1839 {
1840         struct VhostUserMsg msg_reply;
1841         int ret;
1842
1843         if ((msg->flags & VHOST_USER_NEED_REPLY) == 0)
1844                 return 0;
1845
1846         if (read_vhost_message(dev->slave_req_fd, &msg_reply) < 0) {
1847                 ret = -1;
1848                 goto out;
1849         }
1850
1851         if (msg_reply.request.slave != msg->request.slave) {
1852                 RTE_LOG(ERR, VHOST_CONFIG,
1853                         "Received unexpected msg type (%u), expected %u\n",
1854                         msg_reply.request.slave, msg->request.slave);
1855                 ret = -1;
1856                 goto out;
1857         }
1858
1859         ret = msg_reply.payload.u64 ? -1 : 0;
1860
1861 out:
1862         rte_spinlock_unlock(&dev->slave_req_lock);
1863         return ret;
1864 }
1865
1866 int
1867 vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm)
1868 {
1869         int ret;
1870         struct VhostUserMsg msg = {
1871                 .request.slave = VHOST_USER_SLAVE_IOTLB_MSG,
1872                 .flags = VHOST_USER_VERSION,
1873                 .size = sizeof(msg.payload.iotlb),
1874                 .payload.iotlb = {
1875                         .iova = iova,
1876                         .perm = perm,
1877                         .type = VHOST_IOTLB_MISS,
1878                 },
1879         };
1880
1881         ret = send_vhost_message(dev->slave_req_fd, &msg, NULL, 0);
1882         if (ret < 0) {
1883                 RTE_LOG(ERR, VHOST_CONFIG,
1884                                 "Failed to send IOTLB miss message (%d)\n",
1885                                 ret);
1886                 return ret;
1887         }
1888
1889         return 0;
1890 }
1891
1892 static int vhost_user_slave_set_vring_host_notifier(struct virtio_net *dev,
1893                                                     int index, int fd,
1894                                                     uint64_t offset,
1895                                                     uint64_t size)
1896 {
1897         int *fdp = NULL;
1898         size_t fd_num = 0;
1899         int ret;
1900         struct VhostUserMsg msg = {
1901                 .request.slave = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1902                 .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY,
1903                 .size = sizeof(msg.payload.area),
1904                 .payload.area = {
1905                         .u64 = index & VHOST_USER_VRING_IDX_MASK,
1906                         .size = size,
1907                         .offset = offset,
1908                 },
1909         };
1910
1911         if (fd < 0)
1912                 msg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1913         else {
1914                 fdp = &fd;
1915                 fd_num = 1;
1916         }
1917
1918         ret = send_vhost_slave_message(dev, &msg, fdp, fd_num);
1919         if (ret < 0) {
1920                 RTE_LOG(ERR, VHOST_CONFIG,
1921                         "Failed to set host notifier (%d)\n", ret);
1922                 return ret;
1923         }
1924
1925         return process_slave_message_reply(dev, &msg);
1926 }
1927
1928 int vhost_user_host_notifier_ctrl(int vid, bool enable)
1929 {
1930         struct virtio_net *dev;
1931         struct rte_vdpa_device *vdpa_dev;
1932         int vfio_device_fd, did, ret = 0;
1933         uint64_t offset, size;
1934         unsigned int i;
1935
1936         dev = get_device(vid);
1937         if (!dev)
1938                 return -ENODEV;
1939
1940         did = dev->vdpa_dev_id;
1941         if (did < 0)
1942                 return -EINVAL;
1943
1944         if (!(dev->features & (1ULL << VIRTIO_F_VERSION_1)) ||
1945             !(dev->features & (1ULL << VHOST_USER_F_PROTOCOL_FEATURES)) ||
1946             !(dev->protocol_features &
1947                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) ||
1948             !(dev->protocol_features &
1949                         (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) ||
1950             !(dev->protocol_features &
1951                         (1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER)))
1952                 return -ENOTSUP;
1953
1954         vdpa_dev = rte_vdpa_get_device(did);
1955         if (!vdpa_dev)
1956                 return -ENODEV;
1957
1958         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_vfio_device_fd, -ENOTSUP);
1959         RTE_FUNC_PTR_OR_ERR_RET(vdpa_dev->ops->get_notify_area, -ENOTSUP);
1960
1961         vfio_device_fd = vdpa_dev->ops->get_vfio_device_fd(vid);
1962         if (vfio_device_fd < 0)
1963                 return -ENOTSUP;
1964
1965         if (enable) {
1966                 for (i = 0; i < dev->nr_vring; i++) {
1967                         if (vdpa_dev->ops->get_notify_area(vid, i, &offset,
1968                                         &size) < 0) {
1969                                 ret = -ENOTSUP;
1970                                 goto disable;
1971                         }
1972
1973                         if (vhost_user_slave_set_vring_host_notifier(dev, i,
1974                                         vfio_device_fd, offset, size) < 0) {
1975                                 ret = -EFAULT;
1976                                 goto disable;
1977                         }
1978                 }
1979         } else {
1980 disable:
1981                 for (i = 0; i < dev->nr_vring; i++) {
1982                         vhost_user_slave_set_vring_host_notifier(dev, i, -1,
1983                                         0, 0);
1984                 }
1985         }
1986
1987         return ret;
1988 }