Committed by
GitHub
online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)
* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) - added `reset_encoder` boolean member into the OnlineRecognizerConfig class - by default the encoder is not reset * pybind11, adding empty symbols for disabled modules (tts, diarization) * reset_encoder, add default value (false) [pybind11]
正在显示
6 个修改的文件
包含
53 行增加
和
10 行删除
| @@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 382 | } | 382 | } |
| 383 | } | 383 | } |
| 384 | 384 | ||
| 385 | - // reset encoder states | ||
| 386 | - // s->SetStates(model_->GetEncoderInitStates()); | ||
| 387 | - | ||
| 388 | auto r = decoder_->GetEmptyResult(); | 385 | auto r = decoder_->GetEmptyResult(); |
| 389 | auto last_result = s->GetResult(); | 386 | auto last_result = s->GetResult(); |
| 390 | - // if last result is not empty, then | ||
| 391 | - // truncate all last hyps and save as the context for next result | 387 | + |
| 392 | if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { | 388 | if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { |
| 389 | + // if last result is not empty, then | ||
| 390 | + // truncate all last hyps and save as the 'ys' context for next result | ||
| 391 | + // (the encoder state buffers are kept) | ||
| 393 | for (const auto &it : last_result.hyps) { | 392 | for (const auto &it : last_result.hyps) { |
| 394 | auto h = it.second; | 393 | auto h = it.second; |
| 395 | r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, | 394 | r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, |
| @@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 399 | 398 | ||
| 400 | r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size, | 399 | r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size, |
| 401 | last_result.tokens.end()); | 400 | last_result.tokens.end()); |
| 401 | + } else { | ||
| 402 | + if(config_.reset_encoder) { | ||
| 403 | + // reset encoder states, use blanks as 'ys' context | ||
| 404 | + s->SetStates(model_->GetEncoderInitStates()); | ||
| 405 | + } | ||
| 402 | } | 406 | } |
| 403 | 407 | ||
| 404 | // but reset all contextual biasing graph states to root | 408 | // but reset all contextual biasing graph states to root |
| @@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 121 | "rule-fars", &rule_fars, | 121 | "rule-fars", &rule_fars, |
| 122 | "If not empty, it specifies fst archives for inverse text normalization. " | 122 | "If not empty, it specifies fst archives for inverse text normalization. " |
| 123 | "If there are multiple archives, they are separated by a comma."); | 123 | "If there are multiple archives, they are separated by a comma."); |
| 124 | + | ||
| 125 | + po->Register("reset-encoder", &reset_encoder, | ||
| 126 | + "True to reset encoder_state on an endpoint after empty segment." | ||
| 127 | + "Done in `Reset()` method, after an endpoint was detected."); | ||
| 124 | } | 128 | } |
| 125 | 129 | ||
| 126 | bool OnlineRecognizerConfig::Validate() const { | 130 | bool OnlineRecognizerConfig::Validate() const { |
| @@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 198 | os << "blank_penalty=" << blank_penalty << ", "; | 202 | os << "blank_penalty=" << blank_penalty << ", "; |
| 199 | os << "temperature_scale=" << temperature_scale << ", "; | 203 | os << "temperature_scale=" << temperature_scale << ", "; |
| 200 | os << "rule_fsts=\"" << rule_fsts << "\", "; | 204 | os << "rule_fsts=\"" << rule_fsts << "\", "; |
| 201 | - os << "rule_fars=\"" << rule_fars << "\")"; | 205 | + os << "rule_fars=\"" << rule_fars << "\", "; |
| 206 | + os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")"; | ||
| 202 | 207 | ||
| 203 | return os.str(); | 208 | return os.str(); |
| 204 | } | 209 | } |
| @@ -79,6 +79,7 @@ struct OnlineRecognizerConfig { | @@ -79,6 +79,7 @@ struct OnlineRecognizerConfig { | ||
| 79 | OnlineLMConfig lm_config; | 79 | OnlineLMConfig lm_config; |
| 80 | EndpointConfig endpoint_config; | 80 | EndpointConfig endpoint_config; |
| 81 | OnlineCtcFstDecoderConfig ctc_fst_decoder_config; | 81 | OnlineCtcFstDecoderConfig ctc_fst_decoder_config; |
| 82 | + | ||
| 82 | bool enable_endpoint = true; | 83 | bool enable_endpoint = true; |
| 83 | 84 | ||
| 84 | std::string decoding_method = "greedy_search"; | 85 | std::string decoding_method = "greedy_search"; |
| @@ -101,6 +102,11 @@ struct OnlineRecognizerConfig { | @@ -101,6 +102,11 @@ struct OnlineRecognizerConfig { | ||
| 101 | // If there are multiple FST archives, they are applied from left to right. | 102 | // If there are multiple FST archives, they are applied from left to right. |
| 102 | std::string rule_fars; | 103 | std::string rule_fars; |
| 103 | 104 | ||
| 105 | + // True to reset encoder_state on an endpoint after empty segment. | ||
| 106 | + // Done in `Reset()` method, after an endpoint was detected, | ||
| 107 | + // currently only in `OnlineRecognizerTransducerImpl`. | ||
| 108 | + bool reset_encoder = false; | ||
| 109 | + | ||
| 104 | /// used only for modified_beam_search, if hotwords_buf is non-empty, | 110 | /// used only for modified_beam_search, if hotwords_buf is non-empty, |
| 105 | /// the hotwords will be loaded from the buffered string instead of from the | 111 | /// the hotwords will be loaded from the buffered string instead of from the |
| 106 | /// "hotwords_file" | 112 | /// "hotwords_file" |
| @@ -116,7 +122,8 @@ struct OnlineRecognizerConfig { | @@ -116,7 +122,8 @@ struct OnlineRecognizerConfig { | ||
| 116 | bool enable_endpoint, const std::string &decoding_method, | 122 | bool enable_endpoint, const std::string &decoding_method, |
| 117 | int32_t max_active_paths, const std::string &hotwords_file, | 123 | int32_t max_active_paths, const std::string &hotwords_file, |
| 118 | float hotwords_score, float blank_penalty, float temperature_scale, | 124 | float hotwords_score, float blank_penalty, float temperature_scale, |
| 119 | - const std::string &rule_fsts, const std::string &rule_fars) | 125 | + const std::string &rule_fsts, const std::string &rule_fars, |
| 126 | + bool reset_encoder) | ||
| 120 | : feat_config(feat_config), | 127 | : feat_config(feat_config), |
| 121 | model_config(model_config), | 128 | model_config(model_config), |
| 122 | lm_config(lm_config), | 129 | lm_config(lm_config), |
| @@ -130,7 +137,8 @@ struct OnlineRecognizerConfig { | @@ -130,7 +137,8 @@ struct OnlineRecognizerConfig { | ||
| 130 | blank_penalty(blank_penalty), | 137 | blank_penalty(blank_penalty), |
| 131 | temperature_scale(temperature_scale), | 138 | temperature_scale(temperature_scale), |
| 132 | rule_fsts(rule_fsts), | 139 | rule_fsts(rule_fsts), |
| 133 | - rule_fars(rule_fars) {} | 140 | + rule_fars(rule_fars), |
| 141 | + reset_encoder(reset_encoder) {} | ||
| 134 | 142 | ||
| 135 | void Register(ParseOptions *po); | 143 | void Register(ParseOptions *po); |
| 136 | bool Validate() const; | 144 | bool Validate() const; |
| @@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 58 | const OnlineLMConfig &, const EndpointConfig &, | 58 | const OnlineLMConfig &, const EndpointConfig &, |
| 59 | const OnlineCtcFstDecoderConfig &, bool, | 59 | const OnlineCtcFstDecoderConfig &, bool, |
| 60 | const std::string &, int32_t, const std::string &, float, | 60 | const std::string &, int32_t, const std::string &, float, |
| 61 | - float, float, const std::string &, const std::string &>(), | 61 | + float, float, const std::string &, const std::string &, bool>(), |
| 62 | py::arg("feat_config"), py::arg("model_config"), | 62 | py::arg("feat_config"), py::arg("model_config"), |
| 63 | py::arg("lm_config") = OnlineLMConfig(), | 63 | py::arg("lm_config") = OnlineLMConfig(), |
| 64 | py::arg("endpoint_config") = EndpointConfig(), | 64 | py::arg("endpoint_config") = EndpointConfig(), |
| @@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 67 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 67 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 68 | py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, | 68 | py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, |
| 69 | py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", | 69 | py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", |
| 70 | - py::arg("rule_fars") = "") | 70 | + py::arg("rule_fars") = "", py::arg("reset_encoder") = false) |
| 71 | .def_readwrite("feat_config", &PyClass::feat_config) | 71 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 72 | .def_readwrite("model_config", &PyClass::model_config) | 72 | .def_readwrite("model_config", &PyClass::model_config) |
| 73 | .def_readwrite("lm_config", &PyClass::lm_config) | 73 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | @@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { | ||
| 82 | .def_readwrite("temperature_scale", &PyClass::temperature_scale) | 82 | .def_readwrite("temperature_scale", &PyClass::temperature_scale) |
| 83 | .def_readwrite("rule_fsts", &PyClass::rule_fsts) | 83 | .def_readwrite("rule_fsts", &PyClass::rule_fsts) |
| 84 | .def_readwrite("rule_fars", &PyClass::rule_fars) | 84 | .def_readwrite("rule_fars", &PyClass::rule_fars) |
| 85 | + .def_readwrite("reset_encoder", &PyClass::reset_encoder) | ||
| 85 | .def("__str__", &PyClass::ToString); | 86 | .def("__str__", &PyClass::ToString); |
| 86 | } | 87 | } |
| 87 | 88 |
| @@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 75 | 75 | ||
| 76 | #if SHERPA_ONNX_ENABLE_TTS == 1 | 76 | #if SHERPA_ONNX_ENABLE_TTS == 1 |
| 77 | PybindOfflineTts(&m); | 77 | PybindOfflineTts(&m); |
| 78 | +#else | ||
| 79 | + /* Define "empty" TTS sybmbols */ | ||
| 80 | + m.attr("OfflineTtsKokoroModelConfig") = py::none(); | ||
| 81 | + m.attr("OfflineTtsMatchaModelConfig") = py::none(); | ||
| 82 | + m.attr("OfflineTtsModelConfig") = py::none(); | ||
| 83 | + m.attr("OfflineTtsVitsModelConfig") = py::none(); | ||
| 84 | + m.attr("GeneratedAudio") = py::none(); | ||
| 85 | + m.attr("OfflineTtsConfig") = py::none(); | ||
| 86 | + m.attr("OfflineTts") = py::none(); | ||
| 78 | #endif | 87 | #endif |
| 79 | 88 | ||
| 80 | PybindSpeakerEmbeddingExtractor(&m); | 89 | PybindSpeakerEmbeddingExtractor(&m); |
| @@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 85 | PybindFastClustering(&m); | 94 | PybindFastClustering(&m); |
| 86 | PybindOfflineSpeakerDiarizationResult(&m); | 95 | PybindOfflineSpeakerDiarizationResult(&m); |
| 87 | PybindOfflineSpeakerDiarization(&m); | 96 | PybindOfflineSpeakerDiarization(&m); |
| 97 | +#else | ||
| 98 | + /* Define "empty" diarization sybmbols */ | ||
| 99 | + m.attr("FastClusteringConfig") = py::none(); | ||
| 100 | + m.attr("FastClustering") = py::none(); | ||
| 101 | + m.attr("OfflineSpeakerDiarizationSegment") = py::none(); | ||
| 102 | + m.attr("OfflineSpeakerDiarizationResult") = py::none(); | ||
| 103 | + m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none(); | ||
| 104 | + m.attr("OfflineSpeakerSegmentationModelConfig") = py::none(); | ||
| 105 | + m.attr("OfflineSpeakerDiarizationConfig") = py::none(); | ||
| 106 | + m.attr("OfflineSpeakerDiarization") = py::none(); | ||
| 88 | #endif | 107 | #endif |
| 89 | 108 | ||
| 90 | PybindAlsa(&m); | 109 | PybindAlsa(&m); |
| @@ -68,6 +68,7 @@ class OnlineRecognizer(object): | @@ -68,6 +68,7 @@ class OnlineRecognizer(object): | ||
| 68 | lm_scale: float = 0.1, | 68 | lm_scale: float = 0.1, |
| 69 | lm_shallow_fusion: bool = True, | 69 | lm_shallow_fusion: bool = True, |
| 70 | temperature_scale: float = 2.0, | 70 | temperature_scale: float = 2.0, |
| 71 | + reset_encoder: bool = False, | ||
| 71 | debug: bool = False, | 72 | debug: bool = False, |
| 72 | rule_fsts: str = "", | 73 | rule_fsts: str = "", |
| 73 | rule_fars: str = "", | 74 | rule_fars: str = "", |
| @@ -162,6 +163,10 @@ class OnlineRecognizer(object): | @@ -162,6 +163,10 @@ class OnlineRecognizer(object): | ||
| 162 | Temperature scaling for output symbol confidence estiamation. | 163 | Temperature scaling for output symbol confidence estiamation. |
| 163 | It affects only confidence values, the decoding uses the original | 164 | It affects only confidence values, the decoding uses the original |
| 164 | logits without temperature. | 165 | logits without temperature. |
| 166 | + reset_encoder: | ||
| 167 | + True to reset `encoder_state` on an endpoint after empty segment. | ||
| 168 | + Done in `Reset()` method, after an endpoint was detected, | ||
| 169 | + currently only in `OnlineRecognizerTransducerImpl`. | ||
| 165 | model_type: | 170 | model_type: |
| 166 | Online transducer model type. Valid values are: conformer, lstm, | 171 | Online transducer model type. Valid values are: conformer, lstm, |
| 167 | zipformer, zipformer2. All other values lead to loading the model twice. | 172 | zipformer, zipformer2. All other values lead to loading the model twice. |
| @@ -305,6 +310,7 @@ class OnlineRecognizer(object): | @@ -305,6 +310,7 @@ class OnlineRecognizer(object): | ||
| 305 | temperature_scale=temperature_scale, | 310 | temperature_scale=temperature_scale, |
| 306 | rule_fsts=rule_fsts, | 311 | rule_fsts=rule_fsts, |
| 307 | rule_fars=rule_fars, | 312 | rule_fars=rule_fars, |
| 313 | + reset_encoder=reset_encoder, | ||
| 308 | ) | 314 | ) |
| 309 | 315 | ||
| 310 | self.recognizer = _Recognizer(recognizer_config) | 316 | self.recognizer = _Recognizer(recognizer_config) |
-
请 注册 或 登录 后发表评论