Fangjun Kuang
Committed by GitHub

Fix computing features for CED audio tagging models. (#1341)

See also
https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
... ... @@ -8,6 +8,7 @@
#include <cassert>
#include <cmath>
#include <iomanip>
#include <limits>
#include <utility>
#include "kaldi-native-fbank/csrc/online-feature.h"
... ... @@ -110,7 +111,7 @@ class OfflineStream::Impl {
config_.sampling_rate = opts_.frame_opts.samp_freq;
}
explicit Impl(CEDTag /*tag*/) {
explicit Impl(CEDTag /*tag*/) : is_ced_(true) {
// see
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
... ... @@ -123,7 +124,9 @@ class OfflineStream::Impl {
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.low_freq = 0;
opts_.mel_opts.high_freq = 8000;
opts_.use_log_fbank = false;
config_.sampling_rate = opts_.frame_opts.samp_freq;
... ... @@ -216,6 +219,10 @@ class OfflineStream::Impl {
NemoNormalizeFeatures(features.data(), n, feature_dim);
if (is_ced_) {
AmplitudeToDB(features.data(), features.size());
}
return features;
}
... ... @@ -226,6 +233,32 @@ class OfflineStream::Impl {
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
private:
// see
// https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359
void AmplitudeToDB(float *p, int32_t n) const {
float multiplier = 10;
float top_db = 120;
float amin = 1e-10;
float max_x = std::numeric_limits<float>::min();
for (int32_t i = 0; i != n; ++i) {
float x = p[i];
x = (x > amin) ? x : amin;
x = std::log10f(x) * multiplier;
max_x = (x > max_x) ? x : max_x;
p[i] = x;
}
float d = max_x - top_db;
for (int32_t i = 0; i != n; ++i) {
float x = p[i];
x = (x > d) ? x : d;
p[i] = x;
}
}
void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const {
if (config_.nemo_normalize_type.empty()) {
... ... @@ -266,6 +299,7 @@ class OfflineStream::Impl {
knf::MfccOptions mfcc_opts_;
OfflineRecognitionResult r_;
ContextGraphPtr context_graph_;
bool is_ced_ = false;
};
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
... ...