vhost: make notify ops per vhost driver
[dpdk.git] / lib / librte_vhost / vhost_user.c
index 640661b..2083640 100644 (file)
@@ -135,7 +135,7 @@ vhost_user_reset_owner(struct virtio_net *dev)
 {
        if (dev->flags & VIRTIO_DEV_RUNNING) {
                dev->flags &= ~VIRTIO_DEV_RUNNING;
-               notify_ops->destroy_device(dev->vid);
+               dev->notify_ops->destroy_device(dev->vid);
        }
 
        cleanup_device(dev, 0);
@@ -147,9 +147,12 @@ vhost_user_reset_owner(struct virtio_net *dev)
  * The features that we support are requested.
  */
 static uint64_t
-vhost_user_get_features(void)
+vhost_user_get_features(struct virtio_net *dev)
 {
-       return VHOST_FEATURES;
+       uint64_t features = 0;
+
+       rte_vhost_driver_get_features(dev->ifname, &features);
+       return features;
 }
 
 /*
@@ -158,7 +161,10 @@ vhost_user_get_features(void)
 static int
 vhost_user_set_features(struct virtio_net *dev, uint64_t features)
 {
-       if (features & ~VHOST_FEATURES)
+       uint64_t vhost_features = 0;
+
+       rte_vhost_driver_get_features(dev->ifname, &vhost_features);
+       if (features & ~vhost_features)
                return -1;
 
        dev->features = features;
@@ -503,7 +509,7 @@ vhost_user_set_mem_table(struct virtio_net *dev, struct VhostUserMsg *pmsg)
        /* Remove from the data plane. */
        if (dev->flags & VIRTIO_DEV_RUNNING) {
                dev->flags &= ~VIRTIO_DEV_RUNNING;
-               notify_ops->destroy_device(dev->vid);
+               dev->notify_ops->destroy_device(dev->vid);
        }
 
        if (dev->mem) {
@@ -678,14 +684,18 @@ vhost_user_set_vring_kick(struct virtio_net *dev, struct VhostUserMsg *pmsg)
                close(vq->kickfd);
        vq->kickfd = file.fd;
 
-       if (virtio_is_ready(dev) && !(dev->flags & VIRTIO_DEV_RUNNING)) {
-               if (dev->dequeue_zero_copy) {
-                       RTE_LOG(INFO, VHOST_CONFIG,
-                               "dequeue zero copy is enabled\n");
-               }
+       if (virtio_is_ready(dev)) {
+               dev->flags |= VIRTIO_DEV_READY;
+
+               if (!(dev->flags & VIRTIO_DEV_RUNNING)) {
+                       if (dev->dequeue_zero_copy) {
+                               RTE_LOG(INFO, VHOST_CONFIG,
+                                               "dequeue zero copy is enabled\n");
+                       }
 
-               if (notify_ops->new_device(dev->vid) == 0)
-                       dev->flags |= VIRTIO_DEV_RUNNING;
+                       if (dev->notify_ops->new_device(dev->vid) == 0)
+                               dev->flags |= VIRTIO_DEV_RUNNING;
+               }
        }
 }
 
@@ -717,9 +727,11 @@ vhost_user_get_vring_base(struct virtio_net *dev,
        /* We have to stop the queue (virtio) if it is running. */
        if (dev->flags & VIRTIO_DEV_RUNNING) {
                dev->flags &= ~VIRTIO_DEV_RUNNING;
-               notify_ops->destroy_device(dev->vid);
+               dev->notify_ops->destroy_device(dev->vid);
        }
 
+       dev->flags &= ~VIRTIO_DEV_READY;
+
        /* Here we are safe to get the last used index */
        state->num = vq->last_used_idx;
 
@@ -757,8 +769,8 @@ vhost_user_set_vring_enable(struct virtio_net *dev,
                "set queue enable: %d to qp idx: %d\n",
                enable, state->index);
 
-       if (notify_ops->vring_state_changed)
-               notify_ops->vring_state_changed(dev->vid, state->index, enable);
+       if (dev->notify_ops->vring_state_changed)
+               dev->notify_ops->vring_state_changed(dev->vid, state->index, enable);
 
        dev->virtqueue[state->index]->enabled = enable;
 
@@ -972,6 +984,16 @@ vhost_user_msg_handler(int vid, int fd)
        if (dev == NULL)
                return -1;
 
+       if (!dev->notify_ops) {
+               dev->notify_ops = vhost_driver_callback_get(dev->ifname);
+               if (!dev->notify_ops) {
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "failed to get callback ops for driver %s\n",
+                               dev->ifname);
+                       return -1;
+               }
+       }
+
        ret = read_vhost_message(fd, &msg);
        if (ret <= 0 || msg.request >= VHOST_USER_MAX) {
                if (ret < 0)
@@ -1000,7 +1022,7 @@ vhost_user_msg_handler(int vid, int fd)
 
        switch (msg.request) {
        case VHOST_USER_GET_FEATURES:
-               msg.payload.u64 = vhost_user_get_features();
+               msg.payload.u64 = vhost_user_get_features(dev);
                msg.size = sizeof(msg.payload.u64);
                send_vhost_message(fd, &msg);
                break;