net/mlx5: use SPDX tags in 6WIND copyrighted files
[dpdk.git] / drivers / net / mlx5 / mlx5_socket.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright 2016 6WIND S.A.
3  */
4
5 #define _GNU_SOURCE
6
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <sys/un.h>
10 #include <fcntl.h>
11 #include <stdio.h>
12 #include <unistd.h>
13 #include <sys/stat.h>
14
15 #include "mlx5.h"
16 #include "mlx5_utils.h"
17
18 /**
19  * Initialise the socket to communicate with the secondary process
20  *
21  * @param[in] priv
22  *   Pointer to private structure.
23  *
24  * @return
25  *   0 on success, errno value on failure.
26  */
27 int
28 priv_socket_init(struct priv *priv)
29 {
30         struct sockaddr_un sun = {
31                 .sun_family = AF_UNIX,
32         };
33         int ret;
34         int flags;
35         struct stat file_stat;
36
37         /*
38          * Initialise the socket to communicate with the secondary
39          * process.
40          */
41         ret = socket(AF_UNIX, SOCK_STREAM, 0);
42         if (ret < 0) {
43                 WARN("secondary process not supported: %s", strerror(errno));
44                 return ret;
45         }
46         priv->primary_socket = ret;
47         flags = fcntl(priv->primary_socket, F_GETFL, 0);
48         if (flags == -1)
49                 goto out;
50         ret = fcntl(priv->primary_socket, F_SETFL, flags | O_NONBLOCK);
51         if (ret < 0)
52                 goto out;
53         snprintf(sun.sun_path, sizeof(sun.sun_path), "/var/tmp/%s_%d",
54                  MLX5_DRIVER_NAME, priv->primary_socket);
55         ret = stat(sun.sun_path, &file_stat);
56         if (!ret)
57                 claim_zero(remove(sun.sun_path));
58         ret = bind(priv->primary_socket, (const struct sockaddr *)&sun,
59                    sizeof(sun));
60         if (ret < 0) {
61                 WARN("cannot bind socket, secondary process not supported: %s",
62                      strerror(errno));
63                 goto close;
64         }
65         ret = listen(priv->primary_socket, 0);
66         if (ret < 0) {
67                 WARN("Secondary process not supported: %s", strerror(errno));
68                 goto close;
69         }
70         return ret;
71 close:
72         remove(sun.sun_path);
73 out:
74         claim_zero(close(priv->primary_socket));
75         priv->primary_socket = 0;
76         return -(ret);
77 }
78
79 /**
80  * Un-Initialise the socket to communicate with the secondary process
81  *
82  * @param[in] priv
83  *   Pointer to private structure.
84  *
85  * @return
86  *   0 on success, errno value on failure.
87  */
88 int
89 priv_socket_uninit(struct priv *priv)
90 {
91         MKSTR(path, "/var/tmp/%s_%d", MLX5_DRIVER_NAME, priv->primary_socket);
92         claim_zero(close(priv->primary_socket));
93         priv->primary_socket = 0;
94         claim_zero(remove(path));
95         return 0;
96 }
97
98 /**
99  * Handle socket interrupts.
100  *
101  * @param priv
102  *   Pointer to private structure.
103  */
104 void
105 priv_socket_handle(struct priv *priv)
106 {
107         int conn_sock;
108         int ret = 0;
109         struct cmsghdr *cmsg = NULL;
110         struct ucred *cred = NULL;
111         char buf[CMSG_SPACE(sizeof(struct ucred))] = { 0 };
112         char vbuf[1024] = { 0 };
113         struct iovec io = {
114                 .iov_base = vbuf,
115                 .iov_len = sizeof(*vbuf),
116         };
117         struct msghdr msg = {
118                 .msg_iov = &io,
119                 .msg_iovlen = 1,
120                 .msg_control = buf,
121                 .msg_controllen = sizeof(buf),
122         };
123         int *fd;
124
125         /* Accept the connection from the client. */
126         conn_sock = accept(priv->primary_socket, NULL, NULL);
127         if (conn_sock < 0) {
128                 WARN("connection failed: %s", strerror(errno));
129                 return;
130         }
131         ret = setsockopt(conn_sock, SOL_SOCKET, SO_PASSCRED, &(int){1},
132                                          sizeof(int));
133         if (ret < 0) {
134                 WARN("cannot change socket options");
135                 goto out;
136         }
137         ret = recvmsg(conn_sock, &msg, MSG_WAITALL);
138         if (ret < 0) {
139                 WARN("received an empty message: %s", strerror(errno));
140                 goto out;
141         }
142         /* Expect to receive credentials only. */
143         cmsg = CMSG_FIRSTHDR(&msg);
144         if (cmsg == NULL) {
145                 WARN("no message");
146                 goto out;
147         }
148         if ((cmsg->cmsg_type == SCM_CREDENTIALS) &&
149                 (cmsg->cmsg_len >= sizeof(*cred))) {
150                 cred = (struct ucred *)CMSG_DATA(cmsg);
151                 assert(cred != NULL);
152         }
153         cmsg = CMSG_NXTHDR(&msg, cmsg);
154         if (cmsg != NULL) {
155                 WARN("Message wrongly formatted");
156                 goto out;
157         }
158         /* Make sure all the ancillary data was received and valid. */
159         if ((cred == NULL) || (cred->uid != getuid()) ||
160             (cred->gid != getgid())) {
161                 WARN("wrong credentials");
162                 goto out;
163         }
164         /* Set-up the ancillary data. */
165         cmsg = CMSG_FIRSTHDR(&msg);
166         assert(cmsg != NULL);
167         cmsg->cmsg_level = SOL_SOCKET;
168         cmsg->cmsg_type = SCM_RIGHTS;
169         cmsg->cmsg_len = CMSG_LEN(sizeof(priv->ctx->cmd_fd));
170         fd = (int *)CMSG_DATA(cmsg);
171         *fd = priv->ctx->cmd_fd;
172         ret = sendmsg(conn_sock, &msg, 0);
173         if (ret < 0)
174                 WARN("cannot send response");
175 out:
176         close(conn_sock);
177 }
178
179 /**
180  * Connect to the primary process.
181  *
182  * @param[in] priv
183  *   Pointer to private structure.
184  *
185  * @return
186  *   fd on success, negative errno value on failure.
187  */
188 int
189 priv_socket_connect(struct priv *priv)
190 {
191         struct sockaddr_un sun = {
192                 .sun_family = AF_UNIX,
193         };
194         int socket_fd;
195         int *fd = NULL;
196         int ret;
197         struct ucred *cred;
198         char buf[CMSG_SPACE(sizeof(*cred))] = { 0 };
199         char vbuf[1024] = { 0 };
200         struct iovec io = {
201                 .iov_base = vbuf,
202                 .iov_len = sizeof(*vbuf),
203         };
204         struct msghdr msg = {
205                 .msg_control = buf,
206                 .msg_controllen = sizeof(buf),
207                 .msg_iov = &io,
208                 .msg_iovlen = 1,
209         };
210         struct cmsghdr *cmsg;
211
212         ret = socket(AF_UNIX, SOCK_STREAM, 0);
213         if (ret < 0) {
214                 WARN("cannot connect to primary");
215                 return ret;
216         }
217         socket_fd = ret;
218         snprintf(sun.sun_path, sizeof(sun.sun_path), "/var/tmp/%s_%d",
219                  MLX5_DRIVER_NAME, priv->primary_socket);
220         ret = connect(socket_fd, (const struct sockaddr *)&sun, sizeof(sun));
221         if (ret < 0) {
222                 WARN("cannot connect to primary");
223                 goto out;
224         }
225         cmsg = CMSG_FIRSTHDR(&msg);
226         if (cmsg == NULL) {
227                 DEBUG("cannot get first message");
228                 goto out;
229         }
230         cmsg->cmsg_level = SOL_SOCKET;
231         cmsg->cmsg_type = SCM_CREDENTIALS;
232         cmsg->cmsg_len = CMSG_LEN(sizeof(*cred));
233         cred = (struct ucred *)CMSG_DATA(cmsg);
234         if (cred == NULL) {
235                 DEBUG("no credentials received");
236                 goto out;
237         }
238         cred->pid = getpid();
239         cred->uid = getuid();
240         cred->gid = getgid();
241         ret = sendmsg(socket_fd, &msg, MSG_DONTWAIT);
242         if (ret < 0) {
243                 WARN("cannot send credentials to primary: %s",
244                      strerror(errno));
245                 goto out;
246         }
247         ret = recvmsg(socket_fd, &msg, MSG_WAITALL);
248         if (ret <= 0) {
249                 WARN("no message from primary: %s", strerror(errno));
250                 goto out;
251         }
252         cmsg = CMSG_FIRSTHDR(&msg);
253         if (cmsg == NULL) {
254                 WARN("No file descriptor received");
255                 goto out;
256         }
257         fd = (int *)CMSG_DATA(cmsg);
258         if (*fd <= 0) {
259                 WARN("no file descriptor received: %s", strerror(errno));
260                 ret = *fd;
261                 goto out;
262         }
263         ret = *fd;
264 out:
265         close(socket_fd);
266         return ret;
267 }