Fangjun Kuang
Committed by GitHub

Give an informative log for whisper on exceptions. (#473)

@@ -181,6 +181,17 @@ def get_args(): @@ -181,6 +181,17 @@ def get_args():
181 ) 181 )
182 182
183 parser.add_argument( 183 parser.add_argument(
  184 + "--whisper-tail-paddings",
  185 + default=-1,
  186 + type=int,
  187 + help="""Number of tail padding frames.
  188 + We have removed the 30-second constraint from whisper, so you need to
  189 + choose the amount of tail padding frames by yourself.
  190 + Use -1 to use a default value for tail padding.
  191 + """,
  192 + )
  193 +
  194 + parser.add_argument(
184 "--decoding-method", 195 "--decoding-method",
185 type=str, 196 type=str,
186 default="greedy_search", 197 default="greedy_search",
@@ -294,6 +305,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -294,6 +305,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
294 debug=args.debug, 305 debug=args.debug,
295 language=args.whisper_language, 306 language=args.whisper_language,
296 task=args.whisper_task, 307 task=args.whisper_task,
  308 + tail_paddings=args.whisper_tail_paddings,
297 ) 309 )
298 else: 310 else:
299 raise ValueError("Please specify at least one model") 311 raise ValueError("Please specify at least one model")
@@ -277,6 +277,17 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): @@ -277,6 +277,17 @@ def add_whisper_model_args(parser: argparse.ArgumentParser):
277 """, 277 """,
278 ) 278 )
279 279
  280 + parser.add_argument(
  281 + "--whisper-tail-paddings",
  282 + default=-1,
  283 + type=int,
  284 + help="""Number of tail padding frames.
  285 + We have removed the 30-second constraint from whisper, so you need to
  286 + choose the amount of tail padding frames by yourself.
  287 + Use -1 to use a default value for tail padding.
  288 + """,
  289 + )
  290 +
280 291
281 def add_model_args(parser: argparse.ArgumentParser): 292 def add_model_args(parser: argparse.ArgumentParser):
282 add_transducer_model_args(parser) 293 add_transducer_model_args(parser)
@@ -913,6 +924,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -913,6 +924,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
913 decoding_method=args.decoding_method, 924 decoding_method=args.decoding_method,
914 language=args.whisper_language, 925 language=args.whisper_language,
915 task=args.whisper_task, 926 task=args.whisper_task,
  927 + tail_paddings=args.whisper_tail_paddings,
916 ) 928 )
917 elif args.tdnn_model: 929 elif args.tdnn_model:
918 assert_file_exists(args.tdnn_model) 930 assert_file_exists(args.tdnn_model)
@@ -221,6 +221,17 @@ def get_args(): @@ -221,6 +221,17 @@ def get_args():
221 ) 221 )
222 222
223 parser.add_argument( 223 parser.add_argument(
  224 + "--whisper-tail-paddings",
  225 + default=-1,
  226 + type=int,
  227 + help="""Number of tail padding frames.
  228 + We have removed the 30-second constraint from whisper, so you need to
  229 + choose the amount of tail padding frames by yourself.
  230 + Use -1 to use a default value for tail padding.
  231 + """,
  232 + )
  233 +
  234 + parser.add_argument(
224 "--decoding-method", 235 "--decoding-method",
225 type=str, 236 type=str,
226 default="greedy_search", 237 default="greedy_search",
@@ -391,6 +402,7 @@ def main(): @@ -391,6 +402,7 @@ def main():
391 debug=args.debug, 402 debug=args.debug,
392 language=args.whisper_language, 403 language=args.whisper_language,
393 task=args.whisper_task, 404 task=args.whisper_task,
  405 + tail_paddings=args.whisper_tail_paddings,
394 ) 406 )
395 elif args.tdnn_model: 407 elif args.tdnn_model:
396 assert_file_exists(args.tdnn_model) 408 assert_file_exists(args.tdnn_model)
@@ -195,6 +195,17 @@ def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser): @@ -195,6 +195,17 @@ def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser):
195 """, 195 """,
196 ) 196 )
197 197
  198 + parser.add_argument(
  199 + "--second-whisper-tail-paddings",
  200 + default=-1,
  201 + type=int,
  202 + help="""Number of tail padding frames.
  203 + We have removed the 30-second constraint from whisper, so you need to
  204 + choose the amount of tail padding frames by yourself.
  205 + Use -1 to use a default value for tail padding.
  206 + """,
  207 + )
  208 +
198 209
199 def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser): 210 def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser):
200 add_second_pass_transducer_model_args(parser) 211 add_second_pass_transducer_model_args(parser)
@@ -314,6 +325,7 @@ def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -314,6 +325,7 @@ def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
314 decoding_method="greedy_search", 325 decoding_method="greedy_search",
315 language=args.second_whisper_language, 326 language=args.second_whisper_language,
316 task=args.second_whisper_task, 327 task=args.second_whisper_task,
  328 + tail_paddings=args.second_whisper_tail_paddings,
317 ) 329 )
318 else: 330 else:
319 raise ValueError("Please specify at least one model for the second pass") 331 raise ValueError("Please specify at least one model for the second pass")
@@ -167,6 +167,17 @@ def get_args(): @@ -167,6 +167,17 @@ def get_args():
167 ) 167 )
168 168
169 parser.add_argument( 169 parser.add_argument(
  170 + "--whisper-tail-paddings",
  171 + default=-1,
  172 + type=int,
  173 + help="""Number of tail padding frames.
  174 + We have removed the 30-second constraint from whisper, so you need to
  175 + choose the amount of tail padding frames by yourself.
  176 + Use -1 to use a default value for tail padding.
  177 + """,
  178 + )
  179 +
  180 + parser.add_argument(
170 "--decoding-method", 181 "--decoding-method",
171 type=str, 182 type=str,
172 default="greedy_search", 183 default="greedy_search",
@@ -256,6 +267,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -256,6 +267,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
256 debug=args.debug, 267 debug=args.debug,
257 language=args.whisper_language, 268 language=args.whisper_language,
258 task=args.whisper_task, 269 task=args.whisper_task,
  270 + tail_paddings=args.whisper_tail_paddings,
259 ) 271 )
260 else: 272 else:
261 raise ValueError("Please specify at least one model") 273 raise ValueError("Please specify at least one model")
@@ -116,18 +116,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -116,18 +116,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
116 116
117 NormalizeFeatures(f.data(), num_frames, feat_dim); 117 NormalizeFeatures(f.data(), num_frames, feat_dim);
118 118
119 - // note that 50 is an experience value.  
120 - // see also ../../scripts/whisper/test.py  
121 - //  
122 - // You can replace 50 by other values, say, 100. 119 + // note that 1000 is an experience-value.
  120 + // You can replace 1000 by other values, say, 100.
123 // 121 //
124 // Since we have removed the 30 seconds constraint, we need 122 // Since we have removed the 30 seconds constraint, we need
125 // tail_padding_frames so that whisper is able to detect the eot token. 123 // tail_padding_frames so that whisper is able to detect the eot token.
126 - int32_t tail_padding_frames = 50;  
127 - if (model_->IsMultiLingual()) {  
128 - // 300 is an experience value. If it throws, please use a larger value.  
129 - tail_padding_frames = 300;  
130 - } 124 + int32_t tail_padding_frames = 1000;
131 125
132 if (config_.model_config.whisper.tail_paddings > 0) { 126 if (config_.model_config.whisper.tail_paddings > 0) {
133 tail_padding_frames = config_.model_config.whisper.tail_paddings; 127 tail_padding_frames = config_.model_config.whisper.tail_paddings;
@@ -140,11 +134,13 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -140,11 +134,13 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
140 134
141 Ort::Value mel = Ort::Value::CreateTensor<float>( 135 Ort::Value mel = Ort::Value::CreateTensor<float>(
142 model_->Allocator(), shape.data(), shape.size()); 136 model_->Allocator(), shape.data(), shape.size());
  137 +
143 float *p_mel = mel.GetTensorMutableData<float>(); 138 float *p_mel = mel.GetTensorMutableData<float>();
144 - std::copy(f.data(), f.data() + actual_frames * feat_dim, p_mel); 139 + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
  140 +
  141 + std::fill_n(p_mel + num_frames * feat_dim,
  142 + (actual_frames - num_frames) * feat_dim, 0);
145 143
146 - memset(p_mel + f.size(), 0,  
147 - (actual_frames - num_frames) * feat_dim * sizeof(float));  
148 mel = Transpose12(model_->Allocator(), &mel); 144 mel = Transpose12(model_->Allocator(), &mel);
149 145
150 try { 146 try {
@@ -156,8 +152,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -156,8 +152,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
156 auto r = Convert(results[0], symbol_table_); 152 auto r = Convert(results[0], symbol_table_);
157 s->SetResult(r); 153 s->SetResult(r);
158 } catch (const Ort::Exception &ex) { 154 } catch (const Ort::Exception &ex) {
159 - SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result",  
160 - ex.what()); 155 + SHERPA_ONNX_LOGE(
  156 + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
  157 + "input frames: %d, Current tail "
  158 + "paddings: %d. If you see a lot of such exceptions, please consider "
  159 + "using a larger --whisper-tail-paddings",
  160 + ex.what(), num_frames, tail_padding_frames);
161 return; 161 return;
162 } 162 }
163 } 163 }
@@ -261,6 +261,7 @@ class OfflineRecognizer(object): @@ -261,6 +261,7 @@ class OfflineRecognizer(object):
261 decoding_method: str = "greedy_search", 261 decoding_method: str = "greedy_search",
262 debug: bool = False, 262 debug: bool = False,
263 provider: str = "cpu", 263 provider: str = "cpu",
  264 + tail_paddings: int = -1,
264 ): 265 ):
265 """ 266 """
266 Please refer to 267 Please refer to
@@ -305,6 +306,7 @@ class OfflineRecognizer(object): @@ -305,6 +306,7 @@ class OfflineRecognizer(object):
305 decoder=decoder, 306 decoder=decoder,
306 language=language, 307 language=language,
307 task=task, 308 task=task,
  309 + tail_paddings=tail_paddings,
308 ), 310 ),
309 tokens=tokens, 311 tokens=tokens,
310 num_threads=num_threads, 312 num_threads=num_threads,