File size: 28,625 Bytes
d1a84ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
/*
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_

#if defined __AVX__
#include <immintrin.h>

#include <algorithm>
#include <type_traits>
// TODO(b/188702959): Remove fast_transcendentals with GRU refactor.
#include "sparse_matmul/numerics/fast_transcendentals.h"
#include "sparse_matmul/numerics/fixed_types.h"
#include "sparse_matmul/numerics/float16_types.h"
#include "sparse_matmul/numerics/type_utils.h"

namespace csrblocksparse {
namespace detail {

template <typename WeightType, typename RhsType, typename OutType>
struct IsAllowableFloatTypes
    : std::integral_constant<bool, std::is_same<WeightType, float>::value &&
                                       std::is_same<RhsType, float>::value &&
                                       std::is_same<OutType, float>::value> {};

#if defined __AVX2__
// 16-bit inputs, 32-bit output exponent matches sum of input exponents
// OR
// 16-bit inputs, 16-bit output - will shift to match exponent
template <typename WeightType, typename RhsType, typename OutType>
struct IsAllowableFixedTypes
    : std::integral_constant<bool, (IsFixed16Type<WeightType>::value &&
                                    IsFixed16Type<RhsType>::value) &&
                                       (IsFixed32Type<OutType>::value ||
                                        IsFixed16Type<OutType>::value)> {};

template <typename WeightType, typename RhsType, typename OutType>
struct ShouldEnableGenericKernel
    : std::integral_constant<
          bool,
          !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value &&
              !IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {};

template <typename Type>
struct IsAddableFixedTypes
    : std::integral_constant<bool, IsFixed32Type<Type>::value ||
                                       IsFixed16Type<Type>::value> {};
template <typename Type>
struct ShouldEnableGenericAdd
    : std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {};

#else   // No AVX2.

template <typename WeightType, typename RhsType, typename OutType>
struct ShouldEnableGenericKernel
    : std::integral_constant<
          bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {};

template <typename Type>
struct ShouldEnableGenericAdd : std::true_type {};
#endif  // __AVX2__

template <typename WeightType, typename RhsType, typename OutType>
struct ShouldEnableGenericSpMV_4x4
    : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
template <typename WeightType, typename RhsType, typename OutType>
struct ShouldEnableGenericSpMM5_4x4
    : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
template <typename WeightType, typename RhsType, typename OutType>
struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
template <typename WeightType, typename RhsType, typename OutType>
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};

// The computational routines do NO error checking for speed.  It is assumed
// that this has been handled by CSRBlockSparseMatrix.

// In-line function to extract results from a pair of registers and store in
// memory. Note that the non-const references are registers, and are modified
// by this function!
inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2,
                            float** out_ptr) {
  // Horizontally add the results. We have 2 registers, |sum1| and |sum2| that
  // each contain 2 sets of 4 values that need to be added.
  sum1 = _mm256_hadd_ps(sum1, sum2);
  sum1 = _mm256_hadd_ps(sum1, sum1);
  // Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|,
  // |res1|, |res3|]
  if (relu) {
    sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps());
  }
  // It is really hard in AVX to cross the 128 bit 'lanes' and this is the
  // *only* way to do it.
  // Get the top half of |sum1| in to bottom of |sum2|.
  sum2 = _mm256_permute2f128_ps(sum1, sum1, 1);
  // Interleave the values between the two registers.
  sum1 = _mm256_unpacklo_ps(sum1, sum2);
  // Save the lower 128 bits (4 floats).
  __m128 result = _mm256_extractf128_ps(sum1, 0);
  _mm_store_ps(*out_ptr, result);
  *out_ptr += 4;
}

// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
// blocked pattern, x is a vector and b is vector. Weights are stored for this
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
// row-major format. column indices are converted to deltas and then multiplied
// by 2 to convert to bytes, so that the value can be used directly to offset
// the pointer into the rhs vector.
//
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
// this function.  This is automatically taken care of in SparseLinearLayer.
// The bias is reconstructed through horizontal additions, leads to a small
// speedup by reducing latencies at the end of the loop.
template <typename WeightType, typename RhsType, typename OutType>
typename std::enable_if<std::is_same<WeightType, float>::value &&
                        std::is_same<RhsType, float>::value &&
                        std::is_same<OutType, float>::value>::type
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
         const int32_t* nnz_per_row, const RhsType* rhs_ptr,
         const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
         OutType* out_ptr, int64_t assigned_rows,
         int64_t rows /* only used in SpMM variants */,
         int64_t cols /* only used in SpMM variants */, int relu) {
  for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
    // Broadcast the biases by 4 to undo the division by 4 in the input biases.
    __m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
                                  _mm_broadcast_ss(bias_ptr));
    bias_ptr += 2;
    __m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
                                  _mm_broadcast_ss(bias_ptr));
    bias_ptr += 2;

    int reduced_col_count = *nnz_per_row++;
    for (int c = 0; c < reduced_col_count; ++c) {
      int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
      rhs_ptr += col_delta;
      // Multiply this 4x4 block.
      __m256 rhs =
          _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
      __m256 weights1 = _mm256_load_ps(weights_ptr);
      weights_ptr += 8;
      sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs));
      __m256 weights2 = _mm256_load_ps(weights_ptr);
      weights_ptr += 8;
      sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs));
    }
    Extract4Results(relu, sum1, sum2, &out_ptr);
  }
}

// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
// broadcast. Weights are stored for this routine by making each 4x4 block
// contiguous. Blocks are ordered in standard row-major format. column indices
// are converted to deltas and then multiplied by 2 to convert to bytes, so
// that the value can be used directly to offset the pointer into the rhs
// vector.
//
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
// this function.  This is automatically taken care of in SparseLinearLayer.
// The bias is reconstructed through horizontal additions, leads to a small
// speedup by reducing latencies at the end of the loop.
template <typename WeightType, typename RhsType, typename OutType>
typename std::enable_if<std::is_same<WeightType, float>::value &&
                        std::is_same<RhsType, float>::value &&
                        std::is_same<OutType, float>::value>::type
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
          const int32_t* nnz_per_row, const RhsType* rhs_ptr,
          const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
          OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
          int relu) {
  const RhsType* rhs_ptrs[5];
  for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;

  OutType* out_ptrs[5];
  for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;

  for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
    // We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|.
    // Broadcast the biases by 4 to undo the division by 4 in the input biases.
    __m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
                                    _mm_broadcast_ss(bias_ptr));
    bias_ptr += 2;
    __m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
                                    _mm_broadcast_ss(bias_ptr));
    bias_ptr += 2;
    __m256 sum1_1 = sum1_0;
    __m256 sum2_1 = sum2_0;
    __m256 sum1_2 = sum1_0;
    __m256 sum2_2 = sum2_0;
    __m256 sum1_3 = sum1_0;
    __m256 sum2_3 = sum2_0;
    __m256 sum1_4 = sum1_0;
    __m256 sum2_4 = sum2_0;

    int reduced_col_count = *nnz_per_row++;
    for (int c = 0; c < reduced_col_count; ++c) {
      int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
      for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;

      // Multiply this 4x4 block.
      __m256 rhs =
          _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0]));
      __m256 weights1 = _mm256_load_ps(weights_ptr);
      weights_ptr += 8;
      sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs));
      __m256 weights2 = _mm256_load_ps(weights_ptr);
      weights_ptr += 8;
      sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs));
      rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1]));
      sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs));
      sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs));
      rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2]));
      sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs));
      sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs));
      rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3]));
      sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs));
      sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs));
      rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4]));
      sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs));
      sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs));
    }

    Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]);
    Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]);
    Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]);
    Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]);
    Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]);
  }
}

#ifdef __AVX2__

// In-line function to finish the computation of the result as 4x int32 in
// |sum|.
inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) {
  // Horizontally add the results. We have 1 register that contains results
  // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
  // cross lanes, so we end up with [0 1 0 1 2 3 2 3]
  sum = _mm256_hadd_epi32(sum, sum);
  // Permutes the middle two pairs to get the answers together.
  sum = _mm256_permute4x64_epi64(sum, 0xd8);
  if (kShiftAmount > 0) {
    // Shift right with rounding to get the right number of mantissa bits.
    __m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1));
    sum = _mm256_add_epi32(sum, rounding);
    sum = _mm256_srai_epi32(sum, kShiftAmount);
  }
  // Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|,
  // |res2|, |res3|]
  if (relu) {
    sum = _mm256_max_epi32(sum, _mm256_setzero_si256());
  }
}

// In-line function to extract the 4x int32 results from |sum| to memory.
// Non-const reference for |sum| as it is a register.
inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum,
                           int32_t** out_ptr) {
  Compute4Results(relu, kShiftAmount, sum);
  // Save the lower 128 bits (4x int32).
  __m128i result = _mm256_extractf128_si256(sum, 0);
  _mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result);
  *out_ptr += 4;
}

// In-line function to extract the 4x int32 results from sum to 4x int16 in
// memory.
// Non-const reference for |sum| as it is a register.
inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum,
                           int16_t** out_ptr) {
  Compute4Results(relu, kShiftAmount, sum);
  // Clip to 16 bit range (with saturation) and pack in the bottom 64 bits.
  // Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64
  // bits, replicated in the next 64 bits.
  sum = _mm256_packs_epi32(sum, sum);
  // Save 4x int 16 from the bottom 64 bits.
  *reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0);
  *out_ptr += 4;
}

// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
// blocked pattern, x is a vector and b is vector. Weights are stored for this
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
// row-major format. column indices are converted to deltas and then multiplied
// by 2 to convert to bytes, so that the value can be used directly to offset
// the pointer into the rhs vector.
//
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
// this function.  This is automatically taken care of in  SparseLinearLayer.
// The bias is reconstructed through horizontal additions, leads to a small
// speedup by reducing latencies at the end of the loop.
template <typename WeightType, typename RhsType, typename OutType>
typename std::enable_if<
    IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
    (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
         const int32_t* nnz_per_row, const RhsType* rhs_ptr,
         const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
         OutType* out_ptr, int64_t assigned_rows,
         int64_t rows /* only used in SpMM variants */,
         int64_t cols /* only used in SpMM variants */, int relu) {
  constexpr int kShiftAmount =
      TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
      OutType::kMantissaBits;
  static_assert(kShiftAmount >= 0,
                "Result must have fewer mantissa bits than product");
  for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
    // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
    __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
    __m256i biases = _mm256_set_m128i(bias, bias);
    bias_ptr += 4;
    // Swap the top two pairs: [0 1 2 3 2 3 0 1]
    // TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index
    // register outside the row loop.
    biases = _mm256_permute4x64_epi64(biases, 0xb4);
    // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
    biases = _mm256_unpacklo_epi32(biases, biases);
    // Double the results to make up for the division by 4.
    // TODO(b/188702959): consider moving this to where the biases are computed.
    __m256i sum = _mm256_add_epi32(biases, biases);

    // TODO(b/188702959): People don't like the old-fashioned, close-to-the-
    // metal notation of *|nnz_per_row|++, so measure the effect of putting the
    // increment in the for loop.
    int reduced_col_count = *nnz_per_row;
    ++nnz_per_row;
    for (int c = 0; c < reduced_col_count; ++c) {
      int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
      rhs_ptr += col_delta;
      // Multiply this 4x4 block.
      // Get the 4x int16 into the bottom of rhs_64.
      __m128i rhs_64 =
          _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr));
      // Load all 16 weights.
      __m256i weights =
          _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
      // Broadcast the rhs, pretending that each is a 64-bit unit:
      // [0123 0123 0123 0123].
      __m256i rhs = _mm256_broadcastq_epi64(rhs_64);
      weights_ptr += 16;
      // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
      // adds adjacent pairs to make 8x32 bit results. Add these to the sum.
      sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs));
    }
    static_assert(
        IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
        "AVX2 kernel only supports fixed16 and fixed32 types");
    // The only significant difference between fixed16 and fixed32 is the size
    // of the storage unit. The registers have to be repacked accordingly.
    if (IsFixed32Type<OutType>::value) {
      Extract4xint32(relu, kShiftAmount, sum,
                     reinterpret_cast<int32_t**>(&out_ptr));
    } else {
      Extract4xint16(relu, kShiftAmount, sum,
                     reinterpret_cast<int16_t**>(&out_ptr));
    }
  }
}

// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
// broadcast. Weights are stored for this routine by making each 4x4 block
// contiguous. Blocks are ordered in standard row-major format. column indices
// are converted to deltas and then multiplied by 2 to convert to bytes, so
// that the value can be used directly to offset the pointer into the rhs
// vector.
//
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
// this function.  This is automatically taken care of in SparseLinearLayer.
// The bias is reconstructed through horizontal additions, leads to a small
// speedup by reducing latencies at the end of the loop.
template <typename WeightType, typename RhsType, typename OutType>
typename std::enable_if<
    IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
    (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
          const int32_t* nnz_per_row, const RhsType* rhs_ptr,
          const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
          OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
          int relu) {
  constexpr int kShiftAmount =
      TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
      OutType::kMantissaBits;
  static_assert(kShiftAmount >= 0,
                "Result must have fewer mantissa bits than product");
  const RhsType* rhs_ptrs[5];
  for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;

  OutType* out_ptrs[5];
  for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;

  for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
    // We will acumulate the results in 5 registers, sum_0 to sum_4.
    // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
    __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
    __m256i biases = _mm256_set_m128i(bias, bias);
    bias_ptr += 4;
    // Swap the top two pairs: [0 1 2 3 2 3 0 1]
    biases = _mm256_permute4x64_epi64(biases, 0xb4);
    // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
    biases = _mm256_unpacklo_epi32(biases, biases);
    // Double the results to make up for the division by 4.
    __m256i sum_0 = _mm256_add_epi32(biases, biases);
    __m256i sum_1 = sum_0;
    __m256i sum_2 = sum_0;
    __m256i sum_3 = sum_0;
    __m256i sum_4 = sum_0;

    int reduced_col_count = *nnz_per_row;
    ++nnz_per_row;
    for (int c = 0; c < reduced_col_count; ++c) {
      int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
      for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
      // Multiply this 4x4 block.
      // Get the 4x int16 into the bottom of |rhs_64|.
      __m128i rhs_64 =
          _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0]));
      // Load all 16 weights.
      __m256i weights =
          _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
      // Broadcast the rhs, pretending that each is a 64-bit unit:
      // [0123 0123 0123 0123].
      __m256i rhs = _mm256_broadcastq_epi64(rhs_64);
      weights_ptr += 16;
      // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
      // adds adjacent pairs to make 8x32 bit results. Add these to the sum.
      sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs));
      rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1]));
      rhs = _mm256_broadcastq_epi64(rhs_64);
      sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs));
      rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2]));
      rhs = _mm256_broadcastq_epi64(rhs_64);
      sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs));
      rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3]));
      rhs = _mm256_broadcastq_epi64(rhs_64);
      sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs));
      rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4]));
      rhs = _mm256_broadcastq_epi64(rhs_64);
      sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs));
    }
    static_assert(
        IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
        "AVX2 kernel only supports fixed16 and fixed32 types");
    // The only significant difference between fixed16 and fixed32 is the size
    // of the storage unit. The registers have to be repacked accordingly.
    if (IsFixed32Type<OutType>::value) {
      Extract4xint32(relu, kShiftAmount, sum_0,
                     reinterpret_cast<int32_t**>(&out_ptrs[0]));
      Extract4xint32(relu, kShiftAmount, sum_1,
                     reinterpret_cast<int32_t**>(&out_ptrs[1]));
      Extract4xint32(relu, kShiftAmount, sum_2,
                     reinterpret_cast<int32_t**>(&out_ptrs[2]));
      Extract4xint32(relu, kShiftAmount, sum_3,
                     reinterpret_cast<int32_t**>(&out_ptrs[3]));
      Extract4xint32(relu, kShiftAmount, sum_4,
                     reinterpret_cast<int32_t**>(&out_ptrs[4]));
    } else {
      Extract4xint16(relu, kShiftAmount, sum_0,
                     reinterpret_cast<int16_t**>(&out_ptrs[0]));
      Extract4xint16(relu, kShiftAmount, sum_1,
                     reinterpret_cast<int16_t**>(&out_ptrs[1]));
      Extract4xint16(relu, kShiftAmount, sum_2,
                     reinterpret_cast<int16_t**>(&out_ptrs[2]));
      Extract4xint16(relu, kShiftAmount, sum_3,
                     reinterpret_cast<int16_t**>(&out_ptrs[3]));
      Extract4xint16(relu, kShiftAmount, sum_4,
                     reinterpret_cast<int16_t**>(&out_ptrs[4]));
    }
  }
}

// Processes one GRU gate input with sigmoid.
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates>
inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr,
                              const __m256i& input,
                              const int32_t* sigmoid_table) {
  __m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr));
  if (SplitGates) {
    __m256i other =
        _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr));
    gate = _mm256_add_epi32(gate, other);
  }
  gate = _mm256_add_epi32(gate, input);
  // Compute sigmoids on reset and update.
  return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits,
                                                 StateMantissaBits>(
      sigmoid_table, gate);
}

// Processes the tanh and the final combination, returning the new GRU state.
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false>
inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset,
                            const __m256i& update,
                            const __m256i& rounding_offset,
                            const void* gate_ptr, const void* gate_other_ptr,
                            const void* gru_h_ptr, const int32_t* tanh_table) {
  // Multiply the cell GRU output and the reset. There is a slight danger of
  // loss of precision here, so use 32x32=64 bit and shift back after.
  __m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr));
  if (SplitGates) {
    __m256i other_gru =
        _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr));
    gru = _mm256_add_epi32(gru, other_gru);
  }
  // This only computes the products of the low-order 32 bits of each pair.
  __m256i gru_lo = _mm256_mul_epi32(gru, reset);
  // Swap odd and even 32-bit units and do it again to get the high products.
  gru = _mm256_shuffle_epi32(gru, 0xb1);
  __m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1));
  // Now shift right to compensate for the multiply and re-interleave the
  // 32-bit results.
  // NOTE: There is no shift right arithmetic for 64 bit values until AVX512!
  // Fortunately it doesn't matter, as the results are being truncated to 32
  // bits and we aren't shifting right by more than 32 bits here.
  gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits);
  // The upper results are shifted LEFT, so we can use blend to recombine in
  // a single instruction.
  gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits);
  // Recombine the 32 bit results from lo and hi, alternating.
  gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa);
  gru = _mm256_add_epi32(cell, gru);
  // Compute tanh on the result. Although this instantly discards a bunch of
  // bits, there were only 7 surplus bits for the multiply, which isn't enough
  // to do it as 16x16=32.
  __m256i hbar =
      csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits,
                                           StateMantissaBits>(tanh_table, gru);
  // Load the 16-bit previous GRU state and sign-extend to 32 bits.
  gru = _mm256_cvtepi16_epi32(
      _mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr)));
  gru = _mm256_sub_epi32(gru, hbar);
  // Since |gru| is 16 bit sign-extended to 32, and |update| is the output of
  // sigmoid, it is always contained within 16 bits and never negative, we can
  // use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the
  // addend will always be zero, and this is twice as fast as full blown
  // 32x32=32. The only possible problem is if the subtract above caused
  // overflow.
  gru = _mm256_madd_epi16(gru, update);
  // Renormalize to fixed16. This time rounding is critical, as this is the
  // output GRU state.
  gru = _mm256_add_epi32(gru, rounding_offset);
  gru = _mm256_srai_epi32(gru, StateMantissaBits);
  return _mm256_add_epi32(gru, hbar);
}

template <typename Type>
typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors(
    int start, int end, const Type* add1, const Type* add2, Type* result) {
  constexpr int kSIMDWidth = 8;
  for (int i = start; i < end; i += kSIMDWidth) {
    __m256i data1 =
        _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
    __m256i data2 =
        _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
    data1 = _mm256_add_epi32(data1, data2);
    _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
  }
}

template <typename Type>
typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors(
    int start, int end, const Type* add1, const Type* add2, Type* result) {
  constexpr int kSIMDWidth = 16;
  for (int i = start; i < end; i += kSIMDWidth) {
    __m256i data1 =
        _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
    __m256i data2 =
        _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
    data1 = _mm256_add_epi16(data1, data2);
    _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
  }
}

#endif  // __AVX2__

}  // namespace detail
}  // namespace csrblocksparse

#undef LABEL_COL_LOOP
#undef LABEL_ROW_LOOP
#undef LABEL_SKIP_COL_LOOP
#undef LABEL_TOP_LOOP

#endif  // __AVX__

#endif  // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_