Karel Vesely
Committed by GitHub

Configurable low_freq high_freq, dithering (#664)

1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.7.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.7.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=e78fd9d481d83d7d6d1be0012752e6531cb614e030558a3491e3c033cb8e0e4e") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -25,6 +25,19 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { @@ -25,6 +25,19 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
25 25
26 po->Register("feat-dim", &feature_dim, 26 po->Register("feat-dim", &feature_dim,
27 "Feature dimension. Must match the one expected by the model."); 27 "Feature dimension. Must match the one expected by the model.");
  28 +
  29 + po->Register("low-freq", &low_freq,
  30 + "Low cutoff frequency for mel bins");
  31 +
  32 + po->Register("high-freq", &high_freq,
  33 + "High cutoff frequency for mel bins "
  34 + "(if <= 0, offset from Nyquist)");
  35 +
  36 + po->Register("dither", &dither,
  37 + "Dithering constant (0.0 means no dither). "
  38 + "By default the audio samples are in range [-1,+1], "
  39 + "so 0.00003 is a good value, "
  40 + "equivalent to the default 1.0 from kaldi");
28 } 41 }
29 42
30 std::string FeatureExtractorConfig::ToString() const { 43 std::string FeatureExtractorConfig::ToString() const {
@@ -32,7 +45,10 @@ std::string FeatureExtractorConfig::ToString() const { @@ -32,7 +45,10 @@ std::string FeatureExtractorConfig::ToString() const {
32 45
33 os << "FeatureExtractorConfig("; 46 os << "FeatureExtractorConfig(";
34 os << "sampling_rate=" << sampling_rate << ", "; 47 os << "sampling_rate=" << sampling_rate << ", ";
35 - os << "feature_dim=" << feature_dim << ")"; 48 + os << "feature_dim=" << feature_dim << ", ";
  49 + os << "low_freq=" << low_freq << ", ";
  50 + os << "high_freq=" << high_freq << ", ";
  51 + os << "dither=" << dither << ")";
36 52
37 return os.str(); 53 return os.str();
38 } 54 }
@@ -40,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const { @@ -40,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const {
40 class FeatureExtractor::Impl { 56 class FeatureExtractor::Impl {
41 public: 57 public:
42 explicit Impl(const FeatureExtractorConfig &config) : config_(config) { 58 explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
43 - opts_.frame_opts.dither = 0; 59 + opts_.frame_opts.dither = config.dither;
44 opts_.frame_opts.snip_edges = config.snip_edges; 60 opts_.frame_opts.snip_edges = config.snip_edges;
45 opts_.frame_opts.samp_freq = config.sampling_rate; 61 opts_.frame_opts.samp_freq = config.sampling_rate;
46 opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; 62 opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
@@ -50,13 +66,9 @@ class FeatureExtractor::Impl { @@ -50,13 +66,9 @@ class FeatureExtractor::Impl {
50 66
51 opts_.mel_opts.num_bins = config.feature_dim; 67 opts_.mel_opts.num_bins = config.feature_dim;
52 68
53 - // Please see  
54 - // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27  
55 - // and  
56 - // https://github.com/k2-fsa/sherpa-onnx/issues/514  
57 - opts_.mel_opts.high_freq = -400; 69 + opts_.mel_opts.high_freq = config.high_freq;
  70 + opts_.mel_opts.low_freq = config.low_freq;
58 71
59 - opts_.mel_opts.low_freq = config.low_freq;  
60 opts_.mel_opts.is_librosa = config.is_librosa; 72 opts_.mel_opts.is_librosa = config.is_librosa;
61 73
62 fbank_ = std::make_unique<knf::OnlineFbank>(opts_); 74 fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
@@ -21,6 +21,27 @@ struct FeatureExtractorConfig { @@ -21,6 +21,27 @@ struct FeatureExtractorConfig {
21 // Feature dimension 21 // Feature dimension
22 int32_t feature_dim = 80; 22 int32_t feature_dim = 80;
23 23
  24 + // minimal frequency for Mel-filterbank, in Hz
  25 + float low_freq = 20.0f;
  26 +
  27 + // maximal frequency of Mel-filterbank
  28 + // in Hz; negative value is subtracted from Nyquist freq.:
  29 + // i.e. for sampling_rate 16000 / 2 - 400 = 7600Hz
  30 + //
  31 + // Please see
  32 + // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
  33 + // and
  34 + // https://github.com/k2-fsa/sherpa-onnx/issues/514
  35 + float high_freq = -400.0f;
  36 +
  37 + // dithering constant, useful for signals with hard-zeroes in non-speech parts
  38 + // this prevents large negative values in log-mel filterbanks
  39 + //
  40 + // In k2, audio samples are in range [-1..+1], in kaldi the range was
  41 + // [-32k..+32k], so the value 0.00003 is equivalent to kaldi default 1.0
  42 + //
  43 + float dither = 0.0f; // dithering disabled by default
  44 +
24 // Set internally by some models, e.g., paraformer sets it to false. 45 // Set internally by some models, e.g., paraformer sets it to false.
25 // This parameter is not exposed to users from the commandline 46 // This parameter is not exposed to users from the commandline
26 // If true, the feature extractor expects inputs to be normalized to 47 // If true, the feature extractor expects inputs to be normalized to
@@ -31,7 +52,6 @@ struct FeatureExtractorConfig { @@ -31,7 +52,6 @@ struct FeatureExtractorConfig {
31 bool snip_edges = false; 52 bool snip_edges = false;
32 float frame_shift_ms = 10.0f; // in milliseconds. 53 float frame_shift_ms = 10.0f; // in milliseconds.
33 float frame_length_ms = 25.0f; // in milliseconds. 54 float frame_length_ms = 25.0f; // in milliseconds.
34 - int32_t low_freq = 20;  
35 bool is_librosa = false; 55 bool is_librosa = false;
36 bool remove_dc_offset = true; // Subtract mean of wave before FFT. 56 bool remove_dc_offset = true; // Subtract mean of wave before FFT.
37 std::string window_type = "povey"; // e.g. Hamming window 57 std::string window_type = "povey"; // e.g. Hamming window
@@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { @@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
72 unk_id_ = sym_["<unk>"]; 72 unk_id_ = sym_["<unk>"];
73 } 73 }
74 74
  75 + model_->SetFeatureDim(config.feat_config.feature_dim);
  76 +
75 InitKeywords(); 77 InitKeywords();
76 78
77 decoder_ = std::make_unique<TransducerKeywordDecoder>( 79 decoder_ = std::make_unique<TransducerKeywordDecoder>(
@@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { @@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
89 unk_id_ = sym_["<unk>"]; 91 unk_id_ = sym_["<unk>"];
90 } 92 }
91 93
  94 + model_->SetFeatureDim(config.feat_config.feature_dim);
  95 +
92 InitKeywords(mgr); 96 InitKeywords(mgr);
93 97
94 decoder_ = std::make_unique<TransducerKeywordDecoder>( 98 decoder_ = std::make_unique<TransducerKeywordDecoder>(
@@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
90 unk_id_ = sym_["<unk>"]; 90 unk_id_ = sym_["<unk>"];
91 } 91 }
92 92
  93 + model_->SetFeatureDim(config.feat_config.feature_dim);
  94 +
93 if (config.decoding_method == "modified_beam_search") { 95 if (config.decoding_method == "modified_beam_search") {
94 if (!config_.hotwords_file.empty()) { 96 if (!config_.hotwords_file.empty()) {
95 InitHotwords(); 97 InitHotwords();
@@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
123 unk_id_ = sym_["<unk>"]; 125 unk_id_ = sym_["<unk>"];
124 } 126 }
125 127
  128 + model_->SetFeatureDim(config.feat_config.feature_dim);
  129 +
126 if (config.decoding_method == "modified_beam_search") { 130 if (config.decoding_method == "modified_beam_search") {
127 #if 0 131 #if 0
128 // TODO(fangjun): Implement it 132 // TODO(fangjun): Implement it
@@ -61,6 +61,16 @@ class OnlineTransducerModel { @@ -61,6 +61,16 @@ class OnlineTransducerModel {
61 */ 61 */
62 virtual std::vector<Ort::Value> GetEncoderInitStates() = 0; 62 virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;
63 63
  64 + /** Set feature dim.
  65 + *
  66 + * This is used in `OnlineZipformer2TransducerModel`,
  67 + * to pass `feature_dim` for `GetEncoderInitStates()`.
  68 + *
  69 + * This has to be called before GetEncoderInitStates(), so the `encoder_embed`
  70 + * init state has the correct `embed_dim` of its output.
  71 + */
  72 + virtual void SetFeatureDim(int32_t feature_dim) { }
  73 +
64 /** Run the encoder. 74 /** Run the encoder.
65 * 75 *
66 * @param features A tensor of shape (N, T, C). It is changed in-place. 76 * @param features A tensor of shape (N, T, C). It is changed in-place.
@@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() { @@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() {
403 } 403 }
404 404
405 { 405 {
406 - std::array<int64_t, 4> s{1, 128, 3, 19}; 406 + SHERPA_ONNX_CHECK_NE(feature_dim_, 0);
  407 + int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2;
  408 + std::array<int64_t, 4> s{1, 128, 3, embed_dim};
  409 +
407 auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); 410 auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
408 Fill(&v, 0); 411 Fill(&v, 0);
409 ans.push_back(std::move(v)); 412 ans.push_back(std::move(v));
@@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { @@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
37 37
38 std::vector<Ort::Value> GetEncoderInitStates() override; 38 std::vector<Ort::Value> GetEncoderInitStates() override;
39 39
  40 + void SetFeatureDim(int32_t feature_dim) override {
  41 + feature_dim_ = feature_dim;
  42 + }
  43 +
40 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 44 std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
41 Ort::Value features, std::vector<Ort::Value> states, 45 Ort::Value features, std::vector<Ort::Value> states,
42 Ort::Value processed_frames) override; 46 Ort::Value processed_frames) override;
@@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { @@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
101 105
102 int32_t context_size_ = 0; 106 int32_t context_size_ = 0;
103 int32_t vocab_size_ = 0; 107 int32_t vocab_size_ = 0;
  108 + int32_t feature_dim_ = 0;
104 }; 109 };
105 110
106 } // namespace sherpa_onnx 111 } // namespace sherpa_onnx
@@ -11,10 +11,17 @@ namespace sherpa_onnx { @@ -11,10 +11,17 @@ namespace sherpa_onnx {
11 static void PybindFeatureExtractorConfig(py::module *m) { 11 static void PybindFeatureExtractorConfig(py::module *m) {
12 using PyClass = FeatureExtractorConfig; 12 using PyClass = FeatureExtractorConfig;
13 py::class_<PyClass>(*m, "FeatureExtractorConfig") 13 py::class_<PyClass>(*m, "FeatureExtractorConfig")
14 - .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,  
15 - py::arg("feature_dim") = 80) 14 + .def(py::init<int32_t, int32_t, float, float, float>(),
  15 + py::arg("sampling_rate") = 16000,
  16 + py::arg("feature_dim") = 80,
  17 + py::arg("low_freq") = 20.0f,
  18 + py::arg("high_freq") = -400.0f,
  19 + py::arg("dither") = 0.0f)
16 .def_readwrite("sampling_rate", &PyClass::sampling_rate) 20 .def_readwrite("sampling_rate", &PyClass::sampling_rate)
17 .def_readwrite("feature_dim", &PyClass::feature_dim) 21 .def_readwrite("feature_dim", &PyClass::feature_dim)
  22 + .def_readwrite("low_freq", &PyClass::low_freq)
  23 + .def_readwrite("high_freq", &PyClass::high_freq)
  24 + .def_readwrite("dither", &PyClass::high_freq)
18 .def("__str__", &PyClass::ToString); 25 .def("__str__", &PyClass::ToString);
19 } 26 }
20 27
@@ -41,6 +41,9 @@ class OnlineRecognizer(object): @@ -41,6 +41,9 @@ class OnlineRecognizer(object):
41 num_threads: int = 2, 41 num_threads: int = 2,
42 sample_rate: float = 16000, 42 sample_rate: float = 16000,
43 feature_dim: int = 80, 43 feature_dim: int = 80,
  44 + low_freq: float = 20.0,
  45 + high_freq: float = -400.0,
  46 + dither: float = 0.0,
44 enable_endpoint_detection: bool = False, 47 enable_endpoint_detection: bool = False,
45 rule1_min_trailing_silence: float = 2.4, 48 rule1_min_trailing_silence: float = 2.4,
46 rule2_min_trailing_silence: float = 1.2, 49 rule2_min_trailing_silence: float = 1.2,
@@ -80,6 +83,16 @@ class OnlineRecognizer(object): @@ -80,6 +83,16 @@ class OnlineRecognizer(object):
80 Sample rate of the training data used to train the model. 83 Sample rate of the training data used to train the model.
81 feature_dim: 84 feature_dim:
82 Dimension of the feature used to train the model. 85 Dimension of the feature used to train the model.
  86 + low_freq:
  87 + Low cutoff frequency for mel bins in feature extraction.
  88 + high_freq:
  89 + High cutoff frequency for mel bins in feature extraction
  90 + (if <= 0, offset from Nyquist)
  91 + dither:
  92 + Dithering constant (0.0 means no dither).
  93 + By default the audio samples are in range [-1,+1],
  94 + so dithering constant 0.00003 is a good value,
  95 + equivalent to the default 1.0 from kaldi
83 enable_endpoint_detection: 96 enable_endpoint_detection:
84 True to enable endpoint detection. False to disable endpoint 97 True to enable endpoint detection. False to disable endpoint
85 detection. 98 detection.
@@ -140,6 +153,9 @@ class OnlineRecognizer(object): @@ -140,6 +153,9 @@ class OnlineRecognizer(object):
140 feat_config = FeatureExtractorConfig( 153 feat_config = FeatureExtractorConfig(
141 sampling_rate=sample_rate, 154 sampling_rate=sample_rate,
142 feature_dim=feature_dim, 155 feature_dim=feature_dim,
  156 + low_freq=low_freq,
  157 + high_freq=high_freq,
  158 + dither=dither,
143 ) 159 )
144 160
145 endpoint_config = EndpointConfig( 161 endpoint_config = EndpointConfig(