Committed by
GitHub
Configurable low_freq high_freq, dithering (#664)
正在显示
10 个修改的文件
包含
96 行增加
和
15 行删除
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.7.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.7.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=e78fd9d481d83d7d6d1be0012752e6531cb614e030558a3491e3c033cb8e0e4e") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -25,6 +25,19 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | @@ -25,6 +25,19 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 25 | 25 | ||
| 26 | po->Register("feat-dim", &feature_dim, | 26 | po->Register("feat-dim", &feature_dim, |
| 27 | "Feature dimension. Must match the one expected by the model."); | 27 | "Feature dimension. Must match the one expected by the model."); |
| 28 | + | ||
| 29 | + po->Register("low-freq", &low_freq, | ||
| 30 | + "Low cutoff frequency for mel bins"); | ||
| 31 | + | ||
| 32 | + po->Register("high-freq", &high_freq, | ||
| 33 | + "High cutoff frequency for mel bins " | ||
| 34 | + "(if <= 0, offset from Nyquist)"); | ||
| 35 | + | ||
| 36 | + po->Register("dither", &dither, | ||
| 37 | + "Dithering constant (0.0 means no dither). " | ||
| 38 | + "By default the audio samples are in range [-1,+1], " | ||
| 39 | + "so 0.00003 is a good value, " | ||
| 40 | + "equivalent to the default 1.0 from kaldi"); | ||
| 28 | } | 41 | } |
| 29 | 42 | ||
| 30 | std::string FeatureExtractorConfig::ToString() const { | 43 | std::string FeatureExtractorConfig::ToString() const { |
| @@ -32,7 +45,10 @@ std::string FeatureExtractorConfig::ToString() const { | @@ -32,7 +45,10 @@ std::string FeatureExtractorConfig::ToString() const { | ||
| 32 | 45 | ||
| 33 | os << "FeatureExtractorConfig("; | 46 | os << "FeatureExtractorConfig("; |
| 34 | os << "sampling_rate=" << sampling_rate << ", "; | 47 | os << "sampling_rate=" << sampling_rate << ", "; |
| 35 | - os << "feature_dim=" << feature_dim << ")"; | 48 | + os << "feature_dim=" << feature_dim << ", "; |
| 49 | + os << "low_freq=" << low_freq << ", "; | ||
| 50 | + os << "high_freq=" << high_freq << ", "; | ||
| 51 | + os << "dither=" << dither << ")"; | ||
| 36 | 52 | ||
| 37 | return os.str(); | 53 | return os.str(); |
| 38 | } | 54 | } |
| @@ -40,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const { | @@ -40,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const { | ||
| 40 | class FeatureExtractor::Impl { | 56 | class FeatureExtractor::Impl { |
| 41 | public: | 57 | public: |
| 42 | explicit Impl(const FeatureExtractorConfig &config) : config_(config) { | 58 | explicit Impl(const FeatureExtractorConfig &config) : config_(config) { |
| 43 | - opts_.frame_opts.dither = 0; | 59 | + opts_.frame_opts.dither = config.dither; |
| 44 | opts_.frame_opts.snip_edges = config.snip_edges; | 60 | opts_.frame_opts.snip_edges = config.snip_edges; |
| 45 | opts_.frame_opts.samp_freq = config.sampling_rate; | 61 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| 46 | opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; | 62 | opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; |
| @@ -50,13 +66,9 @@ class FeatureExtractor::Impl { | @@ -50,13 +66,9 @@ class FeatureExtractor::Impl { | ||
| 50 | 66 | ||
| 51 | opts_.mel_opts.num_bins = config.feature_dim; | 67 | opts_.mel_opts.num_bins = config.feature_dim; |
| 52 | 68 | ||
| 53 | - // Please see | ||
| 54 | - // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27 | ||
| 55 | - // and | ||
| 56 | - // https://github.com/k2-fsa/sherpa-onnx/issues/514 | ||
| 57 | - opts_.mel_opts.high_freq = -400; | 69 | + opts_.mel_opts.high_freq = config.high_freq; |
| 70 | + opts_.mel_opts.low_freq = config.low_freq; | ||
| 58 | 71 | ||
| 59 | - opts_.mel_opts.low_freq = config.low_freq; | ||
| 60 | opts_.mel_opts.is_librosa = config.is_librosa; | 72 | opts_.mel_opts.is_librosa = config.is_librosa; |
| 61 | 73 | ||
| 62 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 74 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| @@ -21,6 +21,27 @@ struct FeatureExtractorConfig { | @@ -21,6 +21,27 @@ struct FeatureExtractorConfig { | ||
| 21 | // Feature dimension | 21 | // Feature dimension |
| 22 | int32_t feature_dim = 80; | 22 | int32_t feature_dim = 80; |
| 23 | 23 | ||
| 24 | + // minimal frequency for Mel-filterbank, in Hz | ||
| 25 | + float low_freq = 20.0f; | ||
| 26 | + | ||
| 27 | + // maximal frequency of Mel-filterbank | ||
| 28 | + // in Hz; negative value is subtracted from Nyquist freq.: | ||
| 29 | + // i.e. for sampling_rate 16000 / 2 - 400 = 7600Hz | ||
| 30 | + // | ||
| 31 | + // Please see | ||
| 32 | + // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27 | ||
| 33 | + // and | ||
| 34 | + // https://github.com/k2-fsa/sherpa-onnx/issues/514 | ||
| 35 | + float high_freq = -400.0f; | ||
| 36 | + | ||
| 37 | + // dithering constant, useful for signals with hard-zeroes in non-speech parts | ||
| 38 | + // this prevents large negative values in log-mel filterbanks | ||
| 39 | + // | ||
| 40 | + // In k2, audio samples are in range [-1..+1], in kaldi the range was | ||
| 41 | + // [-32k..+32k], so the value 0.00003 is equivalent to kaldi default 1.0 | ||
| 42 | + // | ||
| 43 | + float dither = 0.0f; // dithering disabled by default | ||
| 44 | + | ||
| 24 | // Set internally by some models, e.g., paraformer sets it to false. | 45 | // Set internally by some models, e.g., paraformer sets it to false. |
| 25 | // This parameter is not exposed to users from the commandline | 46 | // This parameter is not exposed to users from the commandline |
| 26 | // If true, the feature extractor expects inputs to be normalized to | 47 | // If true, the feature extractor expects inputs to be normalized to |
| @@ -31,7 +52,6 @@ struct FeatureExtractorConfig { | @@ -31,7 +52,6 @@ struct FeatureExtractorConfig { | ||
| 31 | bool snip_edges = false; | 52 | bool snip_edges = false; |
| 32 | float frame_shift_ms = 10.0f; // in milliseconds. | 53 | float frame_shift_ms = 10.0f; // in milliseconds. |
| 33 | float frame_length_ms = 25.0f; // in milliseconds. | 54 | float frame_length_ms = 25.0f; // in milliseconds. |
| 34 | - int32_t low_freq = 20; | ||
| 35 | bool is_librosa = false; | 55 | bool is_librosa = false; |
| 36 | bool remove_dc_offset = true; // Subtract mean of wave before FFT. | 56 | bool remove_dc_offset = true; // Subtract mean of wave before FFT. |
| 37 | std::string window_type = "povey"; // e.g. Hamming window | 57 | std::string window_type = "povey"; // e.g. Hamming window |
| @@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 72 | unk_id_ = sym_["<unk>"]; | 72 | unk_id_ = sym_["<unk>"]; |
| 73 | } | 73 | } |
| 74 | 74 | ||
| 75 | + model_->SetFeatureDim(config.feat_config.feature_dim); | ||
| 76 | + | ||
| 75 | InitKeywords(); | 77 | InitKeywords(); |
| 76 | 78 | ||
| 77 | decoder_ = std::make_unique<TransducerKeywordDecoder>( | 79 | decoder_ = std::make_unique<TransducerKeywordDecoder>( |
| @@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | @@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { | ||
| 89 | unk_id_ = sym_["<unk>"]; | 91 | unk_id_ = sym_["<unk>"]; |
| 90 | } | 92 | } |
| 91 | 93 | ||
| 94 | + model_->SetFeatureDim(config.feat_config.feature_dim); | ||
| 95 | + | ||
| 92 | InitKeywords(mgr); | 96 | InitKeywords(mgr); |
| 93 | 97 | ||
| 94 | decoder_ = std::make_unique<TransducerKeywordDecoder>( | 98 | decoder_ = std::make_unique<TransducerKeywordDecoder>( |
| @@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 90 | unk_id_ = sym_["<unk>"]; | 90 | unk_id_ = sym_["<unk>"]; |
| 91 | } | 91 | } |
| 92 | 92 | ||
| 93 | + model_->SetFeatureDim(config.feat_config.feature_dim); | ||
| 94 | + | ||
| 93 | if (config.decoding_method == "modified_beam_search") { | 95 | if (config.decoding_method == "modified_beam_search") { |
| 94 | if (!config_.hotwords_file.empty()) { | 96 | if (!config_.hotwords_file.empty()) { |
| 95 | InitHotwords(); | 97 | InitHotwords(); |
| @@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 123 | unk_id_ = sym_["<unk>"]; | 125 | unk_id_ = sym_["<unk>"]; |
| 124 | } | 126 | } |
| 125 | 127 | ||
| 128 | + model_->SetFeatureDim(config.feat_config.feature_dim); | ||
| 129 | + | ||
| 126 | if (config.decoding_method == "modified_beam_search") { | 130 | if (config.decoding_method == "modified_beam_search") { |
| 127 | #if 0 | 131 | #if 0 |
| 128 | // TODO(fangjun): Implement it | 132 | // TODO(fangjun): Implement it |
| @@ -61,6 +61,16 @@ class OnlineTransducerModel { | @@ -61,6 +61,16 @@ class OnlineTransducerModel { | ||
| 61 | */ | 61 | */ |
| 62 | virtual std::vector<Ort::Value> GetEncoderInitStates() = 0; | 62 | virtual std::vector<Ort::Value> GetEncoderInitStates() = 0; |
| 63 | 63 | ||
| 64 | + /** Set feature dim. | ||
| 65 | + * | ||
| 66 | + * This is used in `OnlineZipformer2TransducerModel`, | ||
| 67 | + * to pass `feature_dim` for `GetEncoderInitStates()`. | ||
| 68 | + * | ||
| 69 | + * This has to be called before GetEncoderInitStates(), so the `encoder_embed` | ||
| 70 | + * init state has the correct `embed_dim` of its output. | ||
| 71 | + */ | ||
| 72 | + virtual void SetFeatureDim(int32_t feature_dim) { } | ||
| 73 | + | ||
| 64 | /** Run the encoder. | 74 | /** Run the encoder. |
| 65 | * | 75 | * |
| 66 | * @param features A tensor of shape (N, T, C). It is changed in-place. | 76 | * @param features A tensor of shape (N, T, C). It is changed in-place. |
| @@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() { | @@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() { | ||
| 403 | } | 403 | } |
| 404 | 404 | ||
| 405 | { | 405 | { |
| 406 | - std::array<int64_t, 4> s{1, 128, 3, 19}; | 406 | + SHERPA_ONNX_CHECK_NE(feature_dim_, 0); |
| 407 | + int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2; | ||
| 408 | + std::array<int64_t, 4> s{1, 128, 3, embed_dim}; | ||
| 409 | + | ||
| 407 | auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); | 410 | auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size()); |
| 408 | Fill(&v, 0); | 411 | Fill(&v, 0); |
| 409 | ans.push_back(std::move(v)); | 412 | ans.push_back(std::move(v)); |
| @@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | @@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | ||
| 37 | 37 | ||
| 38 | std::vector<Ort::Value> GetEncoderInitStates() override; | 38 | std::vector<Ort::Value> GetEncoderInitStates() override; |
| 39 | 39 | ||
| 40 | + void SetFeatureDim(int32_t feature_dim) override { | ||
| 41 | + feature_dim_ = feature_dim; | ||
| 42 | + } | ||
| 43 | + | ||
| 40 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( | 44 | std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( |
| 41 | Ort::Value features, std::vector<Ort::Value> states, | 45 | Ort::Value features, std::vector<Ort::Value> states, |
| 42 | Ort::Value processed_frames) override; | 46 | Ort::Value processed_frames) override; |
| @@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | @@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { | ||
| 101 | 105 | ||
| 102 | int32_t context_size_ = 0; | 106 | int32_t context_size_ = 0; |
| 103 | int32_t vocab_size_ = 0; | 107 | int32_t vocab_size_ = 0; |
| 108 | + int32_t feature_dim_ = 0; | ||
| 104 | }; | 109 | }; |
| 105 | 110 | ||
| 106 | } // namespace sherpa_onnx | 111 | } // namespace sherpa_onnx |
| @@ -11,10 +11,17 @@ namespace sherpa_onnx { | @@ -11,10 +11,17 @@ namespace sherpa_onnx { | ||
| 11 | static void PybindFeatureExtractorConfig(py::module *m) { | 11 | static void PybindFeatureExtractorConfig(py::module *m) { |
| 12 | using PyClass = FeatureExtractorConfig; | 12 | using PyClass = FeatureExtractorConfig; |
| 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") | 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") |
| 14 | - .def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000, | ||
| 15 | - py::arg("feature_dim") = 80) | 14 | + .def(py::init<int32_t, int32_t, float, float, float>(), |
| 15 | + py::arg("sampling_rate") = 16000, | ||
| 16 | + py::arg("feature_dim") = 80, | ||
| 17 | + py::arg("low_freq") = 20.0f, | ||
| 18 | + py::arg("high_freq") = -400.0f, | ||
| 19 | + py::arg("dither") = 0.0f) | ||
| 16 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) | 20 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) |
| 17 | .def_readwrite("feature_dim", &PyClass::feature_dim) | 21 | .def_readwrite("feature_dim", &PyClass::feature_dim) |
| 22 | + .def_readwrite("low_freq", &PyClass::low_freq) | ||
| 23 | + .def_readwrite("high_freq", &PyClass::high_freq) | ||
| 24 | + .def_readwrite("dither", &PyClass::high_freq) | ||
| 18 | .def("__str__", &PyClass::ToString); | 25 | .def("__str__", &PyClass::ToString); |
| 19 | } | 26 | } |
| 20 | 27 |
| @@ -41,6 +41,9 @@ class OnlineRecognizer(object): | @@ -41,6 +41,9 @@ class OnlineRecognizer(object): | ||
| 41 | num_threads: int = 2, | 41 | num_threads: int = 2, |
| 42 | sample_rate: float = 16000, | 42 | sample_rate: float = 16000, |
| 43 | feature_dim: int = 80, | 43 | feature_dim: int = 80, |
| 44 | + low_freq: float = 20.0, | ||
| 45 | + high_freq: float = -400.0, | ||
| 46 | + dither: float = 0.0, | ||
| 44 | enable_endpoint_detection: bool = False, | 47 | enable_endpoint_detection: bool = False, |
| 45 | rule1_min_trailing_silence: float = 2.4, | 48 | rule1_min_trailing_silence: float = 2.4, |
| 46 | rule2_min_trailing_silence: float = 1.2, | 49 | rule2_min_trailing_silence: float = 1.2, |
| @@ -80,6 +83,16 @@ class OnlineRecognizer(object): | @@ -80,6 +83,16 @@ class OnlineRecognizer(object): | ||
| 80 | Sample rate of the training data used to train the model. | 83 | Sample rate of the training data used to train the model. |
| 81 | feature_dim: | 84 | feature_dim: |
| 82 | Dimension of the feature used to train the model. | 85 | Dimension of the feature used to train the model. |
| 86 | + low_freq: | ||
| 87 | + Low cutoff frequency for mel bins in feature extraction. | ||
| 88 | + high_freq: | ||
| 89 | + High cutoff frequency for mel bins in feature extraction | ||
| 90 | + (if <= 0, offset from Nyquist) | ||
| 91 | + dither: | ||
| 92 | + Dithering constant (0.0 means no dither). | ||
| 93 | + By default the audio samples are in range [-1,+1], | ||
| 94 | + so dithering constant 0.00003 is a good value, | ||
| 95 | + equivalent to the default 1.0 from kaldi | ||
| 83 | enable_endpoint_detection: | 96 | enable_endpoint_detection: |
| 84 | True to enable endpoint detection. False to disable endpoint | 97 | True to enable endpoint detection. False to disable endpoint |
| 85 | detection. | 98 | detection. |
| @@ -140,6 +153,9 @@ class OnlineRecognizer(object): | @@ -140,6 +153,9 @@ class OnlineRecognizer(object): | ||
| 140 | feat_config = FeatureExtractorConfig( | 153 | feat_config = FeatureExtractorConfig( |
| 141 | sampling_rate=sample_rate, | 154 | sampling_rate=sample_rate, |
| 142 | feature_dim=feature_dim, | 155 | feature_dim=feature_dim, |
| 156 | + low_freq=low_freq, | ||
| 157 | + high_freq=high_freq, | ||
| 158 | + dither=dither, | ||
| 143 | ) | 159 | ) |
| 144 | 160 | ||
| 145 | endpoint_config = EndpointConfig( | 161 | endpoint_config = EndpointConfig( |
-
请 注册 或 登录 后发表评论