Fangjun Kuang
Committed by GitHub

Generate subtitles with FireRedAsr models (#2112)

@@ -79,8 +79,17 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v @@ -79,8 +79,17 @@ wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_v
79 --num-threads=2 \ 79 --num-threads=2 \
80 /path/to/test.mp4 80 /path/to/test.mp4
81 81
  82 +(6) For FireRedAsr models
82 83
83 -(6) For WeNet CTC models 84 +./python-api-examples/generate-subtitles.py \
  85 + --silero-vad-model=/path/to/silero_vad.onnx \
  86 + --tokens=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt \
  87 + --fire-red-asr-encoder=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx \
  88 + --fire-red-asr-decoder=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx \
  89 + --num-threads=2 \
  90 + /path/to/test.mp4
  91 +
  92 +(7) For WeNet CTC models
84 93
85 ./python-api-examples/generate-subtitles.py \ 94 ./python-api-examples/generate-subtitles.py \
86 --silero-vad-model=/path/to/silero_vad.onnx \ 95 --silero-vad-model=/path/to/silero_vad.onnx \
@@ -175,6 +184,20 @@ def get_args(): @@ -175,6 +184,20 @@ def get_args():
175 ) 184 )
176 185
177 parser.add_argument( 186 parser.add_argument(
  187 + "--fire-red-asr-encoder",
  188 + default="",
  189 + type=str,
  190 + help="Path to FireRedAsr encoder model",
  191 + )
  192 +
  193 + parser.add_argument(
  194 + "--fire-red-asr-decoder",
  195 + default="",
  196 + type=str,
  197 + help="Path to FireRedAsr decoder model",
  198 + )
  199 +
  200 + parser.add_argument(
178 "--whisper-encoder", 201 "--whisper-encoder",
179 default="", 202 default="",
180 type=str, 203 type=str,
@@ -304,6 +327,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -304,6 +327,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
304 assert len(args.wenet_ctc) == 0, args.wenet_ctc 327 assert len(args.wenet_ctc) == 0, args.wenet_ctc
305 assert len(args.whisper_encoder) == 0, args.whisper_encoder 328 assert len(args.whisper_encoder) == 0, args.whisper_encoder
306 assert len(args.whisper_decoder) == 0, args.whisper_decoder 329 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  330 + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder
  331 + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder
307 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor 332 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
308 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder 333 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
309 assert ( 334 assert (
@@ -331,6 +356,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -331,6 +356,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
331 assert len(args.wenet_ctc) == 0, args.wenet_ctc 356 assert len(args.wenet_ctc) == 0, args.wenet_ctc
332 assert len(args.whisper_encoder) == 0, args.whisper_encoder 357 assert len(args.whisper_encoder) == 0, args.whisper_encoder
333 assert len(args.whisper_decoder) == 0, args.whisper_decoder 358 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  359 + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder
  360 + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder
334 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor 361 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
335 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder 362 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
336 assert ( 363 assert (
@@ -353,6 +380,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -353,6 +380,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
353 assert len(args.wenet_ctc) == 0, args.wenet_ctc 380 assert len(args.wenet_ctc) == 0, args.wenet_ctc
354 assert len(args.whisper_encoder) == 0, args.whisper_encoder 381 assert len(args.whisper_encoder) == 0, args.whisper_encoder
355 assert len(args.whisper_decoder) == 0, args.whisper_decoder 382 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  383 + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder
  384 + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder
356 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor 385 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
357 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder 386 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
358 assert ( 387 assert (
@@ -371,6 +400,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -371,6 +400,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
371 elif args.wenet_ctc: 400 elif args.wenet_ctc:
372 assert len(args.whisper_encoder) == 0, args.whisper_encoder 401 assert len(args.whisper_encoder) == 0, args.whisper_encoder
373 assert len(args.whisper_decoder) == 0, args.whisper_decoder 402 assert len(args.whisper_decoder) == 0, args.whisper_decoder
  403 + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder
  404 + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder
374 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor 405 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
375 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder 406 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
376 assert ( 407 assert (
@@ -392,6 +423,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -392,6 +423,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
392 elif args.whisper_encoder: 423 elif args.whisper_encoder:
393 assert_file_exists(args.whisper_encoder) 424 assert_file_exists(args.whisper_encoder)
394 assert_file_exists(args.whisper_decoder) 425 assert_file_exists(args.whisper_decoder)
  426 + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder
  427 + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder
395 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor 428 assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
396 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder 429 assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
397 assert ( 430 assert (
@@ -411,6 +444,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -411,6 +444,8 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
411 tail_paddings=args.whisper_tail_paddings, 444 tail_paddings=args.whisper_tail_paddings,
412 ) 445 )
413 elif args.moonshine_preprocessor: 446 elif args.moonshine_preprocessor:
  447 + assert len(args.fire_red_asr_encoder) == 0, args.fire_red_asr_encoder
  448 + assert len(args.fire_red_asr_decoder) == 0, args.fire_red_asr_decoder
414 assert_file_exists(args.moonshine_preprocessor) 449 assert_file_exists(args.moonshine_preprocessor)
415 assert_file_exists(args.moonshine_encoder) 450 assert_file_exists(args.moonshine_encoder)
416 assert_file_exists(args.moonshine_uncached_decoder) 451 assert_file_exists(args.moonshine_uncached_decoder)
@@ -426,6 +461,15 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: @@ -426,6 +461,15 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
426 decoding_method=args.decoding_method, 461 decoding_method=args.decoding_method,
427 debug=args.debug, 462 debug=args.debug,
428 ) 463 )
  464 + elif args.fire_red_asr_encoder:
  465 + recognizer = sherpa_onnx.OfflineRecognizer.from_fire_red_asr(
  466 + encoder=args.fire_red_asr_encoder,
  467 + decoder=args.fire_red_asr_decoder,
  468 + tokens=args.tokens,
  469 + num_threads=args.num_threads,
  470 + decoding_method=args.decoding_method,
  471 + debug=args.debug,
  472 + )
429 else: 473 else:
430 raise ValueError("Please specify at least one model") 474 raise ValueError("Please specify at least one model")
431 475