vhost: fix possible denial of service by leaking FDs
[dpdk.git] / lib / librte_vhost / vhost_user.c
index 2a9fa7c..90ecee1 100644 (file)
@@ -92,6 +92,36 @@ static const char *vhost_message_str[VHOST_USER_MAX] = {
 static int send_vhost_reply(int sockfd, struct VhostUserMsg *msg);
 static int read_vhost_message(int sockfd, struct VhostUserMsg *msg);
 
+static void
+close_msg_fds(struct VhostUserMsg *msg)
+{
+       int i;
+
+       for (i = 0; i < msg->fd_num; i++)
+               close(msg->fds[i]);
+}
+
+/*
+ * Ensure the expected number of FDs is received,
+ * close all FDs and return an error if this is not the case.
+ */
+static int
+validate_msg_fds(struct VhostUserMsg *msg, int expected_fds)
+{
+       if (msg->fd_num == expected_fds)
+               return 0;
+
+       RTE_LOG(ERR, VHOST_CONFIG,
+               " Expect %d FDs for request %s, received %d\n",
+               expected_fds,
+               vhost_message_str[msg->request.master],
+               msg->fd_num);
+
+       close_msg_fds(msg);
+
+       return -1;
+}
+
 static uint64_t
 get_blk_size(int fd)
 {
@@ -204,18 +234,25 @@ vhost_backend_cleanup(struct virtio_net *dev)
  */
 static int
 vhost_user_set_owner(struct virtio_net **pdev __rte_unused,
-                       struct VhostUserMsg *msg __rte_unused,
+                       struct VhostUserMsg *msg,
                        int main_fd __rte_unused)
 {
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        return RTE_VHOST_MSG_RESULT_OK;
 }
 
 static int
 vhost_user_reset_owner(struct virtio_net **pdev,
-                       struct VhostUserMsg *msg __rte_unused,
+                       struct VhostUserMsg *msg,
                        int main_fd __rte_unused)
 {
        struct virtio_net *dev = *pdev;
+
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        vhost_destroy_device_notify(dev);
 
        cleanup_device(dev, 0);
@@ -233,6 +270,9 @@ vhost_user_get_features(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct virtio_net *dev = *pdev;
        uint64_t features = 0;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        rte_vhost_driver_get_features(dev->ifname, &features);
 
        msg->payload.u64 = features;
@@ -252,6 +292,9 @@ vhost_user_get_queue_num(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct virtio_net *dev = *pdev;
        uint32_t queue_num = 0;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        rte_vhost_driver_get_queue_num(dev->ifname, &queue_num);
 
        msg->payload.u64 = (uint64_t)queue_num;
@@ -274,6 +317,9 @@ vhost_user_set_features(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct rte_vdpa_device *vdpa_dev;
        int did = -1;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        rte_vhost_driver_get_features(dev->ifname, &vhost_features);
        if (features & ~vhost_features) {
                RTE_LOG(ERR, VHOST_CONFIG,
@@ -357,14 +403,29 @@ vhost_user_set_vring_num(struct virtio_net **pdev,
        struct virtio_net *dev = *pdev;
        struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        vq->size = msg->payload.state.num;
 
        /* VIRTIO 1.0, 2.4 Virtqueues says:
         *
         *   Queue Size value is always a power of 2. The maximum Queue Size
         *   value is 32768.
+        *
+        * VIRTIO 1.1 2.7 Virtqueues says:
+        *
+        *   Packed virtqueues support up to 2^15 entries each.
         */
-       if ((vq->size & (vq->size - 1)) || vq->size > 32768) {
+       if (!vq_is_packed(dev)) {
+               if (vq->size & (vq->size - 1)) {
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "invalid virtqueue size %u\n", vq->size);
+                       return RTE_VHOST_MSG_RESULT_ERR;
+               }
+       }
+
+       if (vq->size > 32768) {
                RTE_LOG(ERR, VHOST_CONFIG,
                        "invalid virtqueue size %u\n", vq->size);
                return RTE_VHOST_MSG_RESULT_ERR;
@@ -374,6 +435,8 @@ vhost_user_set_vring_num(struct virtio_net **pdev,
                vq->nr_zmbuf = 0;
                vq->last_zmbuf_idx = 0;
                vq->zmbuf_size = vq->size;
+               if (vq->zmbufs)
+                       rte_free(vq->zmbufs);
                vq->zmbufs = rte_zmalloc(NULL, vq->zmbuf_size *
                                         sizeof(struct zcopy_mbuf), 0);
                if (vq->zmbufs == NULL) {
@@ -386,6 +449,8 @@ vhost_user_set_vring_num(struct virtio_net **pdev,
        }
 
        if (vq_is_packed(dev)) {
+               if (vq->shadow_used_packed)
+                       rte_free(vq->shadow_used_packed);
                vq->shadow_used_packed = rte_malloc(NULL,
                                vq->size *
                                sizeof(struct vring_used_elem_packed),
@@ -397,6 +462,8 @@ vhost_user_set_vring_num(struct virtio_net **pdev,
                }
 
        } else {
+               if (vq->shadow_used_split)
+                       rte_free(vq->shadow_used_split);
                vq->shadow_used_split = rte_malloc(NULL,
                                vq->size * sizeof(struct vring_used_elem),
                                RTE_CACHE_LINE_SIZE);
@@ -407,6 +474,8 @@ vhost_user_set_vring_num(struct virtio_net **pdev,
                }
        }
 
+       if (vq->batch_copy_elems)
+               rte_free(vq->batch_copy_elems);
        vq->batch_copy_elems = rte_malloc(NULL,
                                vq->size * sizeof(struct batch_copy_elem),
                                RTE_CACHE_LINE_SIZE);
@@ -641,11 +710,21 @@ translate_ring_addresses(struct virtio_net *dev, int vq_index)
        struct vhost_vring_addr *addr = &vq->ring_addrs;
        uint64_t len, expected_len;
 
+       if (addr->flags & (1 << VHOST_VRING_F_LOG)) {
+               vq->log_guest_addr =
+                       translate_log_addr(dev, vq, addr->log_guest_addr);
+               if (vq->log_guest_addr == 0) {
+                       RTE_LOG(DEBUG, VHOST_CONFIG,
+                               "(%d) failed to map log_guest_addr.\n",
+                               dev->vid);
+                       return dev;
+               }
+       }
+
        if (vq_is_packed(dev)) {
                len = sizeof(struct vring_packed_desc) * vq->size;
                vq->desc_packed = (struct vring_packed_desc *)(uintptr_t)
                        ring_addr_to_vva(dev, vq, addr->desc_user_addr, &len);
-               vq->log_guest_addr = 0;
                if (vq->desc_packed == NULL ||
                                len != sizeof(struct vring_packed_desc) *
                                vq->size) {
@@ -741,14 +820,6 @@ translate_ring_addresses(struct virtio_net *dev, int vq_index)
                vq->last_avail_idx = vq->used->idx;
        }
 
-       vq->log_guest_addr =
-               translate_log_addr(dev, vq, addr->log_guest_addr);
-       if (vq->log_guest_addr == 0) {
-               RTE_LOG(DEBUG, VHOST_CONFIG,
-                       "(%d) failed to map log_guest_addr .\n",
-                       dev->vid);
-               return dev;
-       }
        vq->access_ok = 1;
 
        VHOST_LOG_DEBUG(VHOST_CONFIG, "(%d) mapped address desc: %p\n",
@@ -776,6 +847,9 @@ vhost_user_set_vring_addr(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct vhost_vring_addr *addr = &msg->payload.addr;
        bool access_ok;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (dev->mem == NULL)
                return RTE_VHOST_MSG_RESULT_ERR;
 
@@ -817,6 +891,9 @@ vhost_user_set_vring_base(struct virtio_net **pdev,
        struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
        uint64_t val = msg->payload.state.num;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (vq_is_packed(dev)) {
                /*
                 * Bit[0:14]: avail index
@@ -978,6 +1055,9 @@ vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *msg,
        int populate;
        int fd;
 
+       if (validate_msg_fds(msg, memory->nregions) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (memory->nregions > VHOST_MEMORY_MAX_NREGIONS) {
                RTE_LOG(ERR, VHOST_CONFIG,
                        "too many memory regions (%u)\n", memory->nregions);
@@ -988,8 +1068,7 @@ vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *msg,
                RTE_LOG(INFO, VHOST_CONFIG,
                        "(%d) memory regions not changed\n", dev->vid);
 
-               for (i = 0; i < memory->nregions; i++)
-                       close(msg->fds[i]);
+               close_msg_fds(msg);
 
                return RTE_VHOST_MSG_RESULT_OK;
        }
@@ -1132,6 +1211,10 @@ vhost_user_set_mem_table(struct virtio_net **pdev, struct VhostUserMsg *msg,
                                "Failed to read qemu ack on postcopy set-mem-table\n");
                        goto err_mmap;
                }
+
+               if (validate_msg_fds(&ack_msg, 0) != 0)
+                       goto err_mmap;
+
                if (ack_msg.request.master != VHOST_USER_SET_MEM_TABLE) {
                        RTE_LOG(ERR, VHOST_CONFIG,
                                "Bad qemu ack on postcopy set-mem-table (%d)\n",
@@ -1481,6 +1564,9 @@ vhost_user_set_vring_call(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct vhost_vring_file file;
        struct vhost_virtqueue *vq;
 
+       if (validate_msg_fds(msg, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
        if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
                file.fd = VIRTIO_INVALID_EVENTFD;
@@ -1502,6 +1588,9 @@ static int vhost_user_set_vring_err(struct virtio_net **pdev __rte_unused,
                        struct VhostUserMsg *msg,
                        int main_fd __rte_unused)
 {
+       if (validate_msg_fds(msg, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (!(msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK))
                close(msg->fds[0]);
        RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
@@ -1702,6 +1791,9 @@ vhost_user_set_vring_kick(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct vhost_vring_file file;
        struct vhost_virtqueue *vq;
 
+       if (validate_msg_fds(msg, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        file.index = msg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
        if (msg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)
                file.fd = VIRTIO_INVALID_EVENTFD;
@@ -1772,6 +1864,9 @@ vhost_user_get_vring_base(struct virtio_net **pdev,
        struct vhost_virtqueue *vq = dev->virtqueue[msg->payload.state.index];
        uint64_t val;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        /* We have to stop the queue (virtio) if it is running. */
        vhost_destroy_device_notify(dev);
 
@@ -1847,6 +1942,9 @@ vhost_user_set_vring_enable(struct virtio_net **pdev,
        struct rte_vdpa_device *vdpa_dev;
        int did = -1;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        RTE_LOG(INFO, VHOST_CONFIG,
                "set queue enable: %d to qp idx: %d\n",
                enable, index);
@@ -1877,6 +1975,9 @@ vhost_user_get_protocol_features(struct virtio_net **pdev,
        struct virtio_net *dev = *pdev;
        uint64_t features, protocol_features;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        rte_vhost_driver_get_features(dev->ifname, &features);
        rte_vhost_driver_get_protocol_features(dev->ifname, &protocol_features);
 
@@ -1905,6 +2006,9 @@ vhost_user_set_protocol_features(struct virtio_net **pdev,
        uint64_t protocol_features = msg->payload.u64;
        uint64_t slave_protocol_features = 0;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        rte_vhost_driver_get_protocol_features(dev->ifname,
                        &slave_protocol_features);
        if (protocol_features & ~slave_protocol_features) {
@@ -1931,6 +2035,9 @@ vhost_user_set_log_base(struct virtio_net **pdev, struct VhostUserMsg *msg,
        uint64_t size, off;
        void *addr;
 
+       if (validate_msg_fds(msg, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (fd < 0) {
                RTE_LOG(ERR, VHOST_CONFIG, "invalid log fd: %d\n", fd);
                return RTE_VHOST_MSG_RESULT_ERR;
@@ -1994,6 +2101,9 @@ static int vhost_user_set_log_fd(struct virtio_net **pdev __rte_unused,
                        struct VhostUserMsg *msg,
                        int main_fd __rte_unused)
 {
+       if (validate_msg_fds(msg, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        close(msg->fds[0]);
        RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
 
@@ -2017,6 +2127,9 @@ vhost_user_send_rarp(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct rte_vdpa_device *vdpa_dev;
        int did = -1;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        RTE_LOG(DEBUG, VHOST_CONFIG,
                ":: mac: %02x:%02x:%02x:%02x:%02x:%02x\n",
                mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]);
@@ -2044,6 +2157,10 @@ vhost_user_net_set_mtu(struct virtio_net **pdev, struct VhostUserMsg *msg,
                        int main_fd __rte_unused)
 {
        struct virtio_net *dev = *pdev;
+
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (msg->payload.u64 < VIRTIO_MIN_MTU ||
                        msg->payload.u64 > VIRTIO_MAX_MTU) {
                RTE_LOG(ERR, VHOST_CONFIG, "Invalid MTU size (%"PRIu64")\n",
@@ -2064,6 +2181,9 @@ vhost_user_set_req_fd(struct virtio_net **pdev, struct VhostUserMsg *msg,
        struct virtio_net *dev = *pdev;
        int fd = msg->fds[0];
 
+       if (validate_msg_fds(msg, 1) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (fd < 0) {
                RTE_LOG(ERR, VHOST_CONFIG,
                                "Invalid file descriptor for slave channel (%d)\n",
@@ -2149,6 +2269,9 @@ vhost_user_iotlb_msg(struct virtio_net **pdev, struct VhostUserMsg *msg,
        uint16_t i;
        uint64_t vva, len;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        switch (imsg->type) {
        case VHOST_IOTLB_UPDATE:
                len = imsg->size;
@@ -2195,6 +2318,9 @@ vhost_user_set_postcopy_advise(struct virtio_net **pdev,
 #ifdef RTE_LIBRTE_VHOST_POSTCOPY
        struct uffdio_api api_struct;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
 
        if (dev->postcopy_ufd == -1) {
@@ -2230,6 +2356,9 @@ vhost_user_set_postcopy_listen(struct virtio_net **pdev,
 {
        struct virtio_net *dev = *pdev;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        if (dev->mem && dev->mem->nregions) {
                RTE_LOG(ERR, VHOST_CONFIG,
                        "Regions already registered at postcopy-listen\n");
@@ -2246,6 +2375,9 @@ vhost_user_postcopy_end(struct virtio_net **pdev, struct VhostUserMsg *msg,
 {
        struct virtio_net *dev = *pdev;
 
+       if (validate_msg_fds(msg, 0) != 0)
+               return RTE_VHOST_MSG_RESULT_ERR;
+
        dev->postcopy_listening = 0;
        if (dev->postcopy_ufd >= 0) {
                close(dev->postcopy_ufd);
@@ -2599,6 +2731,7 @@ skip_to_post_handle:
        if (!handled) {
                RTE_LOG(ERR, VHOST_CONFIG,
                        "vhost message (req: %d) was not handled.\n", request);
+               close_msg_fds(&msg);
                ret = RTE_VHOST_MSG_RESULT_ERR;
        }