Committed by
GitHub
Give an informative log for whisper on exceptions. (#473)
正在显示
7 个修改的文件
包含
76 行增加
和
14 行删除
| @@ -181,6 +181,17 @@ def get_args(): | @@ -181,6 +181,17 @@ def get_args(): | ||
| 181 | ) | 181 | ) |
| 182 | 182 | ||
| 183 | parser.add_argument( | 183 | parser.add_argument( |
| 184 | + "--whisper-tail-paddings", | ||
| 185 | + default=-1, | ||
| 186 | + type=int, | ||
| 187 | + help="""Number of tail padding frames. | ||
| 188 | + We have removed the 30-second constraint from whisper, so you need to | ||
| 189 | + choose the amount of tail padding frames by yourself. | ||
| 190 | + Use -1 to use a default value for tail padding. | ||
| 191 | + """, | ||
| 192 | + ) | ||
| 193 | + | ||
| 194 | + parser.add_argument( | ||
| 184 | "--decoding-method", | 195 | "--decoding-method", |
| 185 | type=str, | 196 | type=str, |
| 186 | default="greedy_search", | 197 | default="greedy_search", |
| @@ -294,6 +305,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -294,6 +305,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 294 | debug=args.debug, | 305 | debug=args.debug, |
| 295 | language=args.whisper_language, | 306 | language=args.whisper_language, |
| 296 | task=args.whisper_task, | 307 | task=args.whisper_task, |
| 308 | + tail_paddings=args.whisper_tail_paddings, | ||
| 297 | ) | 309 | ) |
| 298 | else: | 310 | else: |
| 299 | raise ValueError("Please specify at least one model") | 311 | raise ValueError("Please specify at least one model") |
| @@ -277,6 +277,17 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): | @@ -277,6 +277,17 @@ def add_whisper_model_args(parser: argparse.ArgumentParser): | ||
| 277 | """, | 277 | """, |
| 278 | ) | 278 | ) |
| 279 | 279 | ||
| 280 | + parser.add_argument( | ||
| 281 | + "--whisper-tail-paddings", | ||
| 282 | + default=-1, | ||
| 283 | + type=int, | ||
| 284 | + help="""Number of tail padding frames. | ||
| 285 | + We have removed the 30-second constraint from whisper, so you need to | ||
| 286 | + choose the amount of tail padding frames by yourself. | ||
| 287 | + Use -1 to use a default value for tail padding. | ||
| 288 | + """, | ||
| 289 | + ) | ||
| 290 | + | ||
| 280 | 291 | ||
| 281 | def add_model_args(parser: argparse.ArgumentParser): | 292 | def add_model_args(parser: argparse.ArgumentParser): |
| 282 | add_transducer_model_args(parser) | 293 | add_transducer_model_args(parser) |
| @@ -913,6 +924,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -913,6 +924,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 913 | decoding_method=args.decoding_method, | 924 | decoding_method=args.decoding_method, |
| 914 | language=args.whisper_language, | 925 | language=args.whisper_language, |
| 915 | task=args.whisper_task, | 926 | task=args.whisper_task, |
| 927 | + tail_paddings=args.whisper_tail_paddings, | ||
| 916 | ) | 928 | ) |
| 917 | elif args.tdnn_model: | 929 | elif args.tdnn_model: |
| 918 | assert_file_exists(args.tdnn_model) | 930 | assert_file_exists(args.tdnn_model) |
| @@ -221,6 +221,17 @@ def get_args(): | @@ -221,6 +221,17 @@ def get_args(): | ||
| 221 | ) | 221 | ) |
| 222 | 222 | ||
| 223 | parser.add_argument( | 223 | parser.add_argument( |
| 224 | + "--whisper-tail-paddings", | ||
| 225 | + default=-1, | ||
| 226 | + type=int, | ||
| 227 | + help="""Number of tail padding frames. | ||
| 228 | + We have removed the 30-second constraint from whisper, so you need to | ||
| 229 | + choose the amount of tail padding frames by yourself. | ||
| 230 | + Use -1 to use a default value for tail padding. | ||
| 231 | + """, | ||
| 232 | + ) | ||
| 233 | + | ||
| 234 | + parser.add_argument( | ||
| 224 | "--decoding-method", | 235 | "--decoding-method", |
| 225 | type=str, | 236 | type=str, |
| 226 | default="greedy_search", | 237 | default="greedy_search", |
| @@ -391,6 +402,7 @@ def main(): | @@ -391,6 +402,7 @@ def main(): | ||
| 391 | debug=args.debug, | 402 | debug=args.debug, |
| 392 | language=args.whisper_language, | 403 | language=args.whisper_language, |
| 393 | task=args.whisper_task, | 404 | task=args.whisper_task, |
| 405 | + tail_paddings=args.whisper_tail_paddings, | ||
| 394 | ) | 406 | ) |
| 395 | elif args.tdnn_model: | 407 | elif args.tdnn_model: |
| 396 | assert_file_exists(args.tdnn_model) | 408 | assert_file_exists(args.tdnn_model) |
| @@ -195,6 +195,17 @@ def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser): | @@ -195,6 +195,17 @@ def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser): | ||
| 195 | """, | 195 | """, |
| 196 | ) | 196 | ) |
| 197 | 197 | ||
| 198 | + parser.add_argument( | ||
| 199 | + "--second-whisper-tail-paddings", | ||
| 200 | + default=-1, | ||
| 201 | + type=int, | ||
| 202 | + help="""Number of tail padding frames. | ||
| 203 | + We have removed the 30-second constraint from whisper, so you need to | ||
| 204 | + choose the amount of tail padding frames by yourself. | ||
| 205 | + Use -1 to use a default value for tail padding. | ||
| 206 | + """, | ||
| 207 | + ) | ||
| 208 | + | ||
| 198 | 209 | ||
| 199 | def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser): | 210 | def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser): |
| 200 | add_second_pass_transducer_model_args(parser) | 211 | add_second_pass_transducer_model_args(parser) |
| @@ -314,6 +325,7 @@ def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -314,6 +325,7 @@ def create_second_pass_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 314 | decoding_method="greedy_search", | 325 | decoding_method="greedy_search", |
| 315 | language=args.second_whisper_language, | 326 | language=args.second_whisper_language, |
| 316 | task=args.second_whisper_task, | 327 | task=args.second_whisper_task, |
| 328 | + tail_paddings=args.second_whisper_tail_paddings, | ||
| 317 | ) | 329 | ) |
| 318 | else: | 330 | else: |
| 319 | raise ValueError("Please specify at least one model for the second pass") | 331 | raise ValueError("Please specify at least one model for the second pass") |
| @@ -167,6 +167,17 @@ def get_args(): | @@ -167,6 +167,17 @@ def get_args(): | ||
| 167 | ) | 167 | ) |
| 168 | 168 | ||
| 169 | parser.add_argument( | 169 | parser.add_argument( |
| 170 | + "--whisper-tail-paddings", | ||
| 171 | + default=-1, | ||
| 172 | + type=int, | ||
| 173 | + help="""Number of tail padding frames. | ||
| 174 | + We have removed the 30-second constraint from whisper, so you need to | ||
| 175 | + choose the amount of tail padding frames by yourself. | ||
| 176 | + Use -1 to use a default value for tail padding. | ||
| 177 | + """, | ||
| 178 | + ) | ||
| 179 | + | ||
| 180 | + parser.add_argument( | ||
| 170 | "--decoding-method", | 181 | "--decoding-method", |
| 171 | type=str, | 182 | type=str, |
| 172 | default="greedy_search", | 183 | default="greedy_search", |
| @@ -256,6 +267,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | @@ -256,6 +267,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: | ||
| 256 | debug=args.debug, | 267 | debug=args.debug, |
| 257 | language=args.whisper_language, | 268 | language=args.whisper_language, |
| 258 | task=args.whisper_task, | 269 | task=args.whisper_task, |
| 270 | + tail_paddings=args.whisper_tail_paddings, | ||
| 259 | ) | 271 | ) |
| 260 | else: | 272 | else: |
| 261 | raise ValueError("Please specify at least one model") | 273 | raise ValueError("Please specify at least one model") |
| @@ -116,18 +116,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -116,18 +116,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 116 | 116 | ||
| 117 | NormalizeFeatures(f.data(), num_frames, feat_dim); | 117 | NormalizeFeatures(f.data(), num_frames, feat_dim); |
| 118 | 118 | ||
| 119 | - // note that 50 is an experience value. | ||
| 120 | - // see also ../../scripts/whisper/test.py | ||
| 121 | - // | ||
| 122 | - // You can replace 50 by other values, say, 100. | 119 | + // note that 1000 is an experience-value. |
| 120 | + // You can replace 1000 by other values, say, 100. | ||
| 123 | // | 121 | // |
| 124 | // Since we have removed the 30 seconds constraint, we need | 122 | // Since we have removed the 30 seconds constraint, we need |
| 125 | // tail_padding_frames so that whisper is able to detect the eot token. | 123 | // tail_padding_frames so that whisper is able to detect the eot token. |
| 126 | - int32_t tail_padding_frames = 50; | ||
| 127 | - if (model_->IsMultiLingual()) { | ||
| 128 | - // 300 is an experience value. If it throws, please use a larger value. | ||
| 129 | - tail_padding_frames = 300; | ||
| 130 | - } | 124 | + int32_t tail_padding_frames = 1000; |
| 131 | 125 | ||
| 132 | if (config_.model_config.whisper.tail_paddings > 0) { | 126 | if (config_.model_config.whisper.tail_paddings > 0) { |
| 133 | tail_padding_frames = config_.model_config.whisper.tail_paddings; | 127 | tail_padding_frames = config_.model_config.whisper.tail_paddings; |
| @@ -140,11 +134,13 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -140,11 +134,13 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 140 | 134 | ||
| 141 | Ort::Value mel = Ort::Value::CreateTensor<float>( | 135 | Ort::Value mel = Ort::Value::CreateTensor<float>( |
| 142 | model_->Allocator(), shape.data(), shape.size()); | 136 | model_->Allocator(), shape.data(), shape.size()); |
| 137 | + | ||
| 143 | float *p_mel = mel.GetTensorMutableData<float>(); | 138 | float *p_mel = mel.GetTensorMutableData<float>(); |
| 144 | - std::copy(f.data(), f.data() + actual_frames * feat_dim, p_mel); | 139 | + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); |
| 140 | + | ||
| 141 | + std::fill_n(p_mel + num_frames * feat_dim, | ||
| 142 | + (actual_frames - num_frames) * feat_dim, 0); | ||
| 145 | 143 | ||
| 146 | - memset(p_mel + f.size(), 0, | ||
| 147 | - (actual_frames - num_frames) * feat_dim * sizeof(float)); | ||
| 148 | mel = Transpose12(model_->Allocator(), &mel); | 144 | mel = Transpose12(model_->Allocator(), &mel); |
| 149 | 145 | ||
| 150 | try { | 146 | try { |
| @@ -156,8 +152,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -156,8 +152,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 156 | auto r = Convert(results[0], symbol_table_); | 152 | auto r = Convert(results[0], symbol_table_); |
| 157 | s->SetResult(r); | 153 | s->SetResult(r); |
| 158 | } catch (const Ort::Exception &ex) { | 154 | } catch (const Ort::Exception &ex) { |
| 159 | - SHERPA_ONNX_LOGE("\n\nCaught exception:\n\n%s\n\nReturn an empty result", | ||
| 160 | - ex.what()); | 155 | + SHERPA_ONNX_LOGE( |
| 156 | + "\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of " | ||
| 157 | + "input frames: %d, Current tail " | ||
| 158 | + "paddings: %d. If you see a lot of such exceptions, please consider " | ||
| 159 | + "using a larger --whisper-tail-paddings", | ||
| 160 | + ex.what(), num_frames, tail_padding_frames); | ||
| 161 | return; | 161 | return; |
| 162 | } | 162 | } |
| 163 | } | 163 | } |
| @@ -261,6 +261,7 @@ class OfflineRecognizer(object): | @@ -261,6 +261,7 @@ class OfflineRecognizer(object): | ||
| 261 | decoding_method: str = "greedy_search", | 261 | decoding_method: str = "greedy_search", |
| 262 | debug: bool = False, | 262 | debug: bool = False, |
| 263 | provider: str = "cpu", | 263 | provider: str = "cpu", |
| 264 | + tail_paddings: int = -1, | ||
| 264 | ): | 265 | ): |
| 265 | """ | 266 | """ |
| 266 | Please refer to | 267 | Please refer to |
| @@ -305,6 +306,7 @@ class OfflineRecognizer(object): | @@ -305,6 +306,7 @@ class OfflineRecognizer(object): | ||
| 305 | decoder=decoder, | 306 | decoder=decoder, |
| 306 | language=language, | 307 | language=language, |
| 307 | task=task, | 308 | task=task, |
| 309 | + tail_paddings=tail_paddings, | ||
| 308 | ), | 310 | ), |
| 309 | tokens=tokens, | 311 | tokens=tokens, |
| 310 | num_threads=num_threads, | 312 | num_threads=num_threads, |
-
请 注册 或 登录 后发表评论