vhost: get device by device id only
[dpdk.git] / lib / librte_vhost / vhost_user / vhost-net-user.c
index 31f1215..68fc9b9 100644 (file)
 
 static void vserver_new_vq_conn(int fd, void *data, int *remove);
 static void vserver_message_handler(int fd, void *dat, int *remove);
-struct vhost_net_device_ops const *ops;
 
 struct connfd_ctx {
        struct vhost_server *vserver;
-       uint32_t fh;
+       int vid;
 };
 
 #define MAX_VHOST_SERVER 1024
 struct _vhost_server {
        struct vhost_server *server[MAX_VHOST_SERVER];
        struct fdset fdset;
+       int vserver_cnt;
+       pthread_mutex_t server_mutex;
 };
 
 static struct _vhost_server g_vhost_server = {
@@ -74,10 +75,10 @@ static struct _vhost_server g_vhost_server = {
                .fd_mutex = PTHREAD_MUTEX_INITIALIZER,
                .num = 0
        },
+       .vserver_cnt = 0,
+       .server_mutex = PTHREAD_MUTEX_INITIALIZER,
 };
 
-static int vserver_idx;
-
 static const char *vhost_message_str[VHOST_USER_MAX] = {
        [VHOST_USER_NONE] = "VHOST_USER_NONE",
        [VHOST_USER_GET_FEATURES] = "VHOST_USER_GET_FEATURES",
@@ -93,7 +94,12 @@ static const char *vhost_message_str[VHOST_USER_MAX] = {
        [VHOST_USER_GET_VRING_BASE] = "VHOST_USER_GET_VRING_BASE",
        [VHOST_USER_SET_VRING_KICK] = "VHOST_USER_SET_VRING_KICK",
        [VHOST_USER_SET_VRING_CALL] = "VHOST_USER_SET_VRING_CALL",
-       [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR"
+       [VHOST_USER_SET_VRING_ERR]  = "VHOST_USER_SET_VRING_ERR",
+       [VHOST_USER_GET_PROTOCOL_FEATURES]  = "VHOST_USER_GET_PROTOCOL_FEATURES",
+       [VHOST_USER_SET_PROTOCOL_FEATURES]  = "VHOST_USER_SET_PROTOCOL_FEATURES",
+       [VHOST_USER_GET_QUEUE_NUM]  = "VHOST_USER_GET_QUEUE_NUM",
+       [VHOST_USER_SET_VRING_ENABLE]  = "VHOST_USER_SET_VRING_ENABLE",
+       [VHOST_USER_SEND_RARP]  = "VHOST_USER_SEND_RARP",
 };
 
 /**
@@ -120,8 +126,11 @@ uds_socket(const char *path)
        un.sun_family = AF_UNIX;
        snprintf(un.sun_path, sizeof(un.sun_path), "%s", path);
        ret = bind(sockfd, (struct sockaddr *)&un, sizeof(un));
-       if (ret == -1)
+       if (ret == -1) {
+               RTE_LOG(ERR, VHOST_CONFIG, "fail to bind fd:%d, remove file:%s and try again.\n",
+                       sockfd, path);
                goto err;
+       }
        RTE_LOG(INFO, VHOST_CONFIG, "bind to %s\n", path);
 
        ret = listen(sockfd, MAX_VIRTIO_BACKLOG);
@@ -276,8 +285,7 @@ vserver_new_vq_conn(int fd, void *dat, __rte_unused int *remove)
        struct vhost_server *vserver = (struct vhost_server *)dat;
        int conn_fd;
        struct connfd_ctx *ctx;
-       int fh;
-       struct vhost_device_ctx vdev_ctx = { (pid_t)0, 0 };
+       int vid;
        unsigned int size;
 
        conn_fd = accept(fd, NULL, NULL);
@@ -292,22 +300,20 @@ vserver_new_vq_conn(int fd, void *dat, __rte_unused int *remove)
                return;
        }
 
-       fh = ops->new_device(vdev_ctx);
-       if (fh == -1) {
+       vid = vhost_new_device();
+       if (vid == -1) {
                free(ctx);
                close(conn_fd);
                return;
        }
 
-       vdev_ctx.fh = fh;
        size = strnlen(vserver->path, PATH_MAX);
-       ops->set_ifname(vdev_ctx, vserver->path,
-               size);
+       vhost_set_ifname(vid, vserver->path, size);
 
-       RTE_LOG(INFO, VHOST_CONFIG, "new device, handle is %d\n", fh);
+       RTE_LOG(INFO, VHOST_CONFIG, "new device, handle is %d\n", vid);
 
        ctx->vserver = vserver;
-       ctx->fh = fh;
+       ctx->vid = vid;
        fdset_add(&g_vhost_server.fdset,
                conn_fd, vserver_message_handler, NULL, ctx);
 }
@@ -316,46 +322,29 @@ vserver_new_vq_conn(int fd, void *dat, __rte_unused int *remove)
 static void
 vserver_message_handler(int connfd, void *dat, int *remove)
 {
-       struct vhost_device_ctx ctx;
+       int vid;
        struct connfd_ctx *cfd_ctx = (struct connfd_ctx *)dat;
        struct VhostUserMsg msg;
        uint64_t features;
        int ret;
 
-       ctx.fh = cfd_ctx->fh;
+       vid = cfd_ctx->vid;
        ret = read_vhost_message(connfd, &msg);
-       if (ret < 0) {
-               RTE_LOG(ERR, VHOST_CONFIG,
-                       "vhost read message failed\n");
-
-               close(connfd);
-               *remove = 1;
-               free(cfd_ctx);
-               user_destroy_device(ctx);
-               ops->destroy_device(ctx);
-
-               return;
-       } else if (ret == 0) {
-               RTE_LOG(INFO, VHOST_CONFIG,
-                       "vhost peer closed\n");
-
-               close(connfd);
-               *remove = 1;
-               free(cfd_ctx);
-               user_destroy_device(ctx);
-               ops->destroy_device(ctx);
-
-               return;
-       }
-       if (msg.request > VHOST_USER_MAX) {
-               RTE_LOG(ERR, VHOST_CONFIG,
-                       "vhost read incorrect message\n");
+       if (ret <= 0 || msg.request >= VHOST_USER_MAX) {
+               if (ret < 0)
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "vhost read message failed\n");
+               else if (ret == 0)
+                       RTE_LOG(INFO, VHOST_CONFIG,
+                               "vhost peer closed\n");
+               else
+                       RTE_LOG(ERR, VHOST_CONFIG,
+                               "vhost read incorrect message\n");
 
                close(connfd);
                *remove = 1;
                free(cfd_ctx);
-               user_destroy_device(ctx);
-               ops->destroy_device(ctx);
+               vhost_destroy_device(vid);
 
                return;
        }
@@ -364,55 +353,69 @@ vserver_message_handler(int connfd, void *dat, int *remove)
                vhost_message_str[msg.request]);
        switch (msg.request) {
        case VHOST_USER_GET_FEATURES:
-               ret = ops->get_features(ctx, &features);
+               ret = vhost_get_features(vid, &features);
                msg.payload.u64 = features;
                msg.size = sizeof(msg.payload.u64);
                send_vhost_message(connfd, &msg);
                break;
        case VHOST_USER_SET_FEATURES:
                features = msg.payload.u64;
-               ops->set_features(ctx, &features);
+               vhost_set_features(vid, &features);
+               break;
+
+       case VHOST_USER_GET_PROTOCOL_FEATURES:
+               msg.payload.u64 = VHOST_USER_PROTOCOL_FEATURES;
+               msg.size = sizeof(msg.payload.u64);
+               send_vhost_message(connfd, &msg);
+               break;
+       case VHOST_USER_SET_PROTOCOL_FEATURES:
+               user_set_protocol_features(vid, msg.payload.u64);
                break;
 
        case VHOST_USER_SET_OWNER:
-               ops->set_owner(ctx);
+               vhost_set_owner(vid);
                break;
        case VHOST_USER_RESET_OWNER:
-               ops->reset_owner(ctx);
+               vhost_reset_owner(vid);
                break;
 
        case VHOST_USER_SET_MEM_TABLE:
-               user_set_mem_table(ctx, &msg);
+               user_set_mem_table(vid, &msg);
                break;
 
        case VHOST_USER_SET_LOG_BASE:
-               RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
+               user_set_log_base(vid, &msg);
+
+               /* it needs a reply */
+               msg.size = sizeof(msg.payload.u64);
+               send_vhost_message(connfd, &msg);
+               break;
        case VHOST_USER_SET_LOG_FD:
                close(msg.fds[0]);
                RTE_LOG(INFO, VHOST_CONFIG, "not implemented.\n");
                break;
 
        case VHOST_USER_SET_VRING_NUM:
-               ops->set_vring_num(ctx, &msg.payload.state);
+               vhost_set_vring_num(vid, &msg.payload.state);
                break;
        case VHOST_USER_SET_VRING_ADDR:
-               ops->set_vring_addr(ctx, &msg.payload.addr);
+               vhost_set_vring_addr(vid, &msg.payload.addr);
                break;
        case VHOST_USER_SET_VRING_BASE:
-               ops->set_vring_base(ctx, &msg.payload.state);
+               vhost_set_vring_base(vid, &msg.payload.state);
                break;
 
        case VHOST_USER_GET_VRING_BASE:
-               ret = user_get_vring_base(ctx, &msg.payload.state);
+               ret = user_get_vring_base(vid, &msg.payload.state);
                msg.size = sizeof(msg.payload.state);
                send_vhost_message(connfd, &msg);
                break;
 
        case VHOST_USER_SET_VRING_KICK:
-               user_set_vring_kick(ctx, &msg);
+               user_set_vring_kick(vid, &msg);
                break;
        case VHOST_USER_SET_VRING_CALL:
-               user_set_vring_call(ctx, &msg);
+               user_set_vring_call(vid, &msg);
                break;
 
        case VHOST_USER_SET_VRING_ERR:
@@ -421,13 +424,25 @@ vserver_message_handler(int connfd, void *dat, int *remove)
                RTE_LOG(INFO, VHOST_CONFIG, "not implemented\n");
                break;
 
+       case VHOST_USER_GET_QUEUE_NUM:
+               msg.payload.u64 = VHOST_MAX_QUEUE_PAIRS;
+               msg.size = sizeof(msg.payload.u64);
+               send_vhost_message(connfd, &msg);
+               break;
+
+       case VHOST_USER_SET_VRING_ENABLE:
+               user_set_vring_enable(vid, &msg.payload.state);
+               break;
+       case VHOST_USER_SEND_RARP:
+               user_send_rarp(vid, &msg);
+               break;
+
        default:
                break;
 
        }
 }
 
-
 /**
  * Creates and initialise the vhost server.
  */
@@ -436,34 +451,75 @@ rte_vhost_driver_register(const char *path)
 {
        struct vhost_server *vserver;
 
-       if (vserver_idx == 0)
-               ops = get_virtio_net_callbacks();
-       if (vserver_idx == MAX_VHOST_SERVER)
+       pthread_mutex_lock(&g_vhost_server.server_mutex);
+
+       if (g_vhost_server.vserver_cnt == MAX_VHOST_SERVER) {
+               RTE_LOG(ERR, VHOST_CONFIG,
+                       "error: the number of servers reaches maximum\n");
+               pthread_mutex_unlock(&g_vhost_server.server_mutex);
                return -1;
+       }
 
        vserver = calloc(sizeof(struct vhost_server), 1);
-       if (vserver == NULL)
+       if (vserver == NULL) {
+               pthread_mutex_unlock(&g_vhost_server.server_mutex);
                return -1;
-
-       unlink(path);
+       }
 
        vserver->listenfd = uds_socket(path);
        if (vserver->listenfd < 0) {
                free(vserver);
+               pthread_mutex_unlock(&g_vhost_server.server_mutex);
                return -1;
        }
-       vserver->path = path;
+
+       vserver->path = strdup(path);
 
        fdset_add(&g_vhost_server.fdset, vserver->listenfd,
-               vserver_new_vq_conn, NULL,
-               vserver);
+               vserver_new_vq_conn, NULL, vserver);
 
-       g_vhost_server.server[vserver_idx++] = vserver;
+       g_vhost_server.server[g_vhost_server.vserver_cnt++] = vserver;
+       pthread_mutex_unlock(&g_vhost_server.server_mutex);
 
        return 0;
 }
 
 
+/**
+ * Unregister the specified vhost server
+ */
+int
+rte_vhost_driver_unregister(const char *path)
+{
+       int i;
+       int count;
+
+       pthread_mutex_lock(&g_vhost_server.server_mutex);
+
+       for (i = 0; i < g_vhost_server.vserver_cnt; i++) {
+               if (!strcmp(g_vhost_server.server[i]->path, path)) {
+                       fdset_del(&g_vhost_server.fdset,
+                               g_vhost_server.server[i]->listenfd);
+
+                       close(g_vhost_server.server[i]->listenfd);
+                       free(g_vhost_server.server[i]->path);
+                       free(g_vhost_server.server[i]);
+
+                       unlink(path);
+
+                       count = --g_vhost_server.vserver_cnt;
+                       g_vhost_server.server[i] = g_vhost_server.server[count];
+                       g_vhost_server.server[count] = NULL;
+                       pthread_mutex_unlock(&g_vhost_server.server_mutex);
+
+                       return 0;
+               }
+       }
+       pthread_mutex_unlock(&g_vhost_server.server_mutex);
+
+       return -1;
+}
+
 int
 rte_vhost_driver_session_start(void)
 {