NTT123
add fast cpp wavegru
d1a84ee
raw
history blame
No virus
8.4 kB
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
#include <cstdint>
#include <vector>
#include "absl/time/time.h"
#include "sparse_matmul/compute/matmul_fixed_avx2.h"
#include "sparse_matmul/compute/matmul_generic.h"
#include "sparse_matmul/numerics/fixed_types.h"
#include "sparse_matmul/numerics/type_utils.h"
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
#include <cpuid.h>
#endif
namespace csrblocksparse {
// The number of elements in a block.
constexpr int kBlockSize = 4;
// Base class for Matmul containing the members that are non type-specicfic.
class MatmulBase {
public:
// Constructor initializes the flags that determine which implementation to
// use at run-time, constrained by both compiler flags and cpuid.
MatmulBase() {
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
// Code tested to work on Linux systems and multiple Android emulators.
unsigned int eax, ebx, ecx, edx;
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) {
using_avx_ = (ecx & bit_AVX) != 0;
if (using_avx_) {
__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx);
using_avx2_ = (ebx & bit_AVX2) != 0;
using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) &&
(ebx & bit_AVX512BW) != 0;
VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_;
} else {
LOG(ERROR) << "AVX not found at all!";
}
}
#else
using_aarch64_ = true;
#endif
}
protected:
// Flags that define what (runtime) architectures are available. Flags that
// are set are limited by both the compiler flags and runtime environment.
bool using_avx512_ = false;
bool using_avx2_ = false;
bool using_avx_ = false;
bool using_aarch64_ = false;
};
// The master template is really a catch-all for the unimplmented cases to
// report an error.
template <typename WeightType, typename RhsType>
class Matmul : public MatmulBase {
public:
// Sparse inputs, outputs replicated strided for each thread.
template <typename OutType>
void MatVec4x4(const WeightType* weights, const RhsType* rhs,
const typename TypeOfProduct<WeightType, RhsType>::type* bias,
const int32_t* nnz_per_row, const int16_t* rhs_indices,
int start_row, int end_row, bool relu, int replicas,
int stride, OutType* output) {
// The specializations should take care of every real case.
CHECK(false) << "Unsupported combination of types used!";
}
template <typename OutType>
void MatVec8x4(const WeightType* weights, const RhsType* rhs,
const typename TypeOfProduct<WeightType, RhsType>::type* bias,
const int32_t* nnz_per_row, const int16_t* rhs_indices,
int start_row, int end_row, bool relu, int replicas,
int stride, OutType* output) {
// The specializations should take care of every real case.
CHECK(false) << "Unsupported combination of types used!";
}
};
// Full specialization for float.
template <>
class Matmul<float, float> : public MatmulBase {
public:
void MatVec4x4(const float* weights, const float* rhs, const float* bias,
const int32_t* nnz_per_row, const int16_t* rhs_indices,
int start_row, int end_row, bool relu, int replicas,
int stride, float* output) {
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, /*block_height=*/4,
/*block_width=*/4, relu, replicas, stride,
output);
}
void MatVec8x4(const float* weights, const float* rhs, const float* bias,
const int32_t* nnz_per_row, const int16_t* rhs_indices,
int start_row, int end_row, bool relu, int replicas,
int stride, float* output) {
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, /*block_height=*/8,
/*block_width=*/4, relu, replicas, stride,
output);
}
};
// Partial specialization for fixed types. Covers fixed16xfixed16 = OutType,
// where OutType should be fixed16 or fixed32. The mantissa bits don't have
// to match.
template <int WeightBits, int RhsBits>
class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase {
public:
using WeightType = fixed16<WeightBits>;
using RhsType = fixed16<RhsBits>;
template <typename OutType>
void MatVec4x4(const int16_t* weights, const int16_t* rhs,
const int32_t* bias, const int32_t* nnz_per_row,
const int16_t* rhs_indices, int start_row, int end_row,
bool relu, int replicas, int stride, OutType* output) {
constexpr int kShiftAmount =
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
OutType::kMantissaBits;
static_assert(kShiftAmount >= 0,
"OutType must not have more mantissa bits than inputs");
#if defined __AVX2__
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
if (sizeof(*output) == 4) {
int32_t* out32 = reinterpret_cast<int32_t*>(output);
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, relu, kShiftAmount,
replicas, stride, out32);
} else {
int16_t* out16 = reinterpret_cast<int16_t*>(output);
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, relu, kShiftAmount,
replicas, stride, out16);
}
#elif defined __aarch64__
if (using_aarch64_) {
LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!";
}
#else
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, /*block_height=*/4,
/*block_width=*/4, relu, sizeof(*output),
kShiftAmount, replicas, stride, output);
#endif // __AVX2__
}
template <typename OutType>
void MatVec8x4(const int16_t* weights, const int16_t* rhs,
const int32_t* bias, const int32_t* nnz_per_row,
const int16_t* rhs_indices, int start_row, int end_row,
bool relu, int replicas, int stride, OutType* output) {
constexpr int kShiftAmount =
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
OutType::kMantissaBits;
static_assert(kShiftAmount >= 0,
"OutType must not have more mantissa bits than inputs");
#if defined __AVX2__
CHECK(replicas == 1 && sizeof(*output) == 4)
<< "Only replicas == 1 and fixed32 output are implemented for AVX2!";
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
int32_t* out32 = reinterpret_cast<int32_t*>(output);
detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, relu, kShiftAmount, out32);
#elif defined __aarch64__
if (using_aarch64_) {
LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!";
}
#else
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
start_row, end_row, /*block_height=*/8,
/*block_width=*/4, relu, sizeof(*output),
kShiftAmount, replicas, stride, output);
#endif // __AVX2__
}
};
} // namespace csrblocksparse
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_