net/virtio: fix incorrect cast of void *
[dpdk.git] / fib / trie_avx512.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2020 Intel Corporation
3  */
4
5 #include <rte_vect.h>
6 #include <rte_fib6.h>
7
8 #include "trie.h"
9 #include "trie_avx512.h"
10
11 static __rte_always_inline void
12 transpose_x16(uint8_t ips[16][RTE_FIB6_IPV6_ADDR_SIZE],
13         __m512i *first, __m512i *second, __m512i *third, __m512i *fourth)
14 {
15         __m512i tmp1, tmp2, tmp3, tmp4;
16         __m512i tmp5, tmp6, tmp7, tmp8;
17         const __rte_x86_zmm_t perm_idxes = {
18                 .u32 = { 0, 4, 8, 12, 2, 6, 10, 14,
19                         1, 5, 9, 13, 3, 7, 11, 15
20                 },
21         };
22
23         /* load all ip addresses */
24         tmp1 = _mm512_loadu_si512(&ips[0][0]);
25         tmp2 = _mm512_loadu_si512(&ips[4][0]);
26         tmp3 = _mm512_loadu_si512(&ips[8][0]);
27         tmp4 = _mm512_loadu_si512(&ips[12][0]);
28
29         /* transpose 4 byte chunks of 16 ips */
30         tmp5 = _mm512_unpacklo_epi32(tmp1, tmp2);
31         tmp7 = _mm512_unpackhi_epi32(tmp1, tmp2);
32         tmp6 = _mm512_unpacklo_epi32(tmp3, tmp4);
33         tmp8 = _mm512_unpackhi_epi32(tmp3, tmp4);
34
35         tmp1 = _mm512_unpacklo_epi32(tmp5, tmp6);
36         tmp3 = _mm512_unpackhi_epi32(tmp5, tmp6);
37         tmp2 = _mm512_unpacklo_epi32(tmp7, tmp8);
38         tmp4 = _mm512_unpackhi_epi32(tmp7, tmp8);
39
40         /* first 4-byte chunks of ips[] */
41         *first = _mm512_permutexvar_epi32(perm_idxes.z, tmp1);
42         /* second 4-byte chunks of ips[] */
43         *second = _mm512_permutexvar_epi32(perm_idxes.z, tmp3);
44         /* third 4-byte chunks of ips[] */
45         *third = _mm512_permutexvar_epi32(perm_idxes.z, tmp2);
46         /* fourth 4-byte chunks of ips[] */
47         *fourth = _mm512_permutexvar_epi32(perm_idxes.z, tmp4);
48 }
49
50 static __rte_always_inline void
51 transpose_x8(uint8_t ips[8][RTE_FIB6_IPV6_ADDR_SIZE],
52         __m512i *first, __m512i *second)
53 {
54         __m512i tmp1, tmp2, tmp3, tmp4;
55         const __rte_x86_zmm_t perm_idxes = {
56                 .u64 = { 0, 2, 4, 6, 1, 3, 5, 7
57                 },
58         };
59
60         tmp1 = _mm512_loadu_si512(&ips[0][0]);
61         tmp2 = _mm512_loadu_si512(&ips[4][0]);
62
63         tmp3 = _mm512_unpacklo_epi64(tmp1, tmp2);
64         *first = _mm512_permutexvar_epi64(perm_idxes.z, tmp3);
65         tmp4 = _mm512_unpackhi_epi64(tmp1, tmp2);
66         *second = _mm512_permutexvar_epi64(perm_idxes.z, tmp4);
67 }
68
69 static __rte_always_inline void
70 trie_vec_lookup_x16x2(void *p, uint8_t ips[32][RTE_FIB6_IPV6_ADDR_SIZE],
71         uint64_t *next_hops, int size)
72 {
73         struct rte_trie_tbl *dp = (struct rte_trie_tbl *)p;
74         const __m512i zero = _mm512_set1_epi32(0);
75         const __m512i lsb = _mm512_set1_epi32(1);
76         const __m512i two_lsb = _mm512_set1_epi32(3);
77         /* IPv6 four byte chunks */
78         __m512i first_1, second_1, third_1, fourth_1;
79         __m512i first_2, second_2, third_2, fourth_2;
80         __m512i idxes_1, res_1;
81         __m512i idxes_2, res_2;
82         __m512i shuf_idxes;
83         __m512i tmp_1, tmp2_1, bytes_1, byte_chunk_1;
84         __m512i tmp_2, tmp2_2, bytes_2, byte_chunk_2;
85         __m512i base_idxes;
86         /* used to mask gather values if size is 2 (16 bit next hops) */
87         const __m512i res_msk = _mm512_set1_epi32(UINT16_MAX);
88         const __rte_x86_zmm_t bswap = {
89                 .u8 = { 2, 1, 0, 255, 6, 5, 4, 255,
90                         10, 9, 8, 255, 14, 13, 12, 255,
91                         2, 1, 0, 255, 6, 5, 4, 255,
92                         10, 9, 8, 255, 14, 13, 12, 255,
93                         2, 1, 0, 255, 6, 5, 4, 255,
94                         10, 9, 8, 255, 14, 13, 12, 255,
95                         2, 1, 0, 255, 6, 5, 4, 255,
96                         10, 9, 8, 255, 14, 13, 12, 255
97                         },
98         };
99         const __mmask64 k = 0x1111111111111111;
100         int i = 3;
101         __mmask16 msk_ext_1, new_msk_1;
102         __mmask16 msk_ext_2, new_msk_2;
103         __mmask16 exp_msk = 0x5555;
104
105         transpose_x16(ips, &first_1, &second_1, &third_1, &fourth_1);
106         transpose_x16(ips + 16, &first_2, &second_2, &third_2, &fourth_2);
107
108         /* get_tbl24_idx() for every 4 byte chunk */
109         idxes_1 = _mm512_shuffle_epi8(first_1, bswap.z);
110         idxes_2 = _mm512_shuffle_epi8(first_2, bswap.z);
111
112         /**
113          * lookup in tbl24
114          * Put it inside branch to make compiller happy with -O0
115          */
116         if (size == sizeof(uint16_t)) {
117                 res_1 = _mm512_i32gather_epi32(idxes_1,
118                                 (const int *)dp->tbl24, 2);
119                 res_2 = _mm512_i32gather_epi32(idxes_2,
120                                 (const int *)dp->tbl24, 2);
121                 res_1 = _mm512_and_epi32(res_1, res_msk);
122                 res_2 = _mm512_and_epi32(res_2, res_msk);
123         } else {
124                 res_1 = _mm512_i32gather_epi32(idxes_1,
125                                 (const int *)dp->tbl24, 4);
126                 res_2 = _mm512_i32gather_epi32(idxes_2,
127                                 (const int *)dp->tbl24, 4);
128         }
129
130         /* get extended entries indexes */
131         msk_ext_1 = _mm512_test_epi32_mask(res_1, lsb);
132         msk_ext_2 = _mm512_test_epi32_mask(res_2, lsb);
133
134         tmp_1 = _mm512_srli_epi32(res_1, 1);
135         tmp_2 = _mm512_srli_epi32(res_2, 1);
136
137         /* idxes to retrieve bytes */
138         shuf_idxes = _mm512_setr_epi32(3, 7, 11, 15,
139                                 19, 23, 27, 31,
140                                 35, 39, 43, 47,
141                                 51, 55, 59, 63);
142
143         base_idxes = _mm512_setr_epi32(0, 4, 8, 12,
144                                 16, 20, 24, 28,
145                                 32, 36, 40, 44,
146                                 48, 52, 56, 60);
147
148         /* traverse down the trie */
149         while (msk_ext_1 || msk_ext_2) {
150                 idxes_1 = _mm512_maskz_slli_epi32(msk_ext_1, tmp_1, 8);
151                 idxes_2 = _mm512_maskz_slli_epi32(msk_ext_2, tmp_2, 8);
152                 byte_chunk_1 = (i < 8) ?
153                         ((i >= 4) ? second_1 : first_1) :
154                         ((i >= 12) ? fourth_1 : third_1);
155                 byte_chunk_2 = (i < 8) ?
156                         ((i >= 4) ? second_2 : first_2) :
157                         ((i >= 12) ? fourth_2 : third_2);
158                 bytes_1 = _mm512_maskz_shuffle_epi8(k, byte_chunk_1,
159                                 shuf_idxes);
160                 bytes_2 = _mm512_maskz_shuffle_epi8(k, byte_chunk_2,
161                                 shuf_idxes);
162                 idxes_1 = _mm512_maskz_add_epi32(msk_ext_1, idxes_1, bytes_1);
163                 idxes_2 = _mm512_maskz_add_epi32(msk_ext_2, idxes_2, bytes_2);
164                 if (size == sizeof(uint16_t)) {
165                         tmp_1 = _mm512_mask_i32gather_epi32(zero, msk_ext_1,
166                                 idxes_1, (const int *)dp->tbl8, 2);
167                         tmp_2 = _mm512_mask_i32gather_epi32(zero, msk_ext_2,
168                                 idxes_2, (const int *)dp->tbl8, 2);
169                         tmp_1 = _mm512_and_epi32(tmp_1, res_msk);
170                         tmp_2 = _mm512_and_epi32(tmp_2, res_msk);
171                 } else {
172                         tmp_1 = _mm512_mask_i32gather_epi32(zero, msk_ext_1,
173                                 idxes_1, (const int *)dp->tbl8, 4);
174                         tmp_2 = _mm512_mask_i32gather_epi32(zero, msk_ext_2,
175                                 idxes_2, (const int *)dp->tbl8, 4);
176                 }
177                 new_msk_1 = _mm512_test_epi32_mask(tmp_1, lsb);
178                 new_msk_2 = _mm512_test_epi32_mask(tmp_2, lsb);
179                 res_1 = _mm512_mask_blend_epi32(msk_ext_1 ^ new_msk_1, res_1,
180                                 tmp_1);
181                 res_2 = _mm512_mask_blend_epi32(msk_ext_2 ^ new_msk_2, res_2,
182                                 tmp_2);
183                 tmp_1 = _mm512_srli_epi32(tmp_1, 1);
184                 tmp_2 = _mm512_srli_epi32(tmp_2, 1);
185                 msk_ext_1 = new_msk_1;
186                 msk_ext_2 = new_msk_2;
187
188                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, lsb);
189                 shuf_idxes = _mm512_and_epi32(shuf_idxes, two_lsb);
190                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, base_idxes);
191                 i++;
192         }
193
194         /* get rid of 1 LSB, now we have HN in every epi32 */
195         res_1 = _mm512_srli_epi32(res_1, 1);
196         res_2 = _mm512_srli_epi32(res_2, 1);
197         /* extract first half of NH's each in epi64 chunk */
198         tmp_1 = _mm512_maskz_expand_epi32(exp_msk, res_1);
199         tmp_2 = _mm512_maskz_expand_epi32(exp_msk, res_2);
200         /* extract second half of NH's */
201         __m256i tmp256_1, tmp256_2;
202         tmp256_1 = _mm512_extracti32x8_epi32(res_1, 1);
203         tmp256_2 = _mm512_extracti32x8_epi32(res_2, 1);
204         tmp2_1 = _mm512_maskz_expand_epi32(exp_msk,
205                 _mm512_castsi256_si512(tmp256_1));
206         tmp2_2 = _mm512_maskz_expand_epi32(exp_msk,
207                 _mm512_castsi256_si512(tmp256_2));
208         /* return NH's from two sets of registers */
209         _mm512_storeu_si512(next_hops, tmp_1);
210         _mm512_storeu_si512(next_hops + 8, tmp2_1);
211         _mm512_storeu_si512(next_hops + 16, tmp_2);
212         _mm512_storeu_si512(next_hops + 24, tmp2_2);
213 }
214
215 static void
216 trie_vec_lookup_x8x2_8b(void *p, uint8_t ips[16][RTE_FIB6_IPV6_ADDR_SIZE],
217         uint64_t *next_hops)
218 {
219         struct rte_trie_tbl *dp = (struct rte_trie_tbl *)p;
220         const __m512i zero = _mm512_set1_epi32(0);
221         const __m512i lsb = _mm512_set1_epi32(1);
222         const __m512i three_lsb = _mm512_set1_epi32(7);
223         /* IPv6 eight byte chunks */
224         __m512i first_1, second_1;
225         __m512i first_2, second_2;
226         __m512i idxes_1, res_1;
227         __m512i idxes_2, res_2;
228         __m512i shuf_idxes, base_idxes;
229         __m512i tmp_1, bytes_1, byte_chunk_1;
230         __m512i tmp_2, bytes_2, byte_chunk_2;
231         const __rte_x86_zmm_t bswap = {
232                 .u8 = { 2, 1, 0, 255, 255, 255, 255, 255,
233                         10, 9, 8, 255, 255, 255, 255, 255,
234                         2, 1, 0, 255, 255, 255, 255, 255,
235                         10, 9, 8, 255, 255, 255, 255, 255,
236                         2, 1, 0, 255, 255, 255, 255, 255,
237                         10, 9, 8, 255, 255, 255, 255, 255,
238                         2, 1, 0, 255, 255, 255, 255, 255,
239                         10, 9, 8, 255, 255, 255, 255, 255
240                         },
241         };
242         const __mmask64 k = 0x101010101010101;
243         int i = 3;
244         __mmask8 msk_ext_1, new_msk_1;
245         __mmask8 msk_ext_2, new_msk_2;
246
247         transpose_x8(ips, &first_1, &second_1);
248         transpose_x8(ips + 8, &first_2, &second_2);
249
250         /* get_tbl24_idx() for every 4 byte chunk */
251         idxes_1 = _mm512_shuffle_epi8(first_1, bswap.z);
252         idxes_2 = _mm512_shuffle_epi8(first_2, bswap.z);
253
254         /* lookup in tbl24 */
255         res_1 = _mm512_i64gather_epi64(idxes_1, (const void *)dp->tbl24, 8);
256         res_2 = _mm512_i64gather_epi64(idxes_2, (const void *)dp->tbl24, 8);
257         /* get extended entries indexes */
258         msk_ext_1 = _mm512_test_epi64_mask(res_1, lsb);
259         msk_ext_2 = _mm512_test_epi64_mask(res_2, lsb);
260
261         tmp_1 = _mm512_srli_epi64(res_1, 1);
262         tmp_2 = _mm512_srli_epi64(res_2, 1);
263
264         /* idxes to retrieve bytes */
265         shuf_idxes = _mm512_setr_epi64(3, 11, 19, 27, 35, 43, 51, 59);
266
267         base_idxes = _mm512_setr_epi64(0, 8, 16, 24, 32, 40, 48, 56);
268
269         /* traverse down the trie */
270         while (msk_ext_1 || msk_ext_2) {
271                 idxes_1 = _mm512_maskz_slli_epi64(msk_ext_1, tmp_1, 8);
272                 idxes_2 = _mm512_maskz_slli_epi64(msk_ext_2, tmp_2, 8);
273                 byte_chunk_1 = (i < 8) ? first_1 : second_1;
274                 byte_chunk_2 = (i < 8) ? first_2 : second_2;
275                 bytes_1 = _mm512_maskz_shuffle_epi8(k, byte_chunk_1,
276                                 shuf_idxes);
277                 bytes_2 = _mm512_maskz_shuffle_epi8(k, byte_chunk_2,
278                                 shuf_idxes);
279                 idxes_1 = _mm512_maskz_add_epi64(msk_ext_1, idxes_1, bytes_1);
280                 idxes_2 = _mm512_maskz_add_epi64(msk_ext_2, idxes_2, bytes_2);
281                 tmp_1 = _mm512_mask_i64gather_epi64(zero, msk_ext_1,
282                                 idxes_1, (const void *)dp->tbl8, 8);
283                 tmp_2 = _mm512_mask_i64gather_epi64(zero, msk_ext_2,
284                                 idxes_2, (const void *)dp->tbl8, 8);
285                 new_msk_1 = _mm512_test_epi64_mask(tmp_1, lsb);
286                 new_msk_2 = _mm512_test_epi64_mask(tmp_2, lsb);
287                 res_1 = _mm512_mask_blend_epi64(msk_ext_1 ^ new_msk_1, res_1,
288                                 tmp_1);
289                 res_2 = _mm512_mask_blend_epi64(msk_ext_2 ^ new_msk_2, res_2,
290                                 tmp_2);
291                 tmp_1 = _mm512_srli_epi64(tmp_1, 1);
292                 tmp_2 = _mm512_srli_epi64(tmp_2, 1);
293                 msk_ext_1 = new_msk_1;
294                 msk_ext_2 = new_msk_2;
295
296                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, lsb);
297                 shuf_idxes = _mm512_and_epi64(shuf_idxes, three_lsb);
298                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, base_idxes);
299                 i++;
300         }
301
302         res_1 = _mm512_srli_epi64(res_1, 1);
303         res_2 = _mm512_srli_epi64(res_2, 1);
304         _mm512_storeu_si512(next_hops, res_1);
305         _mm512_storeu_si512(next_hops + 8, res_2);
306 }
307
308 void
309 rte_trie_vec_lookup_bulk_2b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
310         uint64_t *next_hops, const unsigned int n)
311 {
312         uint32_t i;
313         for (i = 0; i < (n / 32); i++) {
314                 trie_vec_lookup_x16x2(p, (uint8_t (*)[16])&ips[i * 32][0],
315                                 next_hops + i * 32, sizeof(uint16_t));
316         }
317         rte_trie_lookup_bulk_2b(p, (uint8_t (*)[16])&ips[i * 32][0],
318                         next_hops + i * 32, n - i * 32);
319 }
320
321 void
322 rte_trie_vec_lookup_bulk_4b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
323         uint64_t *next_hops, const unsigned int n)
324 {
325         uint32_t i;
326         for (i = 0; i < (n / 32); i++) {
327                 trie_vec_lookup_x16x2(p, (uint8_t (*)[16])&ips[i * 32][0],
328                                 next_hops + i * 32, sizeof(uint32_t));
329         }
330         rte_trie_lookup_bulk_4b(p, (uint8_t (*)[16])&ips[i * 32][0],
331                         next_hops + i * 32, n - i * 32);
332 }
333
334 void
335 rte_trie_vec_lookup_bulk_8b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
336         uint64_t *next_hops, const unsigned int n)
337 {
338         uint32_t i;
339         for (i = 0; i < (n / 16); i++) {
340                 trie_vec_lookup_x8x2_8b(p, (uint8_t (*)[16])&ips[i * 16][0],
341                                 next_hops + i * 16);
342         }
343         rte_trie_lookup_bulk_8b(p, (uint8_t (*)[16])&ips[i * 16][0],
344                         next_hops + i * 16, n - i * 16);
345 }