fib6: add AVX512 lookup
[dpdk.git] / lib / librte_fib / trie_avx512.c
1 /* SPDX-License-Identifier: BSD-3-Clause
2  * Copyright(c) 2020 Intel Corporation
3  */
4
5 #include <rte_vect.h>
6 #include <rte_fib6.h>
7
8 #include "trie.h"
9 #include "trie_avx512.h"
10
11 static __rte_always_inline void
12 transpose_x16(uint8_t ips[16][RTE_FIB6_IPV6_ADDR_SIZE],
13         __m512i *first, __m512i *second, __m512i *third, __m512i *fourth)
14 {
15         __m512i tmp1, tmp2, tmp3, tmp4;
16         __m512i tmp5, tmp6, tmp7, tmp8;
17         const __rte_x86_zmm_t perm_idxes = {
18                 .u32 = { 0, 4, 8, 12, 2, 6, 10, 14,
19                         1, 5, 9, 13, 3, 7, 11, 15
20                 },
21         };
22
23         /* load all ip addresses */
24         tmp1 = _mm512_loadu_si512(&ips[0][0]);
25         tmp2 = _mm512_loadu_si512(&ips[4][0]);
26         tmp3 = _mm512_loadu_si512(&ips[8][0]);
27         tmp4 = _mm512_loadu_si512(&ips[12][0]);
28
29         /* transpose 4 byte chunks of 16 ips */
30         tmp5 = _mm512_unpacklo_epi32(tmp1, tmp2);
31         tmp7 = _mm512_unpackhi_epi32(tmp1, tmp2);
32         tmp6 = _mm512_unpacklo_epi32(tmp3, tmp4);
33         tmp8 = _mm512_unpackhi_epi32(tmp3, tmp4);
34
35         tmp1 = _mm512_unpacklo_epi32(tmp5, tmp6);
36         tmp3 = _mm512_unpackhi_epi32(tmp5, tmp6);
37         tmp2 = _mm512_unpacklo_epi32(tmp7, tmp8);
38         tmp4 = _mm512_unpackhi_epi32(tmp7, tmp8);
39
40         /* first 4-byte chunks of ips[] */
41         *first = _mm512_permutexvar_epi32(perm_idxes.z, tmp1);
42         /* second 4-byte chunks of ips[] */
43         *second = _mm512_permutexvar_epi32(perm_idxes.z, tmp3);
44         /* third 4-byte chunks of ips[] */
45         *third = _mm512_permutexvar_epi32(perm_idxes.z, tmp2);
46         /* fourth 4-byte chunks of ips[] */
47         *fourth = _mm512_permutexvar_epi32(perm_idxes.z, tmp4);
48 }
49
50 static __rte_always_inline void
51 transpose_x8(uint8_t ips[8][RTE_FIB6_IPV6_ADDR_SIZE],
52         __m512i *first, __m512i *second)
53 {
54         __m512i tmp1, tmp2, tmp3, tmp4;
55         const __rte_x86_zmm_t perm_idxes = {
56                 .u64 = { 0, 2, 4, 6, 1, 3, 5, 7
57                 },
58         };
59
60         tmp1 = _mm512_loadu_si512(&ips[0][0]);
61         tmp2 = _mm512_loadu_si512(&ips[4][0]);
62
63         tmp3 = _mm512_unpacklo_epi64(tmp1, tmp2);
64         *first = _mm512_permutexvar_epi64(perm_idxes.z, tmp3);
65         tmp4 = _mm512_unpackhi_epi64(tmp1, tmp2);
66         *second = _mm512_permutexvar_epi64(perm_idxes.z, tmp4);
67 }
68
69 static __rte_always_inline void
70 trie_vec_lookup_x16(void *p, uint8_t ips[16][RTE_FIB6_IPV6_ADDR_SIZE],
71         uint64_t *next_hops, int size)
72 {
73         struct rte_trie_tbl *dp = (struct rte_trie_tbl *)p;
74         const __m512i zero = _mm512_set1_epi32(0);
75         const __m512i lsb = _mm512_set1_epi32(1);
76         const __m512i two_lsb = _mm512_set1_epi32(3);
77         __m512i first, second, third, fourth; /*< IPv6 four byte chunks */
78         __m512i idxes, res, shuf_idxes;
79         __m512i tmp, tmp2, bytes, byte_chunk, base_idxes;
80         /* used to mask gather values if size is 2 (16 bit next hops) */
81         const __m512i res_msk = _mm512_set1_epi32(UINT16_MAX);
82         const __rte_x86_zmm_t bswap = {
83                 .u8 = { 2, 1, 0, 255, 6, 5, 4, 255,
84                         10, 9, 8, 255, 14, 13, 12, 255,
85                         2, 1, 0, 255, 6, 5, 4, 255,
86                         10, 9, 8, 255, 14, 13, 12, 255,
87                         2, 1, 0, 255, 6, 5, 4, 255,
88                         10, 9, 8, 255, 14, 13, 12, 255,
89                         2, 1, 0, 255, 6, 5, 4, 255,
90                         10, 9, 8, 255, 14, 13, 12, 255
91                         },
92         };
93         const __mmask64 k = 0x1111111111111111;
94         int i = 3;
95         __mmask16 msk_ext, new_msk;
96         __mmask16 exp_msk = 0x5555;
97
98         transpose_x16(ips, &first, &second, &third, &fourth);
99
100         /* get_tbl24_idx() for every 4 byte chunk */
101         idxes = _mm512_shuffle_epi8(first, bswap.z);
102
103         /**
104          * lookup in tbl24
105          * Put it inside branch to make compiller happy with -O0
106          */
107         if (size == sizeof(uint16_t)) {
108                 res = _mm512_i32gather_epi32(idxes, (const int *)dp->tbl24, 2);
109                 res = _mm512_and_epi32(res, res_msk);
110         } else
111                 res = _mm512_i32gather_epi32(idxes, (const int *)dp->tbl24, 4);
112
113
114         /* get extended entries indexes */
115         msk_ext = _mm512_test_epi32_mask(res, lsb);
116
117         tmp = _mm512_srli_epi32(res, 1);
118
119         /* idxes to retrieve bytes */
120         shuf_idxes = _mm512_setr_epi32(3, 7, 11, 15,
121                                 19, 23, 27, 31,
122                                 35, 39, 43, 47,
123                                 51, 55, 59, 63);
124
125         base_idxes = _mm512_setr_epi32(0, 4, 8, 12,
126                                 16, 20, 24, 28,
127                                 32, 36, 40, 44,
128                                 48, 52, 56, 60);
129
130         /* traverse down the trie */
131         while (msk_ext) {
132                 idxes = _mm512_maskz_slli_epi32(msk_ext, tmp, 8);
133                 byte_chunk = (i < 8) ?
134                         ((i >= 4) ? second : first) :
135                         ((i >= 12) ? fourth : third);
136                 bytes = _mm512_maskz_shuffle_epi8(k, byte_chunk, shuf_idxes);
137                 idxes = _mm512_maskz_add_epi32(msk_ext, idxes, bytes);
138                 if (size == sizeof(uint16_t)) {
139                         tmp = _mm512_mask_i32gather_epi32(zero, msk_ext,
140                                 idxes, (const int *)dp->tbl8, 2);
141                         tmp = _mm512_and_epi32(tmp, res_msk);
142                 } else
143                         tmp = _mm512_mask_i32gather_epi32(zero, msk_ext,
144                                 idxes, (const int *)dp->tbl8, 4);
145                 new_msk = _mm512_test_epi32_mask(tmp, lsb);
146                 res = _mm512_mask_blend_epi32(msk_ext ^ new_msk, res, tmp);
147                 tmp = _mm512_srli_epi32(tmp, 1);
148                 msk_ext = new_msk;
149
150                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, lsb);
151                 shuf_idxes = _mm512_and_epi32(shuf_idxes, two_lsb);
152                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, base_idxes);
153                 i++;
154         }
155
156         res = _mm512_srli_epi32(res, 1);
157         tmp = _mm512_maskz_expand_epi32(exp_msk, res);
158         __m256i tmp256;
159         tmp256 = _mm512_extracti32x8_epi32(res, 1);
160         tmp2 = _mm512_maskz_expand_epi32(exp_msk,
161                 _mm512_castsi256_si512(tmp256));
162         _mm512_storeu_si512(next_hops, tmp);
163         _mm512_storeu_si512(next_hops + 8, tmp2);
164 }
165
166 static void
167 trie_vec_lookup_x8_8b(void *p, uint8_t ips[8][RTE_FIB6_IPV6_ADDR_SIZE],
168         uint64_t *next_hops)
169 {
170         struct rte_trie_tbl *dp = (struct rte_trie_tbl *)p;
171         const __m512i zero = _mm512_set1_epi32(0);
172         const __m512i lsb = _mm512_set1_epi32(1);
173         const __m512i three_lsb = _mm512_set1_epi32(7);
174         __m512i first, second; /*< IPv6 eight byte chunks */
175         __m512i idxes, res, shuf_idxes;
176         __m512i tmp, bytes, byte_chunk, base_idxes;
177         const __rte_x86_zmm_t bswap = {
178                 .u8 = { 2, 1, 0, 255, 255, 255, 255, 255,
179                         10, 9, 8, 255, 255, 255, 255, 255,
180                         2, 1, 0, 255, 255, 255, 255, 255,
181                         10, 9, 8, 255, 255, 255, 255, 255,
182                         2, 1, 0, 255, 255, 255, 255, 255,
183                         10, 9, 8, 255, 255, 255, 255, 255,
184                         2, 1, 0, 255, 255, 255, 255, 255,
185                         10, 9, 8, 255, 255, 255, 255, 255
186                         },
187         };
188         const __mmask64 k = 0x101010101010101;
189         int i = 3;
190         __mmask8 msk_ext, new_msk;
191
192         transpose_x8(ips, &first, &second);
193
194         /* get_tbl24_idx() for every 4 byte chunk */
195         idxes = _mm512_shuffle_epi8(first, bswap.z);
196
197         /* lookup in tbl24 */
198         res = _mm512_i64gather_epi64(idxes, (const void *)dp->tbl24, 8);
199         /* get extended entries indexes */
200         msk_ext = _mm512_test_epi64_mask(res, lsb);
201
202         tmp = _mm512_srli_epi64(res, 1);
203
204         /* idxes to retrieve bytes */
205         shuf_idxes = _mm512_setr_epi64(3, 11, 19, 27, 35, 43, 51, 59);
206
207         base_idxes = _mm512_setr_epi64(0, 8, 16, 24, 32, 40, 48, 56);
208
209         /* traverse down the trie */
210         while (msk_ext) {
211                 idxes = _mm512_maskz_slli_epi64(msk_ext, tmp, 8);
212                 byte_chunk = (i < 8) ? first : second;
213                 bytes = _mm512_maskz_shuffle_epi8(k, byte_chunk, shuf_idxes);
214                 idxes = _mm512_maskz_add_epi64(msk_ext, idxes, bytes);
215                 tmp = _mm512_mask_i64gather_epi64(zero, msk_ext,
216                                 idxes, (const void *)dp->tbl8, 8);
217                 new_msk = _mm512_test_epi64_mask(tmp, lsb);
218                 res = _mm512_mask_blend_epi64(msk_ext ^ new_msk, res, tmp);
219                 tmp = _mm512_srli_epi64(tmp, 1);
220                 msk_ext = new_msk;
221
222                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, lsb);
223                 shuf_idxes = _mm512_and_epi64(shuf_idxes, three_lsb);
224                 shuf_idxes = _mm512_maskz_add_epi8(k, shuf_idxes, base_idxes);
225                 i++;
226         }
227
228         res = _mm512_srli_epi64(res, 1);
229         _mm512_storeu_si512(next_hops, res);
230 }
231
232 void
233 rte_trie_vec_lookup_bulk_2b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
234         uint64_t *next_hops, const unsigned int n)
235 {
236         uint32_t i;
237         for (i = 0; i < (n / 16); i++) {
238                 trie_vec_lookup_x16(p, (uint8_t (*)[16])&ips[i * 16][0],
239                                 next_hops + i * 16, sizeof(uint16_t));
240         }
241         rte_trie_lookup_bulk_2b(p, (uint8_t (*)[16])&ips[i * 16][0],
242                         next_hops + i * 16, n - i * 16);
243 }
244
245 void
246 rte_trie_vec_lookup_bulk_4b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
247         uint64_t *next_hops, const unsigned int n)
248 {
249         uint32_t i;
250         for (i = 0; i < (n / 16); i++) {
251                 trie_vec_lookup_x16(p, (uint8_t (*)[16])&ips[i * 16][0],
252                                 next_hops + i * 16, sizeof(uint32_t));
253         }
254         rte_trie_lookup_bulk_4b(p, (uint8_t (*)[16])&ips[i * 16][0],
255                         next_hops + i * 16, n - i * 16);
256 }
257
258 void
259 rte_trie_vec_lookup_bulk_8b(void *p, uint8_t ips[][RTE_FIB6_IPV6_ADDR_SIZE],
260         uint64_t *next_hops, const unsigned int n)
261 {
262         uint32_t i;
263         for (i = 0; i < (n / 8); i++) {
264                 trie_vec_lookup_x8_8b(p, (uint8_t (*)[16])&ips[i * 8][0],
265                                 next_hops + i * 8);
266         }
267         rte_trie_lookup_bulk_8b(p, (uint8_t (*)[16])&ips[i * 8][0],
268                         next_hops + i * 8, n - i * 8);
269 }