Fangjun Kuang
Committed by GitHub

Generate subtitles (#315)

  1 +#!/usr/bin/env python3
  2 +#
  3 +# Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +"""
  6 +This file demonstrates how to use sherpa-onnx Python APIs to generate
  7 +subtitles.
  8 +
  9 +Supported file formats are those supported by ffmpeg; for instance,
  10 +*.mov, *.mp4, *.wav, etc.
  11 +
  12 +Note that you need a non-streaming model for this script.
  13 +
  14 +(1) For paraformer
  15 +
  16 + ./python-api-examples/generate-subtitles.py \
  17 + --silero-vad-model=/path/to/silero_vad.onnx \
  18 + --tokens=/path/to/tokens.txt \
  19 + --paraformer=/path/to/paraformer.onnx \
  20 + --num-threads=2 \
  21 + --decoding-method=greedy_search \
  22 + --debug=false \
  23 + --sample-rate=16000 \
  24 + --feature-dim=80 \
  25 + /path/to/test.mp4
  26 +
  27 +(2) For transducer models from icefall
  28 +
  29 + ./python-api-examples/generate-subtitles.py \
  30 + --silero-vad-model=/path/to/silero_vad.onnx \
  31 + --tokens=/path/to/tokens.txt \
  32 + --encoder=/path/to/encoder.onnx \
  33 + --decoder=/path/to/decoder.onnx \
  34 + --joiner=/path/to/joiner.onnx \
  35 + --num-threads=2 \
  36 + --decoding-method=greedy_search \
  37 + --debug=false \
  38 + --sample-rate=16000 \
  39 + --feature-dim=80 \
  40 + /path/to/test.mp4
  41 +
  42 +(3) For Whisper models
  43 +
  44 +./python-api-examples/generate-subtitles.py \
  45 + --silero-vad-model=/path/to/silero_vad.onnx \
  46 + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
  47 + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
  48 + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
  49 + --whisper-task=transcribe \
  50 + --num-threads=2 \
  51 + /path/to/test.mp4
  52 +
  53 +Please refer to
  54 +https://k2-fsa.github.io/sherpa/onnx/index.html
  55 +to install sherpa-onnx and to download non-streaming pre-trained models
  56 +used in this file.
  57 +"""
  58 +import argparse
  59 +import shutil
  60 +import subprocess
  61 +import sys
  62 +from dataclasses import dataclass
  63 +from datetime import timedelta
  64 +from pathlib import Path
  65 +
  66 +import numpy as np
  67 +import sherpa_onnx
  68 +
  69 +
  70 +def get_args():
  71 + parser = argparse.ArgumentParser(
  72 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  73 + )
  74 +
  75 + parser.add_argument(
  76 + "--silero-vad-model",
  77 + type=str,
  78 + required=True,
  79 + help="Path to silero_vad.onnx",
  80 + )
  81 +
  82 + parser.add_argument(
  83 + "--tokens",
  84 + type=str,
  85 + help="Path to tokens.txt",
  86 + )
  87 +
  88 + parser.add_argument(
  89 + "--encoder",
  90 + default="",
  91 + type=str,
  92 + help="Path to the transducer encoder model",
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--decoder",
  97 + default="",
  98 + type=str,
  99 + help="Path to the transducer decoder model",
  100 + )
  101 +
  102 + parser.add_argument(
  103 + "--joiner",
  104 + default="",
  105 + type=str,
  106 + help="Path to the transducer joiner model",
  107 + )
  108 +
  109 + parser.add_argument(
  110 + "--paraformer",
  111 + default="",
  112 + type=str,
  113 + help="Path to the model.onnx from Paraformer",
  114 + )
  115 +
  116 + parser.add_argument(
  117 + "--num-threads",
  118 + type=int,
  119 + default=1,
  120 + help="Number of threads for neural network computation",
  121 + )
  122 +
  123 + parser.add_argument(
  124 + "--whisper-encoder",
  125 + default="",
  126 + type=str,
  127 + help="Path to whisper encoder model",
  128 + )
  129 +
  130 + parser.add_argument(
  131 + "--whisper-decoder",
  132 + default="",
  133 + type=str,
  134 + help="Path to whisper decoder model",
  135 + )
  136 +
  137 + parser.add_argument(
  138 + "--whisper-language",
  139 + default="",
  140 + type=str,
  141 + help="""It specifies the spoken language in the input file.
  142 + Example values: en, fr, de, zh, jp.
  143 + Available languages for multilingual models can be found at
  144 + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  145 + If not specified, we infer the language from the input audio file.
  146 + """,
  147 + )
  148 +
  149 + parser.add_argument(
  150 + "--whisper-task",
  151 + default="transcribe",
  152 + choices=["transcribe", "translate"],
  153 + type=str,
  154 + help="""For multilingual models, if you specify translate, the output
  155 + will be in English.
  156 + """,
  157 + )
  158 +
  159 + parser.add_argument(
  160 + "--decoding-method",
  161 + type=str,
  162 + default="greedy_search",
  163 + help="""Valid values are greedy_search and modified_beam_search.
  164 + modified_beam_search is valid only for transducer models.
  165 + """,
  166 + )
  167 + parser.add_argument(
  168 + "--debug",
  169 + type=bool,
  170 + default=False,
  171 + help="True to show debug messages when loading modes.",
  172 + )
  173 +
  174 + parser.add_argument(
  175 + "--sample-rate",
  176 + type=int,
  177 + default=16000,
  178 + help="""Sample rate of the feature extractor. Must match the one
  179 + expected by the model. Note: The input sound files can have a
  180 + different sample rate from this argument.""",
  181 + )
  182 +
  183 + parser.add_argument(
  184 + "--feature-dim",
  185 + type=int,
  186 + default=80,
  187 + help="Feature dimension. Must match the one expected by the model",
  188 + )
  189 +
  190 + parser.add_argument(
  191 + "sound_file",
  192 + type=str,
  193 + help="The input sound file to generate subtitles ",
  194 + )
  195 +
  196 + return parser.parse_args()
  197 +
  198 +
  199 +def assert_file_exists(filename: str):
  200 + assert Path(filename).is_file(), (
  201 + f"{filename} does not exist!\n"
  202 + "Please refer to "
  203 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  204 + )
  205 +
  206 +
  207 +def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
  208 + if args.encoder:
  209 + assert len(args.paraformer) == 0, args.paraformer
  210 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  211 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  212 +
  213 + assert_file_exists(args.encoder)
  214 + assert_file_exists(args.decoder)
  215 + assert_file_exists(args.joiner)
  216 +
  217 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  218 + encoder=args.encoder,
  219 + decoder=args.decoder,
  220 + joiner=args.joiner,
  221 + tokens=args.tokens,
  222 + num_threads=args.num_threads,
  223 + sample_rate=args.sample_rate,
  224 + feature_dim=args.feature_dim,
  225 + decoding_method=args.decoding_method,
  226 + debug=args.debug,
  227 + )
  228 + elif args.paraformer:
  229 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  230 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  231 +
  232 + assert_file_exists(args.paraformer)
  233 +
  234 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  235 + paraformer=args.paraformer,
  236 + tokens=args.tokens,
  237 + num_threads=args.num_threads,
  238 + sample_rate=args.sample_rate,
  239 + feature_dim=args.feature_dim,
  240 + decoding_method=args.decoding_method,
  241 + debug=args.debug,
  242 + )
  243 + elif args.whisper_encoder:
  244 + assert_file_exists(args.whisper_encoder)
  245 + assert_file_exists(args.whisper_decoder)
  246 +
  247 + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
  248 + encoder=args.whisper_encoder,
  249 + decoder=args.whisper_decoder,
  250 + tokens=args.tokens,
  251 + num_threads=args.num_threads,
  252 + decoding_method=args.decoding_method,
  253 + debug=args.debug,
  254 + language=args.whisper_language,
  255 + task=args.whisper_task,
  256 + )
  257 + else:
  258 + raise ValueError("Please specify at least one model")
  259 +
  260 + return recognizer
  261 +
  262 +
  263 +@dataclass
  264 +class Segment:
  265 + start: float
  266 + duration: float
  267 + text: str = ""
  268 +
  269 + @property
  270 + def end(self):
  271 + return self.start + self.duration
  272 +
  273 + def __str__(self):
  274 + s = f"{timedelta(seconds=self.start)}"[:-3]
  275 + s += " --> "
  276 + s += f"{timedelta(seconds=self.end)}"[:-3]
  277 + s = s.replace(".", ",")
  278 + s += "\n"
  279 + s += self.text
  280 + return s
  281 +
  282 +
  283 +def main():
  284 + args = get_args()
  285 + assert_file_exists(args.tokens)
  286 + assert_file_exists(args.silero_vad_model)
  287 +
  288 + assert args.num_threads > 0, args.num_threads
  289 +
  290 + if not Path(args.sound_file).is_file():
  291 + raise ValueError(f"{args.sound_file} does not exist")
  292 +
  293 + assert (
  294 + args.sample_rate == 16000
  295 + ), f"Only sample rate 16000 is supported.Given: {args.sample_rate}"
  296 +
  297 + recognizer = create_recognizer(args)
  298 +
  299 + ffmpeg_cmd = [
  300 + "ffmpeg",
  301 + "-i",
  302 + args.sound_file,
  303 + "-f",
  304 + "s16le",
  305 + "-acodec",
  306 + "pcm_s16le",
  307 + "-ac",
  308 + "1",
  309 + "-ar",
  310 + str(args.sample_rate),
  311 + "-",
  312 + ]
  313 +
  314 + process = subprocess.Popen(
  315 + ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
  316 + )
  317 +
  318 + frames_per_read = int(args.sample_rate * 100) # 100 second
  319 +
  320 + stream = recognizer.create_stream()
  321 +
  322 + config = sherpa_onnx.VadModelConfig()
  323 + config.silero_vad.model = args.silero_vad_model
  324 + config.silero_vad.min_silence_duration = 0.25
  325 + config.sample_rate = args.sample_rate
  326 +
  327 + window_size = config.silero_vad.window_size
  328 +
  329 + buffer = []
  330 + vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=100)
  331 +
  332 + segment_list = []
  333 +
  334 + print("Started!")
  335 +
  336 + # TODO(fangjun): Support multithreads
  337 + while True:
  338 + # *2 because int16_t has two bytes
  339 + data = process.stdout.read(frames_per_read * 2)
  340 + if not data:
  341 + break
  342 +
  343 + samples = np.frombuffer(data, dtype=np.int16)
  344 + samples = samples.astype(np.float32) / 32768
  345 +
  346 + buffer = np.concatenate([buffer, samples])
  347 + while len(buffer) > window_size:
  348 + vad.accept_waveform(buffer[:window_size])
  349 + buffer = buffer[window_size:]
  350 +
  351 + streams = []
  352 + segments = []
  353 + while not vad.empty():
  354 + segment = Segment(
  355 + start=vad.front.start / args.sample_rate,
  356 + duration=len(vad.front.samples) / args.sample_rate,
  357 + )
  358 + segments.append(segment)
  359 +
  360 + stream = recognizer.create_stream()
  361 + stream.accept_waveform(args.sample_rate, vad.front.samples)
  362 +
  363 + streams.append(stream)
  364 +
  365 + vad.pop()
  366 +
  367 + recognizer.decode_streams(streams)
  368 + for seg, stream in zip(segments, streams):
  369 + seg.text = stream.result.text
  370 + segment_list.append(seg)
  371 +
  372 + srt_filename = Path(args.sound_file).with_suffix(".srt")
  373 + with open(srt_filename, "w", encoding="utf-8") as f:
  374 + for i, seg in enumerate(segment_list):
  375 + print(i + 1, file=f)
  376 + print(seg, file=f)
  377 + print("", file=f)
  378 +
  379 + print(f"Saved to {srt_filename}")
  380 + print("Done!")
  381 +
  382 +
  383 +if __name__ == "__main__":
  384 + if shutil.which("ffmpeg") is None:
  385 + sys.exit("Please install ffmpeg first!")
  386 + main()