chiiyeh
Committed by GitHub

add blank_penalty for online transducer (#548)

@@ -217,6 +217,18 @@ def get_args(): @@ -217,6 +217,18 @@ def get_args():
217 ) 217 )
218 218
219 parser.add_argument( 219 parser.add_argument(
  220 + "--blank-penalty",
  221 + type=float,
  222 + default=0.0,
  223 + help="""
  224 + The penalty applied on blank symbol during decoding.
  225 + Note: It is a positive value that would be applied to logits like
  226 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  227 + [batch_size, vocab] and blank id is 0).
  228 + """,
  229 + )
  230 +
  231 + parser.add_argument(
220 "sound_files", 232 "sound_files",
221 type=str, 233 type=str,
222 nargs="+", 234 nargs="+",
@@ -290,6 +302,7 @@ def main(): @@ -290,6 +302,7 @@ def main():
290 lm_scale=args.lm_scale, 302 lm_scale=args.lm_scale,
291 hotwords_file=args.hotwords_file, 303 hotwords_file=args.hotwords_file,
292 hotwords_score=args.hotwords_score, 304 hotwords_score=args.hotwords_score,
  305 + blank_penalty=args.blank_penalty,
293 ) 306 )
294 elif args.zipformer2_ctc: 307 elif args.zipformer2_ctc:
295 recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( 308 recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
@@ -102,6 +102,17 @@ def get_args(): @@ -102,6 +102,17 @@ def get_args():
102 """, 102 """,
103 ) 103 )
104 104
  105 + parser.add_argument(
  106 + "--blank-penalty",
  107 + type=float,
  108 + default=0.0,
  109 + help="""
  110 + The penalty applied on blank symbol during decoding.
  111 + Note: It is a positive value that would be applied to logits like
  112 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  113 + [batch_size, vocab] and blank id is 0).
  114 + """,
  115 + )
105 116
106 return parser.parse_args() 117 return parser.parse_args()
107 118
@@ -130,6 +141,7 @@ def create_recognizer(args): @@ -130,6 +141,7 @@ def create_recognizer(args):
130 provider=args.provider, 141 provider=args.provider,
131 hotwords_file=args.hotwords_file, 142 hotwords_file=args.hotwords_file,
132 hotwords_score=args.hotwords_score, 143 hotwords_score=args.hotwords_score,
  144 + blank_penalty=args.blank_penalty,
133 ) 145 )
134 return recognizer 146 return recognizer
135 147
@@ -111,6 +111,17 @@ def get_args(): @@ -111,6 +111,17 @@ def get_args():
111 """, 111 """,
112 ) 112 )
113 113
  114 + parser.add_argument(
  115 + "--blank-penalty",
  116 + type=float,
  117 + default=0.0,
  118 + help="""
  119 + The penalty applied on blank symbol during decoding.
  120 + Note: It is a positive value that would be applied to logits like
  121 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  122 + [batch_size, vocab] and blank id is 0).
  123 + """,
  124 + )
114 125
115 return parser.parse_args() 126 return parser.parse_args()
116 127
@@ -136,6 +147,7 @@ def create_recognizer(args): @@ -136,6 +147,7 @@ def create_recognizer(args):
136 provider=args.provider, 147 provider=args.provider,
137 hotwords_file=args.hotwords_file, 148 hotwords_file=args.hotwords_file,
138 hotwords_score=args.hotwords_score, 149 hotwords_score=args.hotwords_score,
  150 + blank_penalty=args.blank_penalty,
139 ) 151 )
140 return recognizer 152 return recognizer
141 153
@@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser): @@ -241,6 +241,18 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
241 """, 241 """,
242 ) 242 )
243 243
  244 +def add_blank_penalty_args(parser: argparse.ArgumentParser):
  245 + parser.add_argument(
  246 + "--blank-penalty",
  247 + type=float,
  248 + default=0.0,
  249 + help="""
  250 + The penalty applied on blank symbol during decoding.
  251 + Note: It is a positive value that would be applied to logits like
  252 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  253 + [batch_size, vocab] and blank id is 0).
  254 + """,
  255 + )
244 256
245 def add_endpointing_args(parser: argparse.ArgumentParser): 257 def add_endpointing_args(parser: argparse.ArgumentParser):
246 parser.add_argument( 258 parser.add_argument(
@@ -284,6 +296,7 @@ def get_args(): @@ -284,6 +296,7 @@ def get_args():
284 add_decoding_args(parser) 296 add_decoding_args(parser)
285 add_endpointing_args(parser) 297 add_endpointing_args(parser)
286 add_hotwords_args(parser) 298 add_hotwords_args(parser)
  299 + add_blank_penalty_args(parser)
287 300
288 parser.add_argument( 301 parser.add_argument(
289 "--port", 302 "--port",
@@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: @@ -390,6 +403,7 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
390 max_active_paths=args.num_active_paths, 403 max_active_paths=args.num_active_paths,
391 hotwords_score=args.hotwords_score, 404 hotwords_score=args.hotwords_score,
392 hotwords_file=args.hotwords_file, 405 hotwords_file=args.hotwords_file,
  406 + blank_penalty=args.blank_penalty,
393 enable_endpoint_detection=args.use_endpoint != 0, 407 enable_endpoint_detection=args.use_endpoint != 0,
394 rule1_min_trailing_silence=args.rule1_min_trailing_silence, 408 rule1_min_trailing_silence=args.rule1_min_trailing_silence,
395 rule2_min_trailing_silence=args.rule2_min_trailing_silence, 409 rule2_min_trailing_silence=args.rule2_min_trailing_silence,
@@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -95,10 +95,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
95 95
96 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 96 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
97 model_.get(), lm_.get(), config_.max_active_paths, 97 model_.get(), lm_.get(), config_.max_active_paths,
98 - config_.lm_config.scale, unk_id_); 98 + config_.lm_config.scale, unk_id_, config_.blank_penalty);
99 } else if (config.decoding_method == "greedy_search") { 99 } else if (config.decoding_method == "greedy_search") {
100 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 100 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
101 - model_.get(), unk_id_); 101 + model_.get(), unk_id_, config_.blank_penalty);
102 } else { 102 } else {
103 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 103 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
104 config.decoding_method.c_str()); 104 config.decoding_method.c_str());
@@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -131,10 +131,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
131 131
132 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 132 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
133 model_.get(), lm_.get(), config_.max_active_paths, 133 model_.get(), lm_.get(), config_.max_active_paths,
134 - config_.lm_config.scale, unk_id_); 134 + config_.lm_config.scale, unk_id_, config_.blank_penalty);
135 } else if (config.decoding_method == "greedy_search") { 135 } else if (config.decoding_method == "greedy_search") {
136 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 136 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
137 - model_.get(), unk_id_); 137 + model_.get(), unk_id_, config_.blank_penalty);
138 } else { 138 } else {
139 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 139 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
140 config.decoding_method.c_str()); 140 config.decoding_method.c_str());
@@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -81,6 +81,12 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
81 "True to enable endpoint detection. False to disable it."); 81 "True to enable endpoint detection. False to disable it.");
82 po->Register("max-active-paths", &max_active_paths, 82 po->Register("max-active-paths", &max_active_paths,
83 "beam size used in modified beam search."); 83 "beam size used in modified beam search.");
  84 + po->Register("blank-penalty", &blank_penalty,
  85 + "The penalty applied on blank symbol during decoding. "
  86 + "Note: It is a positive value. "
  87 + "Increasing value will lead to lower deletion at the cost"
  88 + "of higher insertions. "
  89 + "Currently only applicable for transducer models.");
84 po->Register("hotwords-score", &hotwords_score, 90 po->Register("hotwords-score", &hotwords_score,
85 "The bonus score for each token in context word/phrase. " 91 "The bonus score for each token in context word/phrase. "
86 "Used only when decoding_method is modified_beam_search"); 92 "Used only when decoding_method is modified_beam_search");
@@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -131,7 +137,8 @@ std::string OnlineRecognizerConfig::ToString() const {
131 os << "max_active_paths=" << max_active_paths << ", "; 137 os << "max_active_paths=" << max_active_paths << ", ";
132 os << "hotwords_score=" << hotwords_score << ", "; 138 os << "hotwords_score=" << hotwords_score << ", ";
133 os << "hotwords_file=\"" << hotwords_file << "\", "; 139 os << "hotwords_file=\"" << hotwords_file << "\", ";
134 - os << "decoding_method=\"" << decoding_method << "\")"; 140 + os << "decoding_method=\"" << decoding_method << "\", ";
  141 + os << "blank_penalty=" << blank_penalty << ")";
135 142
136 return os.str(); 143 return os.str();
137 } 144 }
@@ -83,6 +83,8 @@ struct OnlineRecognizerConfig { @@ -83,6 +83,8 @@ struct OnlineRecognizerConfig {
83 float hotwords_score = 1.5; 83 float hotwords_score = 1.5;
84 std::string hotwords_file; 84 std::string hotwords_file;
85 85
  86 + float blank_penalty = 0.0;
  87 +
86 OnlineRecognizerConfig() = default; 88 OnlineRecognizerConfig() = default;
87 89
88 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, 90 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
@@ -92,7 +94,8 @@ struct OnlineRecognizerConfig { @@ -92,7 +94,8 @@ struct OnlineRecognizerConfig {
92 bool enable_endpoint, 94 bool enable_endpoint,
93 const std::string &decoding_method, 95 const std::string &decoding_method,
94 int32_t max_active_paths, 96 int32_t max_active_paths,
95 - const std::string &hotwords_file, float hotwords_score) 97 + const std::string &hotwords_file, float hotwords_score,
  98 + float blank_penalty)
96 : feat_config(feat_config), 99 : feat_config(feat_config),
97 model_config(model_config), 100 model_config(model_config),
98 lm_config(lm_config), 101 lm_config(lm_config),
@@ -101,7 +104,8 @@ struct OnlineRecognizerConfig { @@ -101,7 +104,8 @@ struct OnlineRecognizerConfig {
101 decoding_method(decoding_method), 104 decoding_method(decoding_method),
102 max_active_paths(max_active_paths), 105 max_active_paths(max_active_paths),
103 hotwords_score(hotwords_score), 106 hotwords_score(hotwords_score),
104 - hotwords_file(hotwords_file) {} 107 + hotwords_file(hotwords_file),
  108 + blank_penalty(blank_penalty) {}
105 109
106 void Register(ParseOptions *po); 110 void Register(ParseOptions *po);
107 bool Validate() const; 111 bool Validate() const;
@@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -116,11 +116,14 @@ void OnlineTransducerGreedySearchDecoder::Decode(
116 Ort::Value logit = model_->RunJoiner( 116 Ort::Value logit = model_->RunJoiner(
117 std::move(cur_encoder_out), View(&decoder_out)); 117 std::move(cur_encoder_out), View(&decoder_out));
118 118
119 - const float *p_logit = logit.GetTensorData<float>(); 119 + float *p_logit = logit.GetTensorMutableData<float>();
120 120
121 bool emitted = false; 121 bool emitted = false;
122 for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { 122 for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
123 auto &r = (*result)[i]; 123 auto &r = (*result)[i];
  124 + if (blank_penalty_ > 0.0) {
  125 + p_logit[0] -= blank_penalty_; // assuming blank id is 0
  126 + }
124 auto y = static_cast<int32_t>(std::distance( 127 auto y = static_cast<int32_t>(std::distance(
125 static_cast<const float *>(p_logit), 128 static_cast<const float *>(p_logit),
126 std::max_element(static_cast<const float *>(p_logit), 129 std::max_element(static_cast<const float *>(p_logit),
@@ -15,8 +15,9 @@ namespace sherpa_onnx { @@ -15,8 +15,9 @@ namespace sherpa_onnx {
15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { 15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
16 public: 16 public:
17 OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, 17 OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
18 - int32_t unk_id)  
19 - : model_(model), unk_id_(unk_id) {} 18 + int32_t unk_id,
  19 + float blank_penalty)
  20 + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
20 21
21 OnlineTransducerDecoderResult GetEmptyResult() const override; 22 OnlineTransducerDecoderResult GetEmptyResult() const override;
22 23
@@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { @@ -28,6 +29,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
28 private: 29 private:
29 OnlineTransducerModel *model_; // Not owned 30 OnlineTransducerModel *model_; // Not owned
30 int32_t unk_id_; 31 int32_t unk_id_;
  32 + float blank_penalty_;
31 }; 33 };
32 34
33 } // namespace sherpa_onnx 35 } // namespace sherpa_onnx
@@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -123,6 +123,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
123 model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); 123 model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
124 124
125 float *p_logit = logit.GetTensorMutableData<float>(); 125 float *p_logit = logit.GetTensorMutableData<float>();
  126 + if (blank_penalty_ > 0.0) {
  127 + // assuming blank id is 0
  128 + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
  129 + }
126 LogSoftmax(p_logit, vocab_size, num_hyps); 130 LogSoftmax(p_logit, vocab_size, num_hyps);
127 131
128 // now p_logit contains log_softmax output, we rename it to p_logprob 132 // now p_logit contains log_softmax output, we rename it to p_logprob
@@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -21,12 +21,14 @@ class OnlineTransducerModifiedBeamSearchDecoder
21 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, 21 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
22 OnlineLM *lm, 22 OnlineLM *lm,
23 int32_t max_active_paths, 23 int32_t max_active_paths,
24 - float lm_scale, int32_t unk_id) 24 + float lm_scale, int32_t unk_id,
  25 + float blank_penalty)
25 : model_(model), 26 : model_(model),
26 lm_(lm), 27 lm_(lm),
27 max_active_paths_(max_active_paths), 28 max_active_paths_(max_active_paths),
28 lm_scale_(lm_scale), 29 lm_scale_(lm_scale),
29 - unk_id_(unk_id) {} 30 + unk_id_(unk_id),
  31 + blank_penalty_(blank_penalty) {}
30 32
31 OnlineTransducerDecoderResult GetEmptyResult() const override; 33 OnlineTransducerDecoderResult GetEmptyResult() const override;
32 34
@@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -47,6 +49,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
47 int32_t max_active_paths_; 49 int32_t max_active_paths_;
48 float lm_scale_; // used only when lm_ is not nullptr 50 float lm_scale_; // used only when lm_ is not nullptr
49 int32_t unk_id_; 51 int32_t unk_id_;
  52 + float blank_penalty_;
50 }; 53 };
51 54
52 } // namespace sherpa_onnx 55 } // namespace sherpa_onnx
@@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -33,12 +33,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
33 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 33 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
34 .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, 34 .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
35 const OnlineLMConfig &, const EndpointConfig &, bool, 35 const OnlineLMConfig &, const EndpointConfig &, bool,
36 - const std::string &, int32_t, const std::string &, float>(), 36 + const std::string &, int32_t, const std::string &, float,
  37 + float>(),
37 py::arg("feat_config"), py::arg("model_config"), 38 py::arg("feat_config"), py::arg("model_config"),
38 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), 39 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
39 py::arg("enable_endpoint"), py::arg("decoding_method"), 40 py::arg("enable_endpoint"), py::arg("decoding_method"),
40 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 41 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
41 - py::arg("hotwords_score") = 0) 42 + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
42 .def_readwrite("feat_config", &PyClass::feat_config) 43 .def_readwrite("feat_config", &PyClass::feat_config)
43 .def_readwrite("model_config", &PyClass::model_config) 44 .def_readwrite("model_config", &PyClass::model_config)
44 .def_readwrite("lm_config", &PyClass::lm_config) 45 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -48,6 +49,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
48 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 49 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
49 .def_readwrite("hotwords_file", &PyClass::hotwords_file) 50 .def_readwrite("hotwords_file", &PyClass::hotwords_file)
50 .def_readwrite("hotwords_score", &PyClass::hotwords_score) 51 .def_readwrite("hotwords_score", &PyClass::hotwords_score)
  52 + .def_readwrite("blank_penalty", &PyClass::blank_penalty)
51 .def("__str__", &PyClass::ToString); 53 .def("__str__", &PyClass::ToString);
52 } 54 }
53 55
@@ -48,6 +48,7 @@ class OnlineRecognizer(object): @@ -48,6 +48,7 @@ class OnlineRecognizer(object):
48 decoding_method: str = "greedy_search", 48 decoding_method: str = "greedy_search",
49 max_active_paths: int = 4, 49 max_active_paths: int = 4,
50 hotwords_score: float = 1.5, 50 hotwords_score: float = 1.5,
  51 + blank_penalty: float = 0.0,
51 hotwords_file: str = "", 52 hotwords_file: str = "",
52 provider: str = "cpu", 53 provider: str = "cpu",
53 model_type: str = "", 54 model_type: str = "",
@@ -100,6 +101,8 @@ class OnlineRecognizer(object): @@ -100,6 +101,8 @@ class OnlineRecognizer(object):
100 max_active_paths: 101 max_active_paths:
101 Use only when decoding_method is modified_beam_search. It specifies 102 Use only when decoding_method is modified_beam_search. It specifies
102 the maximum number of active paths during beam search. 103 the maximum number of active paths during beam search.
  104 + blank_penalty:
  105 + The penalty applied on blank symbol during decoding.
103 hotwords_file: 106 hotwords_file:
104 The file containing hotwords, one words/phrases per line, and for each 107 The file containing hotwords, one words/phrases per line, and for each
105 phrase the bpe/cjkchar are separated by a space. 108 phrase the bpe/cjkchar are separated by a space.
@@ -172,6 +175,7 @@ class OnlineRecognizer(object): @@ -172,6 +175,7 @@ class OnlineRecognizer(object):
172 max_active_paths=max_active_paths, 175 max_active_paths=max_active_paths,
173 hotwords_score=hotwords_score, 176 hotwords_score=hotwords_score,
174 hotwords_file=hotwords_file, 177 hotwords_file=hotwords_file,
  178 + blank_penalty=blank_penalty,
175 ) 179 )
176 180
177 self.recognizer = _Recognizer(recognizer_config) 181 self.recognizer = _Recognizer(recognizer_config)