ipc: handle more invalid parameter cases
[dpdk.git] / lib / librte_vhost / vhost_crypto.c
index dd01afc..0edf12d 100644 (file)
@@ -198,6 +198,7 @@ struct vhost_crypto {
        struct rte_hash *session_map;
        struct rte_mempool *mbuf_pool;
        struct rte_mempool *sess_pool;
+       struct rte_mempool *sess_priv_pool;
        struct rte_mempool *wb_pool;
 
        /** DPDK cryptodev ID */
@@ -369,7 +370,7 @@ vhost_crypto_create_sess(struct vhost_crypto *vcrypto,
        }
 
        if (rte_cryptodev_sym_session_init(vcrypto->cid, session, &xform1,
-                       vcrypto->sess_pool) < 0) {
+                       vcrypto->sess_priv_pool) < 0) {
                VC_LOG_ERR("Failed to initialize session");
                sess_param->session_id = -VIRTIO_CRYPTO_ERR;
                return;
@@ -433,45 +434,56 @@ vhost_crypto_close_sess(struct vhost_crypto *vcrypto, uint64_t session_id)
        return 0;
 }
 
-static enum vh_result
+static enum rte_vhost_msg_result
 vhost_crypto_msg_post_handler(int vid, void *msg)
 {
        struct virtio_net *dev = get_device(vid);
        struct vhost_crypto *vcrypto;
        VhostUserMsg *vmsg = msg;
-       enum vh_result ret = VH_RESULT_OK;
+       enum rte_vhost_msg_result ret = RTE_VHOST_MSG_RESULT_OK;
 
        if (dev == NULL) {
                VC_LOG_ERR("Invalid vid %i", vid);
-               return VH_RESULT_ERR;
+               return RTE_VHOST_MSG_RESULT_ERR;
        }
 
        vcrypto = dev->extern_data;
        if (vcrypto == NULL) {
                VC_LOG_ERR("Cannot find required data, is it initialized?");
-               return VH_RESULT_ERR;
+               return RTE_VHOST_MSG_RESULT_ERR;
        }
 
-       if (vmsg->request.master == VHOST_USER_CRYPTO_CREATE_SESS) {
+       switch (vmsg->request.master) {
+       case VHOST_USER_CRYPTO_CREATE_SESS:
                vhost_crypto_create_sess(vcrypto,
                                &vmsg->payload.crypto_session);
                vmsg->fd_num = 0;
-               ret = VH_RESULT_REPLY;
-       } else if (vmsg->request.master == VHOST_USER_CRYPTO_CLOSE_SESS) {
+               ret = RTE_VHOST_MSG_RESULT_REPLY;
+               break;
+       case VHOST_USER_CRYPTO_CLOSE_SESS:
                if (vhost_crypto_close_sess(vcrypto, vmsg->payload.u64))
-                       ret = VH_RESULT_ERR;
+                       ret = RTE_VHOST_MSG_RESULT_ERR;
+               break;
+       default:
+               ret = RTE_VHOST_MSG_RESULT_NOT_HANDLED;
+               break;
        }
 
        return ret;
 }
 
 static __rte_always_inline struct vring_desc *
-find_write_desc(struct vring_desc *head, struct vring_desc *desc)
+find_write_desc(struct vring_desc *head, struct vring_desc *desc,
+               uint32_t *nb_descs, uint32_t vq_size)
 {
        if (desc->flags & VRING_DESC_F_WRITE)
                return desc;
 
        while (desc->flags & VRING_DESC_F_NEXT) {
+               if (unlikely(*nb_descs == 0 || desc->next >= vq_size))
+                       return NULL;
+               (*nb_descs)--;
+
                desc = &head[desc->next];
                if (desc->flags & VRING_DESC_F_WRITE)
                        return desc;
@@ -481,13 +493,18 @@ find_write_desc(struct vring_desc *head, struct vring_desc *desc)
 }
 
 static struct virtio_crypto_inhdr *
-reach_inhdr(struct vhost_crypto_data_req *vc_req, struct vring_desc *desc)
+reach_inhdr(struct vhost_crypto_data_req *vc_req, struct vring_desc *desc,
+               uint32_t *nb_descs, uint32_t vq_size)
 {
        uint64_t dlen;
        struct virtio_crypto_inhdr *inhdr;
 
-       while (desc->flags & VRING_DESC_F_NEXT)
+       while (desc->flags & VRING_DESC_F_NEXT) {
+               if (unlikely(*nb_descs == 0 || desc->next >= vq_size))
+                       return NULL;
+               (*nb_descs)--;
                desc = &vc_req->head[desc->next];
+       }
 
        dlen = desc->len;
        inhdr = IOVA_TO_VVA(struct virtio_crypto_inhdr *, vc_req, desc->addr,
@@ -500,15 +517,16 @@ reach_inhdr(struct vhost_crypto_data_req *vc_req, struct vring_desc *desc)
 
 static __rte_always_inline int
 move_desc(struct vring_desc *head, struct vring_desc **cur_desc,
-               uint32_t size)
+               uint32_t size, uint32_t *nb_descs, uint32_t vq_size)
 {
        struct vring_desc *desc = *cur_desc;
-       int left = size;
-
-       rte_prefetch0(&head[desc->next]);
-       left -= desc->len;
+       int left = size - desc->len;
 
        while ((desc->flags & VRING_DESC_F_NEXT) && left > 0) {
+               (*nb_descs)--;
+               if (unlikely(*nb_descs == 0 || desc->next >= vq_size))
+                       return -1;
+
                desc = &head[desc->next];
                rte_prefetch0(&head[desc->next]);
                left -= desc->len;
@@ -517,7 +535,14 @@ move_desc(struct vring_desc *head, struct vring_desc **cur_desc,
        if (unlikely(left > 0))
                return -1;
 
-       *cur_desc = &head[desc->next];
+       if (unlikely(*nb_descs == 0))
+               *cur_desc = NULL;
+       else {
+               if (unlikely(desc->next >= vq_size))
+                       return -1;
+               *cur_desc = &head[desc->next];
+       }
+
        return 0;
 }
 
@@ -539,7 +564,8 @@ get_data_ptr(struct vhost_crypto_data_req *vc_req, struct vring_desc *cur_desc,
 
 static int
 copy_data(void *dst_data, struct vhost_crypto_data_req *vc_req,
-               struct vring_desc **cur_desc, uint32_t size)
+               struct vring_desc **cur_desc, uint32_t size,
+               uint32_t *nb_descs, uint32_t vq_size)
 {
        struct vring_desc *desc = *cur_desc;
        uint64_t remain, addr, dlen, len;
@@ -548,7 +574,6 @@ copy_data(void *dst_data, struct vhost_crypto_data_req *vc_req,
        uint8_t *src;
        int left = size;
 
-       rte_prefetch0(&vc_req->head[desc->next]);
        to_copy = RTE_MIN(desc->len, (uint32_t)left);
        dlen = to_copy;
        src = IOVA_TO_VVA(uint8_t *, vc_req, desc->addr, &dlen,
@@ -582,6 +607,12 @@ copy_data(void *dst_data, struct vhost_crypto_data_req *vc_req,
        left -= to_copy;
 
        while ((desc->flags & VRING_DESC_F_NEXT) && left > 0) {
+               if (unlikely(*nb_descs == 0 || desc->next >= vq_size)) {
+                       VC_LOG_ERR("Invalid descriptors");
+                       return -1;
+               }
+               (*nb_descs)--;
+
                desc = &vc_req->head[desc->next];
                rte_prefetch0(&vc_req->head[desc->next]);
                to_copy = RTE_MIN(desc->len, (uint32_t)left);
@@ -624,7 +655,13 @@ copy_data(void *dst_data, struct vhost_crypto_data_req *vc_req,
                return -1;
        }
 
-       *cur_desc = &vc_req->head[desc->next];
+       if (unlikely(*nb_descs == 0))
+               *cur_desc = NULL;
+       else {
+               if (unlikely(desc->next >= vq_size))
+                       return -1;
+               *cur_desc = &vc_req->head[desc->next];
+       }
 
        return 0;
 }
@@ -635,7 +672,6 @@ write_back_data(struct vhost_crypto_data_req *vc_req)
        struct vhost_crypto_writeback_data *wb_data = vc_req->wb, *wb_last;
 
        while (wb_data) {
-               rte_prefetch0(wb_data->next);
                rte_memcpy(wb_data->dst, wb_data->src, wb_data->len);
                wb_last = wb_data;
                wb_data = wb_data->next;
@@ -684,7 +720,8 @@ prepare_write_back_data(struct vhost_crypto_data_req *vc_req,
                struct vhost_crypto_writeback_data **end_wb_data,
                uint8_t *src,
                uint32_t offset,
-               uint64_t write_back_len)
+               uint64_t write_back_len,
+               uint32_t *nb_descs, uint32_t vq_size)
 {
        struct vhost_crypto_writeback_data *wb_data, *head;
        struct vring_desc *desc = *cur_desc;
@@ -731,6 +768,12 @@ prepare_write_back_data(struct vhost_crypto_data_req *vc_req,
                offset -= desc->len;
 
        while (write_back_len) {
+               if (unlikely(*nb_descs == 0 || desc->next >= vq_size)) {
+                       VC_LOG_ERR("Invalid descriptors");
+                       goto error_exit;
+               }
+               (*nb_descs)--;
+
                desc = &vc_req->head[desc->next];
                if (unlikely(!(desc->flags & VRING_DESC_F_WRITE))) {
                        VC_LOG_ERR("incorrect descriptor");
@@ -770,7 +813,13 @@ prepare_write_back_data(struct vhost_crypto_data_req *vc_req,
                        wb_data->next = NULL;
        }
 
-       *cur_desc = &vc_req->head[desc->next];
+       if (unlikely(*nb_descs == 0))
+               *cur_desc = NULL;
+       else {
+               if (unlikely(desc->next >= vq_size))
+                       goto error_exit;
+               *cur_desc = &vc_req->head[desc->next];
+       }
 
        *end_wb_data = wb_data;
 
@@ -787,7 +836,8 @@ static uint8_t
 prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                struct vhost_crypto_data_req *vc_req,
                struct virtio_crypto_cipher_data_req *cipher,
-               struct vring_desc *cur_desc)
+               struct vring_desc *cur_desc,
+               uint32_t *nb_descs, uint32_t vq_size)
 {
        struct vring_desc *desc = cur_desc;
        struct vhost_crypto_writeback_data *ewb = NULL;
@@ -797,8 +847,8 @@ prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
 
        /* prepare */
        /* iv */
-       if (unlikely(copy_data(iv_data, vc_req, &desc,
-                       cipher->para.iv_len) < 0)) {
+       if (unlikely(copy_data(iv_data, vc_req, &desc, cipher->para.iv_len,
+                       nb_descs, vq_size) < 0)) {
                ret = VIRTIO_CRYPTO_BADMSG;
                goto error_exit;
        }
@@ -818,7 +868,8 @@ prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                }
 
                if (unlikely(move_desc(vc_req->head, &desc,
-                               cipher->para.src_data_len) < 0)) {
+                               cipher->para.src_data_len, nb_descs,
+                               vq_size) < 0)) {
                        VC_LOG_ERR("Incorrect descriptor");
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -835,8 +886,8 @@ prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                        goto error_exit;
                }
                if (unlikely(copy_data(rte_pktmbuf_mtod(m_src, uint8_t *),
-                               vc_req, &desc, cipher->para.src_data_len)
-                               < 0)) {
+                               vc_req, &desc, cipher->para.src_data_len,
+                               nb_descs, vq_size) < 0)) {
                        ret = VIRTIO_CRYPTO_BADMSG;
                        goto error_exit;
                }
@@ -847,7 +898,7 @@ prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
        }
 
        /* dst */
-       desc = find_write_desc(vc_req->head, desc);
+       desc = find_write_desc(vc_req->head, desc, nb_descs, vq_size);
        if (unlikely(!desc)) {
                VC_LOG_ERR("Cannot find write location");
                ret = VIRTIO_CRYPTO_BADMSG;
@@ -866,7 +917,8 @@ prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                }
 
                if (unlikely(move_desc(vc_req->head, &desc,
-                               cipher->para.dst_data_len) < 0)) {
+                               cipher->para.dst_data_len,
+                               nb_descs, vq_size) < 0)) {
                        VC_LOG_ERR("Incorrect descriptor");
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -877,7 +929,7 @@ prepare_sym_cipher_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
        case RTE_VHOST_CRYPTO_ZERO_COPY_DISABLE:
                vc_req->wb = prepare_write_back_data(vc_req, &desc, &ewb,
                                rte_pktmbuf_mtod(m_src, uint8_t *), 0,
-                               cipher->para.dst_data_len);
+                               cipher->para.dst_data_len, nb_descs, vq_size);
                if (unlikely(vc_req->wb == NULL)) {
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -919,7 +971,8 @@ static uint8_t
 prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                struct vhost_crypto_data_req *vc_req,
                struct virtio_crypto_alg_chain_data_req *chain,
-               struct vring_desc *cur_desc)
+               struct vring_desc *cur_desc,
+               uint32_t *nb_descs, uint32_t vq_size)
 {
        struct vring_desc *desc = cur_desc, *digest_desc;
        struct vhost_crypto_writeback_data *ewb = NULL, *ewb2 = NULL;
@@ -932,7 +985,7 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
        /* prepare */
        /* iv */
        if (unlikely(copy_data(iv_data, vc_req, &desc,
-                       chain->para.iv_len) < 0)) {
+                       chain->para.iv_len, nb_descs, vq_size) < 0)) {
                ret = VIRTIO_CRYPTO_BADMSG;
                goto error_exit;
        }
@@ -953,7 +1006,8 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                }
 
                if (unlikely(move_desc(vc_req->head, &desc,
-                               chain->para.src_data_len) < 0)) {
+                               chain->para.src_data_len,
+                               nb_descs, vq_size) < 0)) {
                        VC_LOG_ERR("Incorrect descriptor");
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -969,7 +1023,8 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                        goto error_exit;
                }
                if (unlikely(copy_data(rte_pktmbuf_mtod(m_src, uint8_t *),
-                               vc_req, &desc, chain->para.src_data_len)) < 0) {
+                               vc_req, &desc, chain->para.src_data_len,
+                               nb_descs, vq_size)) < 0) {
                        ret = VIRTIO_CRYPTO_BADMSG;
                        goto error_exit;
                }
@@ -981,7 +1036,7 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
        }
 
        /* dst */
-       desc = find_write_desc(vc_req->head, desc);
+       desc = find_write_desc(vc_req->head, desc, nb_descs, vq_size);
        if (unlikely(!desc)) {
                VC_LOG_ERR("Cannot find write location");
                ret = VIRTIO_CRYPTO_BADMSG;
@@ -1000,7 +1055,8 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                }
 
                if (unlikely(move_desc(vc_req->head, &desc,
-                               chain->para.dst_data_len) < 0)) {
+                               chain->para.dst_data_len,
+                               nb_descs, vq_size) < 0)) {
                        VC_LOG_ERR("Incorrect descriptor");
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -1017,7 +1073,8 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                }
 
                if (unlikely(move_desc(vc_req->head, &desc,
-                               chain->para.hash_result_len) < 0)) {
+                               chain->para.hash_result_len,
+                               nb_descs, vq_size) < 0)) {
                        VC_LOG_ERR("Incorrect descriptor");
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -1029,7 +1086,8 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
                                rte_pktmbuf_mtod(m_src, uint8_t *),
                                chain->para.cipher_start_src_offset,
                                chain->para.dst_data_len -
-                               chain->para.cipher_start_src_offset);
+                               chain->para.cipher_start_src_offset,
+                               nb_descs, vq_size);
                if (unlikely(vc_req->wb == NULL)) {
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
@@ -1042,14 +1100,16 @@ prepare_sym_chain_op(struct vhost_crypto *vcrypto, struct rte_crypto_op *op,
 
                /** create a wb_data for digest */
                ewb->next = prepare_write_back_data(vc_req, &desc, &ewb2,
-                               digest_addr, 0, chain->para.hash_result_len);
+                               digest_addr, 0, chain->para.hash_result_len,
+                               nb_descs, vq_size);
                if (unlikely(ewb->next == NULL)) {
                        ret = VIRTIO_CRYPTO_ERR;
                        goto error_exit;
                }
 
                if (unlikely(copy_data(digest_addr, vc_req, &digest_desc,
-                               chain->para.hash_result_len)) < 0) {
+                               chain->para.hash_result_len,
+                               nb_descs, vq_size) < 0)) {
                        ret = VIRTIO_CRYPTO_BADMSG;
                        goto error_exit;
                }
@@ -1108,6 +1168,7 @@ vhost_crypto_process_one_req(struct vhost_crypto *vcrypto,
        struct vring_desc *desc = NULL;
        uint64_t session_id;
        uint64_t dlen;
+       uint32_t nb_descs = vq->size;
        int err = 0;
 
        vc_req->desc_idx = desc_idx;
@@ -1116,6 +1177,10 @@ vhost_crypto_process_one_req(struct vhost_crypto *vcrypto,
 
        if (likely(head->flags & VRING_DESC_F_INDIRECT)) {
                dlen = head->len;
+               nb_descs = dlen / sizeof(struct vring_desc);
+               /* drop invalid descriptors */
+               if (unlikely(nb_descs > vq->size))
+                       return -1;
                desc = IOVA_TO_VVA(struct vring_desc *, vc_req, head->addr,
                                &dlen, VHOST_ACCESS_RO);
                if (unlikely(!desc || dlen != head->len))
@@ -1138,8 +1203,8 @@ vhost_crypto_process_one_req(struct vhost_crypto *vcrypto,
                        goto error_exit;
                case RTE_VHOST_CRYPTO_ZERO_COPY_DISABLE:
                        req = &tmp_req;
-                       if (unlikely(copy_data(req, vc_req, &desc, sizeof(*req))
-                                       < 0)) {
+                       if (unlikely(copy_data(req, vc_req, &desc, sizeof(*req),
+                                       &nb_descs, vq->size) < 0)) {
                                err = VIRTIO_CRYPTO_BADMSG;
                                VC_LOG_ERR("Invalid descriptor");
                                goto error_exit;
@@ -1152,7 +1217,7 @@ vhost_crypto_process_one_req(struct vhost_crypto *vcrypto,
                }
        } else {
                if (unlikely(move_desc(vc_req->head, &desc,
-                               sizeof(*req)) < 0)) {
+                               sizeof(*req), &nb_descs, vq->size) < 0)) {
                        VC_LOG_ERR("Incorrect descriptor");
                        goto error_exit;
                }
@@ -1193,11 +1258,13 @@ vhost_crypto_process_one_req(struct vhost_crypto *vcrypto,
                        break;
                case VIRTIO_CRYPTO_SYM_OP_CIPHER:
                        err = prepare_sym_cipher_op(vcrypto, op, vc_req,
-                                       &req->u.sym_req.u.cipher, desc);
+                                       &req->u.sym_req.u.cipher, desc,
+                                       &nb_descs, vq->size);
                        break;
                case VIRTIO_CRYPTO_SYM_OP_ALGORITHM_CHAINING:
                        err = prepare_sym_chain_op(vcrypto, op, vc_req,
-                                       &req->u.sym_req.u.chain, desc);
+                                       &req->u.sym_req.u.chain, desc,
+                                       &nb_descs, vq->size);
                        break;
                }
                if (unlikely(err != 0)) {
@@ -1215,7 +1282,7 @@ vhost_crypto_process_one_req(struct vhost_crypto *vcrypto,
 
 error_exit:
 
-       inhdr = reach_inhdr(vc_req, desc);
+       inhdr = reach_inhdr(vc_req, desc, &nb_descs, vq->size);
        if (likely(inhdr != NULL))
                inhdr->status = (uint8_t)err;
 
@@ -1293,7 +1360,9 @@ vhost_crypto_complete_one_vm_requests(struct rte_crypto_op **ops,
 
 int __rte_experimental
 rte_vhost_crypto_create(int vid, uint8_t cryptodev_id,
-               struct rte_mempool *sess_pool, int socket_id)
+               struct rte_mempool *sess_pool,
+               struct rte_mempool *sess_priv_pool,
+               int socket_id)
 {
        struct virtio_net *dev = get_device(vid);
        struct rte_hash_parameters params = {0};
@@ -1321,6 +1390,7 @@ rte_vhost_crypto_create(int vid, uint8_t cryptodev_id,
        }
 
        vcrypto->sess_pool = sess_pool;
+       vcrypto->sess_priv_pool = sess_priv_pool;
        vcrypto->cid = cryptodev_id;
        vcrypto->cache_session_id = UINT64_MAX;
        vcrypto->last_session_id = 1;
@@ -1557,7 +1627,7 @@ rte_vhost_crypto_fetch_requests(int vid, uint32_t qid,
                        op->sym->m_src->data_off = 0;
 
                        if (unlikely(vhost_crypto_process_one_req(vcrypto, vq,
-                                       op, head, desc_idx)) < 0)
+                                       op, head, desc_idx) < 0))
                                break;
                }