File size: 3,757 Bytes
d1a84ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "sparse_matmul/numerics/fast_transcendentals.h"

namespace csrblocksparse {

// Maximum desired precision of the output.
static constexpr int kMaxMantissaBits = 14;

// Returns (and builds if not done yet) a static data table that implements
// tanh on fixed32 input, returning another fixed32 with the given number of
// mantissa bits (which is assumed to be less than the input mantissa bits).
// NOTE that this function is intended to be used only with fixed16 outputs that
// are sign-extended to 32 bits for convenience, and will return a nullptr
// if asked for more than |kMaxMantissaBits| of precision in the output table.
const int32_t* TanhTable(int num_mantissa_bits_out) {
  if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr;
  // Static data dynamically created and never destructed.
  static const int32_t* tanh_luts[kMaxMantissaBits];
  if (tanh_luts[num_mantissa_bits_out - 1] == nullptr) {
    // Total bits is number each side of the binary point.
    int tanh_lut_bits = num_mantissa_bits_out + kNumTanhExpBits;
    // Offset is the number of negative numbers represented.
    int tanh_offset = 1 << tanh_lut_bits;
    // Size is double the offset plus one more for zero.
    int tanh_size = tanh_offset * 2 + 1;
    // Conversion between int and float.
    float float_factor = static_cast<float>(1 << num_mantissa_bits_out);
    int* tanh_lut = new int[tanh_size];
    // Initialize the table.
    for (int i = 0; i < tanh_size; ++i) {
      float x = (i - tanh_offset) / float_factor;
      tanh_lut[i] = static_cast<int>(std::round(tanhf(x) * float_factor));
    }
    tanh_luts[num_mantissa_bits_out - 1] = tanh_lut;
  }
  return tanh_luts[num_mantissa_bits_out - 1];
}

// As TanhTable, but for Sigmoid.
const int32_t* SigmoidTable(int num_mantissa_bits_out) {
  if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr;
  // Static data dynamically created and never destructed.
  static const int32_t* sigmoid_luts[kMaxMantissaBits];
  if (sigmoid_luts[num_mantissa_bits_out - 1] == nullptr) {
    // Total bits is number each side of the binary point minus one for the fact
    // that the gradient never exceeds 1/4. (Could probably use -2.)
    int sigmoid_lut_bits =
        num_mantissa_bits_out + kNumSigmoidExpBits - kNumExtraSigmoidShiftBits;
    // Offset is the number of negative numbers represented.
    int sigmoid_offset = 1 << sigmoid_lut_bits;
    // Size is double the offset plus one more for zero.
    int sigmoid_size = sigmoid_offset * 2 + 1;
    // Conversion between int and float.
    float float_factor = static_cast<float>(1 << num_mantissa_bits_out);
    int* sigmoid_lut = new int[sigmoid_size];
    // Initialize the table.
    for (int i = 0; i < sigmoid_size; ++i) {
      constexpr int kSigmoidFactor = 1 << kNumExtraSigmoidShiftBits;
      float x = ((i - sigmoid_offset) * kSigmoidFactor) / float_factor;
      float sigmoid = 1.0f / (1.0f + expf(-x));
      sigmoid_lut[i] = static_cast<int>(std::round(sigmoid * float_factor));
    }
    sigmoid_luts[num_mantissa_bits_out - 1] = sigmoid_lut;
  }
  return sigmoid_luts[num_mantissa_bits_out - 1];
}

}  // namespace csrblocksparse