Zhong-Yi Li
Committed by GitHub

offline transducer: treat unk as blank (#1005)

Co-authored-by: chungyi.li <chungyi.li@ailabs.tw>
@@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
78 config_(config), 78 config_(config),
79 symbol_table_(config_.model_config.tokens), 79 symbol_table_(config_.model_config.tokens),
80 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) { 80 model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
  81 + if (symbol_table_.Contains("<unk>")) {
  82 + unk_id_ = symbol_table_["<unk>"];
  83 + }
  84 +
81 if (config_.decoding_method == "greedy_search") { 85 if (config_.decoding_method == "greedy_search") {
82 decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( 86 decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
83 - model_.get(), config_.blank_penalty); 87 + model_.get(), unk_id_, config_.blank_penalty);
84 } else if (config_.decoding_method == "modified_beam_search") { 88 } else if (config_.decoding_method == "modified_beam_search") {
85 if (!config_.lm_config.model.empty()) { 89 if (!config_.lm_config.model.empty()) {
86 lm_ = OfflineLM::Create(config.lm_config); 90 lm_ = OfflineLM::Create(config.lm_config);
@@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
97 101
98 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 102 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
99 model_.get(), lm_.get(), config_.max_active_paths, 103 model_.get(), lm_.get(), config_.max_active_paths,
100 - config_.lm_config.scale, config_.blank_penalty); 104 + config_.lm_config.scale, unk_id_, config_.blank_penalty);
101 } else { 105 } else {
102 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 106 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
103 config_.decoding_method.c_str()); 107 config_.decoding_method.c_str());
@@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
113 symbol_table_(mgr, config_.model_config.tokens), 117 symbol_table_(mgr, config_.model_config.tokens),
114 model_(std::make_unique<OfflineTransducerModel>(mgr, 118 model_(std::make_unique<OfflineTransducerModel>(mgr,
115 config_.model_config)) { 119 config_.model_config)) {
  120 + if (symbol_table_.Contains("<unk>")) {
  121 + unk_id_ = symbol_table_["<unk>"];
  122 + }
  123 +
116 if (config_.decoding_method == "greedy_search") { 124 if (config_.decoding_method == "greedy_search") {
117 decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( 125 decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
118 - model_.get(), config_.blank_penalty); 126 + model_.get(), unk_id_, config_.blank_penalty);
119 } else if (config_.decoding_method == "modified_beam_search") { 127 } else if (config_.decoding_method == "modified_beam_search") {
120 if (!config_.lm_config.model.empty()) { 128 if (!config_.lm_config.model.empty()) {
121 lm_ = OfflineLM::Create(mgr, config.lm_config); 129 lm_ = OfflineLM::Create(mgr, config.lm_config);
@@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
133 141
134 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( 142 decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
135 model_.get(), lm_.get(), config_.max_active_paths, 143 model_.get(), lm_.get(), config_.max_active_paths,
136 - config_.lm_config.scale, config_.blank_penalty); 144 + config_.lm_config.scale, unk_id_, config_.blank_penalty);
137 } else { 145 } else {
138 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 146 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
139 config_.decoding_method.c_str()); 147 config_.decoding_method.c_str());
@@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
293 std::unique_ptr<OfflineTransducerModel> model_; 301 std::unique_ptr<OfflineTransducerModel> model_;
294 std::unique_ptr<OfflineTransducerDecoder> decoder_; 302 std::unique_ptr<OfflineTransducerDecoder> decoder_;
295 std::unique_ptr<OfflineLM> lm_; 303 std::unique_ptr<OfflineLM> lm_;
  304 + int32_t unk_id_ = -1;
296 }; 305 };
297 306
298 } // namespace sherpa_onnx 307 } // namespace sherpa_onnx
@@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, @@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
57 std::max_element(static_cast<const float *>(p_logit), 57 std::max_element(static_cast<const float *>(p_logit),
58 static_cast<const float *>(p_logit) + vocab_size))); 58 static_cast<const float *>(p_logit) + vocab_size)));
59 p_logit += vocab_size; 59 p_logit += vocab_size;
60 - if (y != 0) { 60 + // blank id is hardcoded to 0
  61 + // also, it treats unk as blank
  62 + if (y != 0 && y != unk_id_) {
61 ans[i].tokens.push_back(y); 63 ans[i].tokens.push_back(y);
62 ans[i].timestamps.push_back(t); 64 ans[i].timestamps.push_back(t);
63 emitted = true; 65 emitted = true;
@@ -15,8 +15,9 @@ namespace sherpa_onnx { @@ -15,8 +15,9 @@ namespace sherpa_onnx {
15 class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { 15 class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
16 public: 16 public:
17 OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, 17 OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
  18 + int32_t unk_id,
18 float blank_penalty) 19 float blank_penalty)
19 - : model_(model), blank_penalty_(blank_penalty) {} 20 + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
20 21
21 std::vector<OfflineTransducerDecoderResult> Decode( 22 std::vector<OfflineTransducerDecoderResult> Decode(
22 Ort::Value encoder_out, Ort::Value encoder_out_length, 23 Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { @@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
24 25
25 private: 26 private:
26 OfflineTransducerModel *model_; // Not owned 27 OfflineTransducerModel *model_; // Not owned
  28 + int32_t unk_id_;
27 float blank_penalty_; 29 float blank_penalty_;
28 }; 30 };
29 31
@@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
131 131
132 float context_score = 0; 132 float context_score = 0;
133 auto context_state = new_hyp.context_state; 133 auto context_state = new_hyp.context_state;
134 - if (new_token != 0) {  
135 - // blank id is fixed to 0 134 + // blank is hardcoded to 0
  135 + // also, it treats unk as blank
  136 + if (new_token != 0 && new_token != unk_id_) {
136 new_hyp.ys.push_back(new_token); 137 new_hyp.ys.push_back(new_token);
137 new_hyp.timestamps.push_back(t); 138 new_hyp.timestamps.push_back(t);
138 if (context_graphs[i] != nullptr) { 139 if (context_graphs[i] != nullptr) {
@@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder @@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
19 OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, 19 OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
20 OfflineLM *lm, 20 OfflineLM *lm,
21 int32_t max_active_paths, 21 int32_t max_active_paths,
22 - float lm_scale, 22 + float lm_scale, int32_t unk_id,
23 float blank_penalty) 23 float blank_penalty)
24 : model_(model), 24 : model_(model),
25 lm_(lm), 25 lm_(lm),
26 max_active_paths_(max_active_paths), 26 max_active_paths_(max_active_paths),
27 lm_scale_(lm_scale), 27 lm_scale_(lm_scale),
  28 + unk_id_(unk_id),
28 blank_penalty_(blank_penalty) {} 29 blank_penalty_(blank_penalty) {}
29 30
30 std::vector<OfflineTransducerDecoderResult> Decode( 31 std::vector<OfflineTransducerDecoderResult> Decode(
@@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder @@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
37 38
38 int32_t max_active_paths_; 39 int32_t max_active_paths_;
39 float lm_scale_; // used only when lm_ is not nullptr 40 float lm_scale_; // used only when lm_ is not nullptr
  41 + int32_t unk_id_;
40 float blank_penalty_; 42 float blank_penalty_;
41 }; 43 };
42 44