Fangjun Kuang
Committed by GitHub

treat unk as blank (#299)

@@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -57,6 +57,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
57 model_(OnlineTransducerModel::Create(config.model_config)), 57 model_(OnlineTransducerModel::Create(config.model_config)),
58 sym_(config.model_config.tokens), 58 sym_(config.model_config.tokens),
59 endpoint_(config_.endpoint_config) { 59 endpoint_(config_.endpoint_config) {
  60 + if (sym_.contains("<unk>")) {
  61 + unk_id_ = sym_["<unk>"];
  62 + }
  63 +
60 if (config.decoding_method == "modified_beam_search") { 64 if (config.decoding_method == "modified_beam_search") {
61 if (!config_.lm_config.model.empty()) { 65 if (!config_.lm_config.model.empty()) {
62 lm_ = OnlineLM::Create(config.lm_config); 66 lm_ = OnlineLM::Create(config.lm_config);
@@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -64,10 +68,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
64 68
65 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 69 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
66 model_.get(), lm_.get(), config_.max_active_paths, 70 model_.get(), lm_.get(), config_.max_active_paths,
67 - config_.lm_config.scale); 71 + config_.lm_config.scale, unk_id_);
68 } else if (config.decoding_method == "greedy_search") { 72 } else if (config.decoding_method == "greedy_search") {
69 - decoder_ =  
70 - std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 73 + decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
  74 + model_.get(), unk_id_);
71 } else { 75 } else {
72 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 76 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
73 config.decoding_method.c_str()); 77 config.decoding_method.c_str());
@@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -82,13 +86,17 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
82 model_(OnlineTransducerModel::Create(mgr, config.model_config)), 86 model_(OnlineTransducerModel::Create(mgr, config.model_config)),
83 sym_(mgr, config.model_config.tokens), 87 sym_(mgr, config.model_config.tokens),
84 endpoint_(config_.endpoint_config) { 88 endpoint_(config_.endpoint_config) {
  89 + if (sym_.contains("<unk>")) {
  90 + unk_id_ = sym_["<unk>"];
  91 + }
  92 +
85 if (config.decoding_method == "modified_beam_search") { 93 if (config.decoding_method == "modified_beam_search") {
86 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 94 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
87 model_.get(), lm_.get(), config_.max_active_paths, 95 model_.get(), lm_.get(), config_.max_active_paths,
88 - config_.lm_config.scale); 96 + config_.lm_config.scale, unk_id_);
89 } else if (config.decoding_method == "greedy_search") { 97 } else if (config.decoding_method == "greedy_search") {
90 - decoder_ =  
91 - std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 98 + decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
  99 + model_.get(), unk_id_);
92 } else { 100 } else {
93 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 101 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
94 config.decoding_method.c_str()); 102 config.decoding_method.c_str());
@@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -268,6 +276,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
268 std::unique_ptr<OnlineTransducerDecoder> decoder_; 276 std::unique_ptr<OnlineTransducerDecoder> decoder_;
269 SymbolTable sym_; 277 SymbolTable sym_;
270 Endpoint endpoint_; 278 Endpoint endpoint_;
  279 + int32_t unk_id_ = -1;
271 }; 280 };
272 281
273 } // namespace sherpa_onnx 282 } // namespace sherpa_onnx
@@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -108,7 +108,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
108 static_cast<const float *>(p_logit), 108 static_cast<const float *>(p_logit),
109 std::max_element(static_cast<const float *>(p_logit), 109 std::max_element(static_cast<const float *>(p_logit),
110 static_cast<const float *>(p_logit) + vocab_size))); 110 static_cast<const float *>(p_logit) + vocab_size)));
111 - if (y != 0) { 111 + // blank id is hardcoded to 0
  112 + // also, it treats unk as blank
  113 + if (y != 0 && y != unk_id_) {
112 emitted = true; 114 emitted = true;
113 r.tokens.push_back(y); 115 r.tokens.push_back(y);
114 r.timestamps.push_back(t + r.frame_offset); 116 r.timestamps.push_back(t + r.frame_offset);
@@ -14,8 +14,9 @@ namespace sherpa_onnx { @@ -14,8 +14,9 @@ namespace sherpa_onnx {
14 14
15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { 15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
16 public: 16 public:
17 - explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)  
18 - : model_(model) {} 17 + OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
  18 + int32_t unk_id)
  19 + : model_(model), unk_id_(unk_id) {}
19 20
20 OnlineTransducerDecoderResult GetEmptyResult() const override; 21 OnlineTransducerDecoderResult GetEmptyResult() const override;
21 22
@@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { @@ -26,6 +27,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
26 27
27 private: 28 private:
28 OnlineTransducerModel *model_; // Not owned 29 OnlineTransducerModel *model_; // Not owned
  30 + int32_t unk_id_;
29 }; 31 };
30 32
31 } // namespace sherpa_onnx 33 } // namespace sherpa_onnx
@@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -155,7 +155,9 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
155 float context_score = 0; 155 float context_score = 0;
156 auto context_state = new_hyp.context_state; 156 auto context_state = new_hyp.context_state;
157 157
158 - if (new_token != 0) { 158 + // blank is hardcoded to 0
  159 + // also, it treats unk as blank
  160 + if (new_token != 0 && new_token != unk_id_) {
159 new_hyp.ys.push_back(new_token); 161 new_hyp.ys.push_back(new_token);
160 new_hyp.timestamps.push_back(t + frame_offset); 162 new_hyp.timestamps.push_back(t + frame_offset);
161 new_hyp.num_trailing_blanks = 0; 163 new_hyp.num_trailing_blanks = 0;
@@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -21,11 +21,12 @@ class OnlineTransducerModifiedBeamSearchDecoder
21 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, 21 OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
22 OnlineLM *lm, 22 OnlineLM *lm,
23 int32_t max_active_paths, 23 int32_t max_active_paths,
24 - float lm_scale) 24 + float lm_scale, int32_t unk_id)
25 : model_(model), 25 : model_(model),
26 lm_(lm), 26 lm_(lm),
27 max_active_paths_(max_active_paths), 27 max_active_paths_(max_active_paths),
28 - lm_scale_(lm_scale) {} 28 + lm_scale_(lm_scale),
  29 + unk_id_(unk_id) {}
29 30
30 OnlineTransducerDecoderResult GetEmptyResult() const override; 31 OnlineTransducerDecoderResult GetEmptyResult() const override;
31 32
@@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -45,6 +46,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
45 46
46 int32_t max_active_paths_; 47 int32_t max_active_paths_;
47 float lm_scale_; // used only when lm_ is not nullptr 48 float lm_scale_; // used only when lm_ is not nullptr
  49 + int32_t unk_id_;
48 }; 50 };
49 51
50 } // namespace sherpa_onnx 52 } // namespace sherpa_onnx