chiiyeh
Committed by GitHub

add blank_penalty for offline transducer (#542)

@@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser): @@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
383 """, 383 """,
384 ) 384 )
385 385
  386 +def add_blank_penalty_args(parser: argparse.ArgumentParser):
  387 + parser.add_argument(
  388 + "--blank-penalty",
  389 + type=float,
  390 + default=0.0,
  391 + help="""
  392 + The penalty applied on blank symbol during decoding.
  393 + Note: It is a positive value that would be applied to logits like
  394 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  395 + [batch_size, vocab] and blank id is 0).
  396 + """,
  397 + )
  398 +
386 399
387 def check_args(args): 400 def check_args(args):
388 if not Path(args.tokens).is_file(): 401 if not Path(args.tokens).is_file():
@@ -414,6 +427,7 @@ def get_args(): @@ -414,6 +427,7 @@ def get_args():
414 add_feature_config_args(parser) 427 add_feature_config_args(parser)
415 add_decoding_args(parser) 428 add_decoding_args(parser)
416 add_hotwords_args(parser) 429 add_hotwords_args(parser)
  430 + add_blank_penalty_args(parser)
417 431
418 parser.add_argument( 432 parser.add_argument(
419 "--port", 433 "--port",
@@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
862 max_active_paths=args.max_active_paths, 876 max_active_paths=args.max_active_paths,
863 hotwords_file=args.hotwords_file, 877 hotwords_file=args.hotwords_file,
864 hotwords_score=args.hotwords_score, 878 hotwords_score=args.hotwords_score,
  879 + blank_penalty=args.blank_penalty,
865 provider=args.provider, 880 provider=args.provider,
866 ) 881 )
867 elif args.paraformer: 882 elif args.paraformer:
@@ -232,6 +232,18 @@ def get_args(): @@ -232,6 +232,18 @@ def get_args():
232 ) 232 )
233 233
234 parser.add_argument( 234 parser.add_argument(
  235 + "--blank-penalty",
  236 + type=float,
  237 + default=0.0,
  238 + help="""
  239 + The penalty applied on blank symbol during decoding.
  240 + Note: It is a positive value that would be applied to logits like
  241 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  242 + [batch_size, vocab] and blank id is 0).
  243 + """,
  244 + )
  245 +
  246 + parser.add_argument(
235 "--decoding-method", 247 "--decoding-method",
236 type=str, 248 type=str,
237 default="greedy_search", 249 default="greedy_search",
@@ -335,6 +347,7 @@ def main(): @@ -335,6 +347,7 @@ def main():
335 decoding_method=args.decoding_method, 347 decoding_method=args.decoding_method,
336 hotwords_file=args.hotwords_file, 348 hotwords_file=args.hotwords_file,
337 hotwords_score=args.hotwords_score, 349 hotwords_score=args.hotwords_score,
  350 + blank_penalty=args.blank_penalty,
338 debug=args.debug, 351 debug=args.debug,
339 ) 352 )
340 elif args.paraformer: 353 elif args.paraformer:
@@ -178,6 +178,18 @@ def get_args(): @@ -178,6 +178,18 @@ def get_args():
178 ) 178 )
179 179
180 parser.add_argument( 180 parser.add_argument(
  181 + "--blank-penalty",
  182 + type=float,
  183 + default=0.0,
  184 + help="""
  185 + The penalty applied on blank symbol during decoding.
  186 + Note: It is a positive value that would be applied to logits like
  187 + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
  188 + [batch_size, vocab] and blank id is 0).
  189 + """,
  190 + )
  191 +
  192 + parser.add_argument(
181 "--decoding-method", 193 "--decoding-method",
182 type=str, 194 type=str,
183 default="greedy_search", 195 default="greedy_search",
@@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
237 sample_rate=args.sample_rate, 249 sample_rate=args.sample_rate,
238 feature_dim=args.feature_dim, 250 feature_dim=args.feature_dim,
239 decoding_method=args.decoding_method, 251 decoding_method=args.decoding_method,
  252 + blank_penalty=args.blank_penalty,
240 debug=args.debug, 253 debug=args.debug,
241 ) 254 )
242 elif args.paraformer: 255 elif args.paraformer:
@@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) { @@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
96 } 96 }
97 } 97 }
98 98
  99 +template <typename T>
  100 +void SubtractBlank(T *in, int32_t w, int32_t h,
  101 + int32_t blank_idx, float blank_penalty) {
  102 + for (int32_t i = 0; i != h; ++i) {
  103 + in[blank_idx] -= blank_penalty;
  104 + in += w;
  105 + }
  106 +}
  107 +
99 template <class T> 108 template <class T>
100 std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { 109 std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
101 std::vector<int32_t> vec_index(size); 110 std::vector<int32_t> vec_index(size);
@@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
79 } 79 }
80 if (config_.decoding_method == "greedy_search") { 80 if (config_.decoding_method == "greedy_search") {
81 decoder_ = 81 decoder_ =
82 - std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); 82 + std::make_unique<OfflineTransducerGreedySearchDecoder>(
  83 + model_.get(), config_.blank_penalty);
83 } else if (config_.decoding_method == "modified_beam_search") { 84 } else if (config_.decoding_method == "modified_beam_search") {
84 if (!config_.lm_config.model.empty()) { 85 if (!config_.lm_config.model.empty()) {
85 lm_ = OfflineLM::Create(config.lm_config); 86 lm_ = OfflineLM::Create(config.lm_config);
@@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
87 88
88 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 89 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
89 model_.get(), lm_.get(), config_.max_active_paths, 90 model_.get(), lm_.get(), config_.max_active_paths,
90 - config_.lm_config.scale); 91 + config_.lm_config.scale, config_.blank_penalty);
91 } else { 92 } else {
92 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 93 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
93 config_.decoding_method.c_str()); 94 config_.decoding_method.c_str());
@@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
104 config_.model_config)) { 105 config_.model_config)) {
105 if (config_.decoding_method == "greedy_search") { 106 if (config_.decoding_method == "greedy_search") {
106 decoder_ = 107 decoder_ =
107 - std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); 108 + std::make_unique<OfflineTransducerGreedySearchDecoder>(
  109 + model_.get(), config_.blank_penalty);
108 } else if (config_.decoding_method == "modified_beam_search") { 110 } else if (config_.decoding_method == "modified_beam_search") {
109 if (!config_.lm_config.model.empty()) { 111 if (!config_.lm_config.model.empty()) {
110 lm_ = OfflineLM::Create(mgr, config.lm_config); 112 lm_ = OfflineLM::Create(mgr, config.lm_config);
@@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
112 114
113 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 115 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
114 model_.get(), lm_.get(), config_.max_active_paths, 116 model_.get(), lm_.get(), config_.max_active_paths,
115 - config_.lm_config.scale); 117 + config_.lm_config.scale, config_.blank_penalty);
116 } else { 118 } else {
117 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 119 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
118 config_.decoding_method.c_str()); 120 config_.decoding_method.c_str());
@@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
28 po->Register("max-active-paths", &max_active_paths, 28 po->Register("max-active-paths", &max_active_paths,
29 "Used only when decoding_method is modified_beam_search"); 29 "Used only when decoding_method is modified_beam_search");
30 30
  31 + po->Register("blank-penalty", &blank_penalty,
  32 + "The penalty applied on blank symbol during decoding. "
  33 + "Note: It is a positive value. "
  34 + "Increasing value will lead to lower deletion at the cost"
  35 + "of higher insertions. "
  36 + "Currently only applicable for transducer models.");
  37 +
31 po->Register( 38 po->Register(
32 "hotwords-file", &hotwords_file, 39 "hotwords-file", &hotwords_file,
33 "The file containing hotwords, one words/phrases per line, and for each" 40 "The file containing hotwords, one words/phrases per line, and for each"
@@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const {
74 os << "decoding_method=\"" << decoding_method << "\", "; 81 os << "decoding_method=\"" << decoding_method << "\", ";
75 os << "max_active_paths=" << max_active_paths << ", "; 82 os << "max_active_paths=" << max_active_paths << ", ";
76 os << "hotwords_file=\"" << hotwords_file << "\", "; 83 os << "hotwords_file=\"" << hotwords_file << "\", ";
77 - os << "hotwords_score=" << hotwords_score << ")"; 84 + os << "hotwords_score=" << hotwords_score << ", ";
  85 + os << "blank_penalty=" << blank_penalty << ")";
78 86
79 return os.str(); 87 return os.str();
80 } 88 }
@@ -37,6 +37,8 @@ struct OfflineRecognizerConfig { @@ -37,6 +37,8 @@ struct OfflineRecognizerConfig {
37 std::string hotwords_file; 37 std::string hotwords_file;
38 float hotwords_score = 1.5; 38 float hotwords_score = 1.5;
39 39
  40 + float blank_penalty = 0.0;
  41 +
40 // only greedy_search is implemented 42 // only greedy_search is implemented
41 // TODO(fangjun): Implement modified_beam_search 43 // TODO(fangjun): Implement modified_beam_search
42 44
@@ -46,7 +48,8 @@ struct OfflineRecognizerConfig { @@ -46,7 +48,8 @@ struct OfflineRecognizerConfig {
46 const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, 48 const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
47 const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, 49 const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
48 const std::string &decoding_method, int32_t max_active_paths, 50 const std::string &decoding_method, int32_t max_active_paths,
49 - const std::string &hotwords_file, float hotwords_score) 51 + const std::string &hotwords_file, float hotwords_score,
  52 + float blank_penalty)
50 : feat_config(feat_config), 53 : feat_config(feat_config),
51 model_config(model_config), 54 model_config(model_config),
52 lm_config(lm_config), 55 lm_config(lm_config),
@@ -54,7 +57,8 @@ struct OfflineRecognizerConfig { @@ -54,7 +57,8 @@ struct OfflineRecognizerConfig {
54 decoding_method(decoding_method), 57 decoding_method(decoding_method),
55 max_active_paths(max_active_paths), 58 max_active_paths(max_active_paths),
56 hotwords_file(hotwords_file), 59 hotwords_file(hotwords_file),
57 - hotwords_score(hotwords_score) {} 60 + hotwords_score(hotwords_score),
  61 + blank_penalty(blank_penalty) {}
58 62
59 void Register(ParseOptions *po); 63 void Register(ParseOptions *po);
60 bool Validate() const; 64 bool Validate() const;
@@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, @@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
46 start += n; 46 start += n;
47 Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out), 47 Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
48 std::move(cur_decoder_out)); 48 std::move(cur_decoder_out));
49 - const float *p_logit = logit.GetTensorData<float>(); 49 + float *p_logit = logit.GetTensorMutableData<float>();
50 bool emitted = false; 50 bool emitted = false;
51 for (int32_t i = 0; i != n; ++i) { 51 for (int32_t i = 0; i != n; ++i) {
  52 + if (blank_penalty_ > 0.0) {
  53 + p_logit[0] -= blank_penalty_; // assuming blank id is 0
  54 + }
52 auto y = static_cast<int32_t>(std::distance( 55 auto y = static_cast<int32_t>(std::distance(
53 static_cast<const float *>(p_logit), 56 static_cast<const float *>(p_logit),
54 std::max_element(static_cast<const float *>(p_logit), 57 std::max_element(static_cast<const float *>(p_logit),
@@ -14,8 +14,10 @@ namespace sherpa_onnx { @@ -14,8 +14,10 @@ namespace sherpa_onnx {
14 14
15 class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { 15 class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
16 public: 16 public:
17 - explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)  
18 - : model_(model) {} 17 + explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
  18 + float blank_penalty)
  19 + : model_(model),
  20 + blank_penalty_(blank_penalty) {}
19 21
20 std::vector<OfflineTransducerDecoderResult> Decode( 22 std::vector<OfflineTransducerDecoderResult> Decode(
21 Ort::Value encoder_out, Ort::Value encoder_out_length, 23 Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { @@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
23 25
24 private: 26 private:
25 OfflineTransducerModel *model_; // Not owned 27 OfflineTransducerModel *model_; // Not owned
  28 + float blank_penalty_;
26 }; 29 };
27 30
28 } // namespace sherpa_onnx 31 } // namespace sherpa_onnx
@@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
97 model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); 97 model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
98 98
99 float *p_logit = logit.GetTensorMutableData<float>(); 99 float *p_logit = logit.GetTensorMutableData<float>();
  100 + if (blank_penalty_ > 0.0) {
  101 + // assuming blank id is 0
  102 + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
  103 + }
100 LogSoftmax(p_logit, vocab_size, num_hyps); 104 LogSoftmax(p_logit, vocab_size, num_hyps);
101 105
102 // now p_logit contains log_softmax output, we rename it to p_logprob 106 // now p_logit contains log_softmax output, we rename it to p_logprob
@@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder @@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
19 OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, 19 OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
20 OfflineLM *lm, 20 OfflineLM *lm,
21 int32_t max_active_paths, 21 int32_t max_active_paths,
22 - float lm_scale) 22 + float lm_scale,
  23 + float blank_penalty)
23 : model_(model), 24 : model_(model),
24 lm_(lm), 25 lm_(lm),
25 max_active_paths_(max_active_paths), 26 max_active_paths_(max_active_paths),
26 - lm_scale_(lm_scale) {} 27 + lm_scale_(lm_scale),
  28 + blank_penalty_(blank_penalty) {}
27 29
28 std::vector<OfflineTransducerDecoderResult> Decode( 30 std::vector<OfflineTransducerDecoderResult> Decode(
29 Ort::Value encoder_out, Ort::Value encoder_out_length, 31 Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder @@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
35 37
36 int32_t max_active_paths_; 38 int32_t max_active_paths_;
37 float lm_scale_; // used only when lm_ is not nullptr 39 float lm_scale_; // used only when lm_ is not nullptr
  40 + float blank_penalty_;
38 }; 41 };
39 42
40 } // namespace sherpa_onnx 43 } // namespace sherpa_onnx
@@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
17 .def(py::init<const OfflineFeatureExtractorConfig &, 17 .def(py::init<const OfflineFeatureExtractorConfig &,
18 const OfflineModelConfig &, const OfflineLMConfig &, 18 const OfflineModelConfig &, const OfflineLMConfig &,
19 const OfflineCtcFstDecoderConfig &, const std::string &, 19 const OfflineCtcFstDecoderConfig &, const std::string &,
20 - int32_t, const std::string &, float>(), 20 + int32_t, const std::string &, float, float>(),
21 py::arg("feat_config"), py::arg("model_config"), 21 py::arg("feat_config"), py::arg("model_config"),
22 py::arg("lm_config") = OfflineLMConfig(), 22 py::arg("lm_config") = OfflineLMConfig(),
23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), 23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
24 py::arg("decoding_method") = "greedy_search", 24 py::arg("decoding_method") = "greedy_search",
25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
26 - py::arg("hotwords_score") = 1.5) 26 + py::arg("hotwords_score") = 1.5,
  27 + py::arg("blank_penalty") = 0.0)
27 .def_readwrite("feat_config", &PyClass::feat_config) 28 .def_readwrite("feat_config", &PyClass::feat_config)
28 .def_readwrite("model_config", &PyClass::model_config) 29 .def_readwrite("model_config", &PyClass::model_config)
29 .def_readwrite("lm_config", &PyClass::lm_config) 30 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
32 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 33 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
33 .def_readwrite("hotwords_file", &PyClass::hotwords_file) 34 .def_readwrite("hotwords_file", &PyClass::hotwords_file)
34 .def_readwrite("hotwords_score", &PyClass::hotwords_score) 35 .def_readwrite("hotwords_score", &PyClass::hotwords_score)
  36 + .def_readwrite("blank_penalty", &PyClass::blank_penalty)
35 .def("__str__", &PyClass::ToString); 37 .def("__str__", &PyClass::ToString);
36 } 38 }
37 39
@@ -48,6 +48,7 @@ class OfflineRecognizer(object): @@ -48,6 +48,7 @@ class OfflineRecognizer(object):
48 max_active_paths: int = 4, 48 max_active_paths: int = 4,
49 hotwords_file: str = "", 49 hotwords_file: str = "",
50 hotwords_score: float = 1.5, 50 hotwords_score: float = 1.5,
  51 + blank_penalty: float = 0.0,
51 debug: bool = False, 52 debug: bool = False,
52 provider: str = "cpu", 53 provider: str = "cpu",
53 ): 54 ):
@@ -81,6 +82,8 @@ class OfflineRecognizer(object): @@ -81,6 +82,8 @@ class OfflineRecognizer(object):
81 max_active_paths: 82 max_active_paths:
82 Maximum number of active paths to keep. Used only when 83 Maximum number of active paths to keep. Used only when
83 decoding_method is modified_beam_search. 84 decoding_method is modified_beam_search.
  85 + blank_penalty:
  86 + The penalty applied on blank symbol during decoding.
84 debug: 87 debug:
85 True to show debug messages. 88 True to show debug messages.
86 provider: 89 provider:
@@ -117,6 +120,7 @@ class OfflineRecognizer(object): @@ -117,6 +120,7 @@ class OfflineRecognizer(object):
117 decoding_method=decoding_method, 120 decoding_method=decoding_method,
118 hotwords_file=hotwords_file, 121 hotwords_file=hotwords_file,
119 hotwords_score=hotwords_score, 122 hotwords_score=hotwords_score,
  123 + blank_penalty=blank_penalty,
120 ) 124 )
121 self.recognizer = _Recognizer(recognizer_config) 125 self.recognizer = _Recognizer(recognizer_config)
122 self.config = recognizer_config 126 self.config = recognizer_config