Fangjun Kuang
Committed by GitHub

Print a more user-friendly error message when using --hotwords-file. (#344)

@@ -46,7 +46,17 @@ bool OfflineRecognizerConfig::Validate() const { @@ -46,7 +46,17 @@ bool OfflineRecognizerConfig::Validate() const {
46 max_active_paths); 46 max_active_paths);
47 return false; 47 return false;
48 } 48 }
49 - if (!lm_config.Validate()) return false; 49 + if (!lm_config.Validate()) {
  50 + return false;
  51 + }
  52 + }
  53 +
  54 + if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
  55 + SHERPA_ONNX_LOGE(
  56 + "Please use --decoding-method=modified_beam_search if you"
  57 + " provide --hotwords-file. Given --decoding-method=%s",
  58 + decoding_method.c_str());
  59 + return false;
50 } 60 }
51 61
52 return model_config.Validate(); 62 return model_config.Validate();
@@ -156,8 +156,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -156,8 +156,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
156 bool has_context_graph = false; 156 bool has_context_graph = false;
157 157
158 for (int32_t i = 0; i != n; ++i) { 158 for (int32_t i = 0; i != n; ++i) {
159 - if (!has_context_graph && ss[i]->GetContextGraph()) 159 + if (!has_context_graph && ss[i]->GetContextGraph()) {
160 has_context_graph = true; 160 has_context_graph = true;
  161 + }
161 162
162 const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); 163 const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
163 std::vector<float> features = 164 std::vector<float> features =
@@ -102,8 +102,20 @@ bool OnlineRecognizerConfig::Validate() const { @@ -102,8 +102,20 @@ bool OnlineRecognizerConfig::Validate() const {
102 max_active_paths); 102 max_active_paths);
103 return false; 103 return false;
104 } 104 }
105 - if (!lm_config.Validate()) return false; 105 +
  106 + if (!lm_config.Validate()) {
  107 + return false;
106 } 108 }
  109 + }
  110 +
  111 + if (!hotwords_file.empty() && decoding_method != "modified_beam_search") {
  112 + SHERPA_ONNX_LOGE(
  113 + "Please use --decoding-method=modified_beam_search if you"
  114 + " provide --hotwords-file. Given --decoding-method=%s",
  115 + decoding_method.c_str());
  116 + return false;
  117 + }
  118 +
107 return model_config.Validate(); 119 return model_config.Validate();
108 } 120 }
109 121
1 -from typing import Dict, List, Optional  
2 -  
3 from _sherpa_onnx import ( 1 from _sherpa_onnx import (
4 CircularBuffer, 2 CircularBuffer,
5 Display, 3 Display,
@@ -102,6 +102,12 @@ class OfflineRecognizer(object): @@ -102,6 +102,12 @@ class OfflineRecognizer(object):
102 feature_dim=feature_dim, 102 feature_dim=feature_dim,
103 ) 103 )
104 104
  105 + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search":
  106 + raise ValueError(
  107 + "Please use --decoding-method=modified_beam_search when using "
  108 + f"--hotwords-file. Currently given: {decoding_method}"
  109 + )
  110 +
105 recognizer_config = OfflineRecognizerConfig( 111 recognizer_config = OfflineRecognizerConfig(
106 feat_config=feat_config, 112 feat_config=feat_config,
107 model_config=model_config, 113 model_config=model_config,
@@ -132,6 +132,12 @@ class OnlineRecognizer(object): @@ -132,6 +132,12 @@ class OnlineRecognizer(object):
132 rule3_min_utterance_length=rule3_min_utterance_length, 132 rule3_min_utterance_length=rule3_min_utterance_length,
133 ) 133 )
134 134
  135 + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search":
  136 + raise ValueError(
  137 + "Please use --decoding-method=modified_beam_search when using "
  138 + f"--hotwords-file. Currently given: {decoding_method}"
  139 + )
  140 +
135 recognizer_config = OnlineRecognizerConfig( 141 recognizer_config = OnlineRecognizerConfig(
136 feat_config=feat_config, 142 feat_config=feat_config,
137 model_config=model_config, 143 model_config=model_config,