Fangjun Kuang
Committed by GitHub

Fix computing features for CED audio tagging models. (#1341)

See also
https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <cassert> 8 #include <cassert>
9 #include <cmath> 9 #include <cmath>
10 #include <iomanip> 10 #include <iomanip>
  11 +#include <limits>
11 #include <utility> 12 #include <utility>
12 13
13 #include "kaldi-native-fbank/csrc/online-feature.h" 14 #include "kaldi-native-fbank/csrc/online-feature.h"
@@ -110,7 +111,7 @@ class OfflineStream::Impl { @@ -110,7 +111,7 @@ class OfflineStream::Impl {
110 config_.sampling_rate = opts_.frame_opts.samp_freq; 111 config_.sampling_rate = opts_.frame_opts.samp_freq;
111 } 112 }
112 113
113 - explicit Impl(CEDTag /*tag*/) { 114 + explicit Impl(CEDTag /*tag*/) : is_ced_(true) {
114 // see 115 // see
115 // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py 116 // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
116 117
@@ -123,7 +124,9 @@ class OfflineStream::Impl { @@ -123,7 +124,9 @@ class OfflineStream::Impl {
123 124
124 opts_.frame_opts.samp_freq = 16000; // fixed to 16000 125 opts_.frame_opts.samp_freq = 16000; // fixed to 16000
125 opts_.mel_opts.num_bins = 64; 126 opts_.mel_opts.num_bins = 64;
  127 + opts_.mel_opts.low_freq = 0;
126 opts_.mel_opts.high_freq = 8000; 128 opts_.mel_opts.high_freq = 8000;
  129 + opts_.use_log_fbank = false;
127 130
128 config_.sampling_rate = opts_.frame_opts.samp_freq; 131 config_.sampling_rate = opts_.frame_opts.samp_freq;
129 132
@@ -216,6 +219,10 @@ class OfflineStream::Impl { @@ -216,6 +219,10 @@ class OfflineStream::Impl {
216 219
217 NemoNormalizeFeatures(features.data(), n, feature_dim); 220 NemoNormalizeFeatures(features.data(), n, feature_dim);
218 221
  222 + if (is_ced_) {
  223 + AmplitudeToDB(features.data(), features.size());
  224 + }
  225 +
219 return features; 226 return features;
220 } 227 }
221 228
@@ -226,6 +233,32 @@ class OfflineStream::Impl { @@ -226,6 +233,32 @@ class OfflineStream::Impl {
226 const ContextGraphPtr &GetContextGraph() const { return context_graph_; } 233 const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
227 234
228 private: 235 private:
  236 + // see
  237 + // https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359
  238 + void AmplitudeToDB(float *p, int32_t n) const {
  239 + float multiplier = 10;
  240 + float top_db = 120;
  241 + float amin = 1e-10;
  242 +
  243 + float max_x = std::numeric_limits<float>::min();
  244 +
  245 + for (int32_t i = 0; i != n; ++i) {
  246 + float x = p[i];
  247 + x = (x > amin) ? x : amin;
  248 + x = std::log10f(x) * multiplier;
  249 +
  250 + max_x = (x > max_x) ? x : max_x;
  251 + p[i] = x;
  252 + }
  253 +
  254 + float d = max_x - top_db;
  255 + for (int32_t i = 0; i != n; ++i) {
  256 + float x = p[i];
  257 + x = (x > d) ? x : d;
  258 + p[i] = x;
  259 + }
  260 + }
  261 +
229 void NemoNormalizeFeatures(float *p, int32_t num_frames, 262 void NemoNormalizeFeatures(float *p, int32_t num_frames,
230 int32_t feature_dim) const { 263 int32_t feature_dim) const {
231 if (config_.nemo_normalize_type.empty()) { 264 if (config_.nemo_normalize_type.empty()) {
@@ -266,6 +299,7 @@ class OfflineStream::Impl { @@ -266,6 +299,7 @@ class OfflineStream::Impl {
266 knf::MfccOptions mfcc_opts_; 299 knf::MfccOptions mfcc_opts_;
267 OfflineRecognitionResult r_; 300 OfflineRecognitionResult r_;
268 ContextGraphPtr context_graph_; 301 ContextGraphPtr context_graph_;
  302 + bool is_ced_ = false;
269 }; 303 };
270 304
271 OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, 305 OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,