Committed by
GitHub
Fix computing features for CED audio tagging models. (#1341)
See also https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
正在显示
1 个修改的文件
包含
35 行增加
和
1 行删除
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <cassert> | 8 | #include <cassert> |
| 9 | #include <cmath> | 9 | #include <cmath> |
| 10 | #include <iomanip> | 10 | #include <iomanip> |
| 11 | +#include <limits> | ||
| 11 | #include <utility> | 12 | #include <utility> |
| 12 | 13 | ||
| 13 | #include "kaldi-native-fbank/csrc/online-feature.h" | 14 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| @@ -110,7 +111,7 @@ class OfflineStream::Impl { | @@ -110,7 +111,7 @@ class OfflineStream::Impl { | ||
| 110 | config_.sampling_rate = opts_.frame_opts.samp_freq; | 111 | config_.sampling_rate = opts_.frame_opts.samp_freq; |
| 111 | } | 112 | } |
| 112 | 113 | ||
| 113 | - explicit Impl(CEDTag /*tag*/) { | 114 | + explicit Impl(CEDTag /*tag*/) : is_ced_(true) { |
| 114 | // see | 115 | // see |
| 115 | // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py | 116 | // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py |
| 116 | 117 | ||
| @@ -123,7 +124,9 @@ class OfflineStream::Impl { | @@ -123,7 +124,9 @@ class OfflineStream::Impl { | ||
| 123 | 124 | ||
| 124 | opts_.frame_opts.samp_freq = 16000; // fixed to 16000 | 125 | opts_.frame_opts.samp_freq = 16000; // fixed to 16000 |
| 125 | opts_.mel_opts.num_bins = 64; | 126 | opts_.mel_opts.num_bins = 64; |
| 127 | + opts_.mel_opts.low_freq = 0; | ||
| 126 | opts_.mel_opts.high_freq = 8000; | 128 | opts_.mel_opts.high_freq = 8000; |
| 129 | + opts_.use_log_fbank = false; | ||
| 127 | 130 | ||
| 128 | config_.sampling_rate = opts_.frame_opts.samp_freq; | 131 | config_.sampling_rate = opts_.frame_opts.samp_freq; |
| 129 | 132 | ||
| @@ -216,6 +219,10 @@ class OfflineStream::Impl { | @@ -216,6 +219,10 @@ class OfflineStream::Impl { | ||
| 216 | 219 | ||
| 217 | NemoNormalizeFeatures(features.data(), n, feature_dim); | 220 | NemoNormalizeFeatures(features.data(), n, feature_dim); |
| 218 | 221 | ||
| 222 | + if (is_ced_) { | ||
| 223 | + AmplitudeToDB(features.data(), features.size()); | ||
| 224 | + } | ||
| 225 | + | ||
| 219 | return features; | 226 | return features; |
| 220 | } | 227 | } |
| 221 | 228 | ||
| @@ -226,6 +233,32 @@ class OfflineStream::Impl { | @@ -226,6 +233,32 @@ class OfflineStream::Impl { | ||
| 226 | const ContextGraphPtr &GetContextGraph() const { return context_graph_; } | 233 | const ContextGraphPtr &GetContextGraph() const { return context_graph_; } |
| 227 | 234 | ||
| 228 | private: | 235 | private: |
| 236 | + // see | ||
| 237 | + // https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359 | ||
| 238 | + void AmplitudeToDB(float *p, int32_t n) const { | ||
| 239 | + float multiplier = 10; | ||
| 240 | + float top_db = 120; | ||
| 241 | + float amin = 1e-10; | ||
| 242 | + | ||
| 243 | + float max_x = std::numeric_limits<float>::min(); | ||
| 244 | + | ||
| 245 | + for (int32_t i = 0; i != n; ++i) { | ||
| 246 | + float x = p[i]; | ||
| 247 | + x = (x > amin) ? x : amin; | ||
| 248 | + x = std::log10f(x) * multiplier; | ||
| 249 | + | ||
| 250 | + max_x = (x > max_x) ? x : max_x; | ||
| 251 | + p[i] = x; | ||
| 252 | + } | ||
| 253 | + | ||
| 254 | + float d = max_x - top_db; | ||
| 255 | + for (int32_t i = 0; i != n; ++i) { | ||
| 256 | + float x = p[i]; | ||
| 257 | + x = (x > d) ? x : d; | ||
| 258 | + p[i] = x; | ||
| 259 | + } | ||
| 260 | + } | ||
| 261 | + | ||
| 229 | void NemoNormalizeFeatures(float *p, int32_t num_frames, | 262 | void NemoNormalizeFeatures(float *p, int32_t num_frames, |
| 230 | int32_t feature_dim) const { | 263 | int32_t feature_dim) const { |
| 231 | if (config_.nemo_normalize_type.empty()) { | 264 | if (config_.nemo_normalize_type.empty()) { |
| @@ -266,6 +299,7 @@ class OfflineStream::Impl { | @@ -266,6 +299,7 @@ class OfflineStream::Impl { | ||
| 266 | knf::MfccOptions mfcc_opts_; | 299 | knf::MfccOptions mfcc_opts_; |
| 267 | OfflineRecognitionResult r_; | 300 | OfflineRecognitionResult r_; |
| 268 | ContextGraphPtr context_graph_; | 301 | ContextGraphPtr context_graph_; |
| 302 | + bool is_ced_ = false; | ||
| 269 | }; | 303 | }; |
| 270 | 304 | ||
| 271 | OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, | 305 | OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, |
-
请 注册 或 登录 后发表评论