vhost: fix FD leak with inflight messages
[dpdk.git] / lib / vhost / vhost_user.c
index 8ee9c3e..1d39067 100644 (file)
@@ -1602,6 +1602,9 @@ vhost_user_get_inflight_fd(struct virtio_net **pdev,
        int numa_node = SOCKET_ID_ANY;
        void *addr;
 
+       if (validate_msg_fds(dev, ctx, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (ctx->msg.size != sizeof(ctx->msg.payload.inflight)) {
                VHOST_LOG_CONFIG(ERR, "(%s) invalid get_inflight_fd message size is %d\n",
                        dev->ifname, ctx->msg.size);
@@ -1699,6 +1702,9 @@ vhost_user_set_inflight_fd(struct virtio_net **pdev,
        int fd, i;
        int numa_node = SOCKET_ID_ANY;
 
+       if (validate_msg_fds(dev, ctx, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        fd = ctx->fds[0];
        if (ctx->msg.size != sizeof(ctx->msg.payload.inflight) || fd < 0) {
                VHOST_LOG_CONFIG(ERR, "(%s) invalid set_inflight_fd message size is %d,fd is %d\n",
@@ -2564,8 +2570,11 @@ vhost_user_iotlb_msg(struct virtio_net **pdev,
                        vhost_user_iotlb_cache_insert(dev, vq, imsg->iova, vva,
                                        len, imsg->perm);
 
-                       if (is_vring_iotlb(dev, vq, imsg))
+                       if (is_vring_iotlb(dev, vq, imsg)) {
+                               rte_spinlock_lock(&vq->access_lock);
                                *pdev = dev = translate_ring_addresses(dev, i);
+                               rte_spinlock_unlock(&vq->access_lock);
+                       }
                }
                break;
        case VHOST_IOTLB_INVALIDATE:
@@ -2578,8 +2587,11 @@ vhost_user_iotlb_msg(struct virtio_net **pdev,
                        vhost_user_iotlb_cache_remove(vq, imsg->iova,
                                        imsg->size);
 
-                       if (is_vring_iotlb(dev, vq, imsg))
+                       if (is_vring_iotlb(dev, vq, imsg)) {
+                               rte_spinlock_lock(&vq->access_lock);
                                vring_invalidate(dev, vq);
+                               rte_spinlock_unlock(&vq->access_lock);
+                       }
                }
                break;
        default:
@@ -2877,6 +2889,9 @@ vhost_user_check_and_alloc_queue_pair(struct virtio_net *dev,
        case VHOST_USER_SET_VRING_ADDR:
                vring_idx = ctx->msg.payload.addr.index;
                break;
+       case VHOST_USER_SET_INFLIGHT_FD:
+               vring_idx = ctx->msg.payload.inflight.num_queues - 1;
+               break;
        default:
                return 0;
        }
@@ -3017,8 +3032,8 @@ vhost_user_msg_handler(int vid, int fd)
 
        handled = false;
        if (dev->extern_ops.pre_msg_handle) {
-               ret = (*dev->extern_ops.pre_msg_handle)(dev->vid,
-                               (void *)&ctx.msg);
+               RTE_BUILD_BUG_ON(offsetof(struct vhu_msg_context, msg) != 0);
+               ret = (*dev->extern_ops.pre_msg_handle)(dev->vid, &ctx);
                switch (ret) {
                case RTE_VHOST_MSG_RESULT_REPLY:
                        send_vhost_reply(dev, fd, &ctx);
@@ -3063,8 +3078,8 @@ vhost_user_msg_handler(int vid, int fd)
 skip_to_post_handle:
        if (ret != RTE_VHOST_MSG_RESULT_ERR &&
                        dev->extern_ops.post_msg_handle) {
-               ret = (*dev->extern_ops.post_msg_handle)(dev->vid,
-                               (void *)&ctx.msg);
+               RTE_BUILD_BUG_ON(offsetof(struct vhu_msg_context, msg) != 0);
+               ret = (*dev->extern_ops.post_msg_handle)(dev->vid, &ctx);
                switch (ret) {
                case RTE_VHOST_MSG_RESULT_REPLY:
                        send_vhost_reply(dev, fd, &ctx);