vhost: check cmsg not null
[dpdk.git] / lib / librte_vhost / socket.c
index 83befdc..6ba60f5 100644 (file)
@@ -97,6 +97,7 @@ read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
        size_t fdsize = fd_num * sizeof(int);
        char control[CMSG_SPACE(fdsize)];
        struct cmsghdr *cmsg;
+       int got_fds = 0;
        int ret;
 
        memset(&msgh, 0, sizeof(msgh));
@@ -123,11 +124,16 @@ read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
                cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
                if ((cmsg->cmsg_level == SOL_SOCKET) &&
                        (cmsg->cmsg_type == SCM_RIGHTS)) {
-                       memcpy(fds, CMSG_DATA(cmsg), fdsize);
+                       got_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+                       memcpy(fds, CMSG_DATA(cmsg), got_fds * sizeof(int));
                        break;
                }
        }
 
+       /* Clear out unused file descriptors */
+       while (got_fds < fd_num)
+               fds[got_fds++] = -1;
+
        return ret;
 }
 
@@ -153,6 +159,11 @@ send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
                msgh.msg_control = control;
                msgh.msg_controllen = sizeof(control);
                cmsg = CMSG_FIRSTHDR(&msgh);
+               if (cmsg == NULL) {
+                       RTE_LOG(ERR, VHOST_CONFIG, "cmsg == NULL\n");
+                       errno = EINVAL;
+                       return -1;
+               }
                cmsg->cmsg_len = CMSG_LEN(fdsize);
                cmsg->cmsg_level = SOL_SOCKET;
                cmsg->cmsg_type = SCM_RIGHTS;