/* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ #include #include #include #include #include "absl/strings/str_format.h" #include "sparse_matmul/vector/cache_aligned_vector.h" namespace csrblocksparse { // MaskedSparseMatrix serves two purposes: // 1) It is useful as a reference implementation of SpMV for correctness // checking the much more complicated implementations in CSRBlockSparseMatrix // 2) This is the format that sparse matrices are represented after pruning // in TF. This class provides a bridge to getting these parameters into // a compressed form suitable for computation and serialization. // // MaskedSparseMatrix matrix(rows, cols, mask_from_tf, values_from_tf); // CSRBlockSparseMatrix csr_matrix(matrix); // csr_matrix.Multiply(rhs, bias, &out); template class MaskedSparseMatrix { public: MaskedSparseMatrix() {} // Construct a MaskedSparseMatrix of the given size, sparsity and block size. // This is mainly useful for testing. MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1, int block_width = 1, float constant = 1.f, bool random = true) : rows_(rows), cols_(cols), sparsity_(sparsity) { CHECK_EQ(rows % block_height, 0); CHECK_EQ(cols % block_width, 0); init(sparsity, block_height, block_width, constant, random); } // Construct from an existing mask and values (most likely from a TF model). template MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values) : rows_(rows), cols_(cols) { mask_.resize(rows * cols); values_.resize(rows * cols); std::copy_n(mask, rows * cols, mask_.begin()); std::copy_n(values, rows * cols, values_.begin()); sparsity_ = 1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size(); } const std::vector& mask() const { return mask_; } const std::vector& values() const { return values_; } T* data() { return values_.data(); } const T* data() const { return values_.data(); } int rows() const { return rows_; } int cols() const { return cols_; } float sparsity() const { return sparsity_; } void Print() const { absl::PrintF("-------Values---------\n"); for (int r = 0; r < rows_; ++r) { for (int c = 0; c < cols_; ++c) { absl::PrintF("%+6.3f ", static_cast(values_[r * cols_ + c])); } absl::PrintF("\n"); } absl::PrintF("-------Mask---------\n"); for (int r = 0; r < rows_; ++r) { for (int c = 0; c < cols_; ++c) { printf("%2d ", mask_[r * cols_ + c]); } absl::PrintF("\n"); } } // This routine is useful for rounding the possibly higher precision values // stored in this class to a lower precision, so that correctness checks // between this class and CSRBlockSparseMatrix can have a tighter tolerance. template void CastWeights() { for (int i = 0; i < values_.size(); ++i) { values_[i] = static_cast(U(values_[i])); } } // Only meant for correctness checking. // RhsClassType is meant to be either CacheAlignedVector OR // FatCacheAlignedVector. // The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR. // |bias| is broadcast if |rhs| has more than one column. template void SpMM_bias(const RhsClassType& rhs, const CacheAlignedVector& bias, OutClassType* out, bool relu = false) { for (int r = 0; r < rows_; ++r) { for (int n = 0; n < rhs.cols(); ++n) { float sum = 0.f; const RhsType* rhs_ptr = rhs.data() + n * rhs.rows(); OutType* out_ptr = out->data() + n * out->rows(); const int* mask_ptr = mask_.data() + r * cols_; const T* value_ptr = values_.data() + r * cols_; for (int c = 0; c < cols_; ++c) { sum += mask_ptr[c] * static_cast(value_ptr[c]) * static_cast(rhs_ptr[c]); } out_ptr[r] = static_cast( relu ? std::max(sum + static_cast(bias[r]), 0.f) : sum + static_cast(bias[r])); } } } private: // Generate a random matrix with the specified sparsity. // Useful for testing. void init(float sparsity, int block_height, int block_width, float constant, bool random = true) { int reduced_rows = rows_ / block_height; int reduced_cols = cols_ / block_width; mask_.resize(rows_ * cols_, 0); // Fill with non-zero value to make sure masking works. values_.resize(rows_ * cols_, static_cast(2.f)); std::mt19937 generator(0); std::uniform_real_distribution dist_sparsity; std::uniform_real_distribution dist_value(-1.f, 1.f); int nnz = 0; while (nnz == 0) { for (int r = 0; r < reduced_rows; ++r) { for (int c = 0; c < reduced_cols; ++c) { if (dist_sparsity(generator) > sparsity) { nnz++; for (int i = 0; i < block_height; ++i) { for (int j = 0; j < block_width; ++j) { mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1; values_[(r * block_height + i) * cols_ + block_width * c + j] = static_cast(random ? dist_value(generator) : constant); } } } } } } } std::vector mask_; std::vector values_; int rows_; int cols_; float sparsity_; }; template class MaskedLinearLayer { public: MaskedLinearLayer(MaskedSparseMatrix&& weights, CacheAlignedVector&& bias) : weights_(std::move(weights)), bias_(std::move(bias)) {} MaskedLinearLayer() {} template void CastWeights() { weights_.template CastWeights(); } // Does Ax + b where A is a masked sparse ROW MAJOR matrix and // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is // broadcast is rhs has more than one column. template void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) { static_assert(std::is_same::value, "FatVector value_type must match masked_linear_layer type"); weights_.SpMM_bias(rhs, bias_, out, relu); } private: MaskedSparseMatrix weights_; CacheAlignedVector bias_; }; } // namespace csrblocksparse #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_