NTT123 commited on
Commit
0535752
1 Parent(s): 5ec3478

back to normal sampling method.

Browse files
Files changed (2) hide show
  1. wavegru_mod.cc +12 -12
  2. wavegru_mod.so +1 -1
wavegru_mod.cc CHANGED
@@ -122,18 +122,18 @@ struct WaveGRU {
122
  }
123
  o1.SpMM_bias(h, o1b, &fco1, true);
124
  o2.SpMM_bias(fco1, o2b, &fco2, false);
125
- auto max_logit = fco2[0];
126
- for (int i = 1; i <= 255; ++i) {
127
- max_logit = max(max_logit, fco2[i]);
128
- }
129
- float total = 0.0;
130
- for (int i = 0; i <= 255; ++i) {
131
- logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit);
132
- total += logits[i];
133
- }
134
- for (int i = 0; i <= 255; ++i) {
135
- if (logits[i] < total / 256.0) fco2[i] = -1e9;
136
- }
137
  value = fco2.Sample(temperature);
138
  signal[index] = value;
139
  }
 
122
  }
123
  o1.SpMM_bias(h, o1b, &fco1, true);
124
  o2.SpMM_bias(fco1, o2b, &fco2, false);
125
+ // auto max_logit = fco2[0];
126
+ // for (int i = 1; i <= 255; ++i) {
127
+ // max_logit = max(max_logit, fco2[i]);
128
+ // }
129
+ // float total = 0.0;
130
+ // for (int i = 0; i <= 255; ++i) {
131
+ // logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit);
132
+ // total += logits[i];
133
+ // }
134
+ // for (int i = 0; i <= 255; ++i) {
135
+ // if (logits[i] < total / 256.0) fco2[i] = -1e9;
136
+ // }
137
  value = fco2.Sample(temperature);
138
  signal[index] = value;
139
  }
wavegru_mod.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b65c5312f24f8ab9cfa51e8340a24ac1165b247046a331386d636fba9036c19c
3
  size 525536
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:700f2cade76db615b1e38bddfc9c604ff1c8ea1af3e507f879d0ceebae5d232d
3
  size 525536