Karel Vesely
Committed by GitHub

online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank) (#2129)

* online-transducer: reset the encoder toghter with 2 previous output symbols (non-blank)

- added `reset_encoder` boolean member into the OnlineRecognizerConfig class
- by default the encoder is not reset

* pybind11, adding empty symbols for disabled modules (tts, diarization)

* reset_encoder, add default value (false) [pybind11]
@@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -382,14 +382,13 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
382 } 382 }
383 } 383 }
384 384
385 - // reset encoder states  
386 - // s->SetStates(model_->GetEncoderInitStates());  
387 -  
388 auto r = decoder_->GetEmptyResult(); 385 auto r = decoder_->GetEmptyResult();
389 auto last_result = s->GetResult(); 386 auto last_result = s->GetResult();
390 - // if last result is not empty, then  
391 - // truncate all last hyps and save as the context for next result 387 +
392 if (static_cast<int32_t>(last_result.tokens.size()) > context_size) { 388 if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
  389 + // if last result is not empty, then
  390 + // truncate all last hyps and save as the 'ys' context for next result
  391 + // (the encoder state buffers are kept)
393 for (const auto &it : last_result.hyps) { 392 for (const auto &it : last_result.hyps) {
394 auto h = it.second; 393 auto h = it.second;
395 r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size, 394 r.hyps.Add({std::vector<int64_t>(h.ys.end() - context_size,
@@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -399,6 +398,11 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
399 398
400 r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size, 399 r.tokens = std::vector<int64_t> (last_result.tokens.end() - context_size,
401 last_result.tokens.end()); 400 last_result.tokens.end());
  401 + } else {
  402 + if(config_.reset_encoder) {
  403 + // reset encoder states, use blanks as 'ys' context
  404 + s->SetStates(model_->GetEncoderInitStates());
  405 + }
402 } 406 }
403 407
404 // but reset all contextual biasing graph states to root 408 // but reset all contextual biasing graph states to root
@@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -121,6 +121,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
121 "rule-fars", &rule_fars, 121 "rule-fars", &rule_fars,
122 "If not empty, it specifies fst archives for inverse text normalization. " 122 "If not empty, it specifies fst archives for inverse text normalization. "
123 "If there are multiple archives, they are separated by a comma."); 123 "If there are multiple archives, they are separated by a comma.");
  124 +
  125 + po->Register("reset-encoder", &reset_encoder,
  126 + "True to reset encoder_state on an endpoint after empty segment."
  127 + "Done in `Reset()` method, after an endpoint was detected.");
124 } 128 }
125 129
126 bool OnlineRecognizerConfig::Validate() const { 130 bool OnlineRecognizerConfig::Validate() const {
@@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -198,7 +202,8 @@ std::string OnlineRecognizerConfig::ToString() const {
198 os << "blank_penalty=" << blank_penalty << ", "; 202 os << "blank_penalty=" << blank_penalty << ", ";
199 os << "temperature_scale=" << temperature_scale << ", "; 203 os << "temperature_scale=" << temperature_scale << ", ";
200 os << "rule_fsts=\"" << rule_fsts << "\", "; 204 os << "rule_fsts=\"" << rule_fsts << "\", ";
201 - os << "rule_fars=\"" << rule_fars << "\")"; 205 + os << "rule_fars=\"" << rule_fars << "\", ";
  206 + os << "reset_encoder=\"" << (reset_encoder ? "True" : "False") << "\")";
202 207
203 return os.str(); 208 return os.str();
204 } 209 }
@@ -79,6 +79,7 @@ struct OnlineRecognizerConfig { @@ -79,6 +79,7 @@ struct OnlineRecognizerConfig {
79 OnlineLMConfig lm_config; 79 OnlineLMConfig lm_config;
80 EndpointConfig endpoint_config; 80 EndpointConfig endpoint_config;
81 OnlineCtcFstDecoderConfig ctc_fst_decoder_config; 81 OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
  82 +
82 bool enable_endpoint = true; 83 bool enable_endpoint = true;
83 84
84 std::string decoding_method = "greedy_search"; 85 std::string decoding_method = "greedy_search";
@@ -101,6 +102,11 @@ struct OnlineRecognizerConfig { @@ -101,6 +102,11 @@ struct OnlineRecognizerConfig {
101 // If there are multiple FST archives, they are applied from left to right. 102 // If there are multiple FST archives, they are applied from left to right.
102 std::string rule_fars; 103 std::string rule_fars;
103 104
  105 + // True to reset encoder_state on an endpoint after empty segment.
  106 + // Done in `Reset()` method, after an endpoint was detected,
  107 + // currently only in `OnlineRecognizerTransducerImpl`.
  108 + bool reset_encoder = false;
  109 +
104 /// used only for modified_beam_search, if hotwords_buf is non-empty, 110 /// used only for modified_beam_search, if hotwords_buf is non-empty,
105 /// the hotwords will be loaded from the buffered string instead of from the 111 /// the hotwords will be loaded from the buffered string instead of from the
106 /// "hotwords_file" 112 /// "hotwords_file"
@@ -116,7 +122,8 @@ struct OnlineRecognizerConfig { @@ -116,7 +122,8 @@ struct OnlineRecognizerConfig {
116 bool enable_endpoint, const std::string &decoding_method, 122 bool enable_endpoint, const std::string &decoding_method,
117 int32_t max_active_paths, const std::string &hotwords_file, 123 int32_t max_active_paths, const std::string &hotwords_file,
118 float hotwords_score, float blank_penalty, float temperature_scale, 124 float hotwords_score, float blank_penalty, float temperature_scale,
119 - const std::string &rule_fsts, const std::string &rule_fars) 125 + const std::string &rule_fsts, const std::string &rule_fars,
  126 + bool reset_encoder)
120 : feat_config(feat_config), 127 : feat_config(feat_config),
121 model_config(model_config), 128 model_config(model_config),
122 lm_config(lm_config), 129 lm_config(lm_config),
@@ -130,7 +137,8 @@ struct OnlineRecognizerConfig { @@ -130,7 +137,8 @@ struct OnlineRecognizerConfig {
130 blank_penalty(blank_penalty), 137 blank_penalty(blank_penalty),
131 temperature_scale(temperature_scale), 138 temperature_scale(temperature_scale),
132 rule_fsts(rule_fsts), 139 rule_fsts(rule_fsts),
133 - rule_fars(rule_fars) {} 140 + rule_fars(rule_fars),
  141 + reset_encoder(reset_encoder) {}
134 142
135 void Register(ParseOptions *po); 143 void Register(ParseOptions *po);
136 bool Validate() const; 144 bool Validate() const;
@@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -58,7 +58,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
58 const OnlineLMConfig &, const EndpointConfig &, 58 const OnlineLMConfig &, const EndpointConfig &,
59 const OnlineCtcFstDecoderConfig &, bool, 59 const OnlineCtcFstDecoderConfig &, bool,
60 const std::string &, int32_t, const std::string &, float, 60 const std::string &, int32_t, const std::string &, float,
61 - float, float, const std::string &, const std::string &>(), 61 + float, float, const std::string &, const std::string &, bool>(),
62 py::arg("feat_config"), py::arg("model_config"), 62 py::arg("feat_config"), py::arg("model_config"),
63 py::arg("lm_config") = OnlineLMConfig(), 63 py::arg("lm_config") = OnlineLMConfig(),
64 py::arg("endpoint_config") = EndpointConfig(), 64 py::arg("endpoint_config") = EndpointConfig(),
@@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -67,7 +67,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
67 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 67 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
68 py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, 68 py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
69 py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", 69 py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
70 - py::arg("rule_fars") = "") 70 + py::arg("rule_fars") = "", py::arg("reset_encoder") = false)
71 .def_readwrite("feat_config", &PyClass::feat_config) 71 .def_readwrite("feat_config", &PyClass::feat_config)
72 .def_readwrite("model_config", &PyClass::model_config) 72 .def_readwrite("model_config", &PyClass::model_config)
73 .def_readwrite("lm_config", &PyClass::lm_config) 73 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -82,6 +82,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
82 .def_readwrite("temperature_scale", &PyClass::temperature_scale) 82 .def_readwrite("temperature_scale", &PyClass::temperature_scale)
83 .def_readwrite("rule_fsts", &PyClass::rule_fsts) 83 .def_readwrite("rule_fsts", &PyClass::rule_fsts)
84 .def_readwrite("rule_fars", &PyClass::rule_fars) 84 .def_readwrite("rule_fars", &PyClass::rule_fars)
  85 + .def_readwrite("reset_encoder", &PyClass::reset_encoder)
85 .def("__str__", &PyClass::ToString); 86 .def("__str__", &PyClass::ToString);
86 } 87 }
87 88
@@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -75,6 +75,15 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
75 75
76 #if SHERPA_ONNX_ENABLE_TTS == 1 76 #if SHERPA_ONNX_ENABLE_TTS == 1
77 PybindOfflineTts(&m); 77 PybindOfflineTts(&m);
  78 +#else
  79 + /* Define "empty" TTS sybmbols */
  80 + m.attr("OfflineTtsKokoroModelConfig") = py::none();
  81 + m.attr("OfflineTtsMatchaModelConfig") = py::none();
  82 + m.attr("OfflineTtsModelConfig") = py::none();
  83 + m.attr("OfflineTtsVitsModelConfig") = py::none();
  84 + m.attr("GeneratedAudio") = py::none();
  85 + m.attr("OfflineTtsConfig") = py::none();
  86 + m.attr("OfflineTts") = py::none();
78 #endif 87 #endif
79 88
80 PybindSpeakerEmbeddingExtractor(&m); 89 PybindSpeakerEmbeddingExtractor(&m);
@@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -85,6 +94,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
85 PybindFastClustering(&m); 94 PybindFastClustering(&m);
86 PybindOfflineSpeakerDiarizationResult(&m); 95 PybindOfflineSpeakerDiarizationResult(&m);
87 PybindOfflineSpeakerDiarization(&m); 96 PybindOfflineSpeakerDiarization(&m);
  97 +#else
  98 + /* Define "empty" diarization sybmbols */
  99 + m.attr("FastClusteringConfig") = py::none();
  100 + m.attr("FastClustering") = py::none();
  101 + m.attr("OfflineSpeakerDiarizationSegment") = py::none();
  102 + m.attr("OfflineSpeakerDiarizationResult") = py::none();
  103 + m.attr("OfflineSpeakerSegmentationPyannoteModelConfig") = py::none();
  104 + m.attr("OfflineSpeakerSegmentationModelConfig") = py::none();
  105 + m.attr("OfflineSpeakerDiarizationConfig") = py::none();
  106 + m.attr("OfflineSpeakerDiarization") = py::none();
88 #endif 107 #endif
89 108
90 PybindAlsa(&m); 109 PybindAlsa(&m);
@@ -68,6 +68,7 @@ class OnlineRecognizer(object): @@ -68,6 +68,7 @@ class OnlineRecognizer(object):
68 lm_scale: float = 0.1, 68 lm_scale: float = 0.1,
69 lm_shallow_fusion: bool = True, 69 lm_shallow_fusion: bool = True,
70 temperature_scale: float = 2.0, 70 temperature_scale: float = 2.0,
  71 + reset_encoder: bool = False,
71 debug: bool = False, 72 debug: bool = False,
72 rule_fsts: str = "", 73 rule_fsts: str = "",
73 rule_fars: str = "", 74 rule_fars: str = "",
@@ -162,6 +163,10 @@ class OnlineRecognizer(object): @@ -162,6 +163,10 @@ class OnlineRecognizer(object):
162 Temperature scaling for output symbol confidence estiamation. 163 Temperature scaling for output symbol confidence estiamation.
163 It affects only confidence values, the decoding uses the original 164 It affects only confidence values, the decoding uses the original
164 logits without temperature. 165 logits without temperature.
  166 + reset_encoder:
  167 + True to reset `encoder_state` on an endpoint after empty segment.
  168 + Done in `Reset()` method, after an endpoint was detected,
  169 + currently only in `OnlineRecognizerTransducerImpl`.
165 model_type: 170 model_type:
166 Online transducer model type. Valid values are: conformer, lstm, 171 Online transducer model type. Valid values are: conformer, lstm,
167 zipformer, zipformer2. All other values lead to loading the model twice. 172 zipformer, zipformer2. All other values lead to loading the model twice.
@@ -305,6 +310,7 @@ class OnlineRecognizer(object): @@ -305,6 +310,7 @@ class OnlineRecognizer(object):
305 temperature_scale=temperature_scale, 310 temperature_scale=temperature_scale,
306 rule_fsts=rule_fsts, 311 rule_fsts=rule_fsts,
307 rule_fars=rule_fars, 312 rule_fars=rule_fars,
  313 + reset_encoder=reset_encoder,
308 ) 314 )
309 315
310 self.recognizer = _Recognizer(recognizer_config) 316 self.recognizer = _Recognizer(recognizer_config)