Peakyxh
Committed by GitHub

Add speaker identification with VAD and non-streaming ASR using ALSA (#1463)

  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script works only on Linux. It uses ALSA for recording.
  5 +
  6 +This script shows how to use Python APIs for speaker identification with
  7 +a microphone, a VAD model, and a non-streaming ASR model.
  8 +
  9 +Please see also ./generate-subtitles.py
  10 +
  11 +Usage:
  12 +
  13 +(1) Prepare a text file containing speaker related files.
  14 +
  15 +Each line in the text file contains two columns. The first column is the
  16 +speaker name, while the second column contains the wave file of the speaker.
  17 +
  18 +If the text file contains multiple wave files for the same speaker, then the
  19 +embeddings of these files are averaged.
  20 +
  21 +An example text file is given below:
  22 +
  23 + foo /path/to/a.wav
  24 + bar /path/to/b.wav
  25 + foo /path/to/c.wav
  26 + foobar /path/to/d.wav
  27 +
  28 +Each wave file should contain only a single channel; the sample format
  29 +should be int16_t; the sample rate can be arbitrary.
  30 +
  31 +(2) Download a model for computing speaker embeddings
  32 +
  33 +Please visit
  34 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  35 +to download a model. An example is given below:
  36 +
  37 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx
  38 +
  39 +Note that `zh` means Chinese, while `en` means English.
  40 +
  41 +(3) Download the VAD model
  42 +Please visit
  43 +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx
  44 +to download silero_vad.onnx
  45 +
  46 +For instance,
  47 +
  48 +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx
  49 +
  50 +(4) Please refer to ./generate-subtitles.py
  51 +to download a non-streaming ASR model.
  52 +
  53 +(5) Run this script
  54 +
  55 +Assume the filename of the text file is speaker.txt.
  56 +
  57 +python3 ./python-api-examples/speaker-identification-with-vad-non-streaming-asr.py \
  58 + --silero-vad-model=/path/to/silero_vad.onnx \
  59 + --speaker-file ./speaker.txt \
  60 + --model ./wespeaker_zh_cnceleb_resnet34.onnx
  61 +"""
  62 +import argparse
  63 +from collections import defaultdict
  64 +from pathlib import Path
  65 +from typing import Dict, List, Tuple
  66 +
  67 +import numpy as np
  68 +import sherpa_onnx
  69 +import soundfile as sf
  70 +
  71 +g_sample_rate = 16000
  72 +
  73 +
  74 +def register_non_streaming_asr_model_args(parser):
  75 + parser.add_argument(
  76 + "--tokens",
  77 + type=str,
  78 + help="Path to tokens.txt",
  79 + )
  80 +
  81 + parser.add_argument(
  82 + "--encoder",
  83 + default="",
  84 + type=str,
  85 + help="Path to the transducer encoder model",
  86 + )
  87 +
  88 + parser.add_argument(
  89 + "--decoder",
  90 + default="",
  91 + type=str,
  92 + help="Path to the transducer decoder model",
  93 + )
  94 +
  95 + parser.add_argument(
  96 + "--joiner",
  97 + default="",
  98 + type=str,
  99 + help="Path to the transducer joiner model",
  100 + )
  101 +
  102 + parser.add_argument(
  103 + "--paraformer",
  104 + default="",
  105 + type=str,
  106 + help="Path to the model.onnx from Paraformer",
  107 + )
  108 +
  109 + parser.add_argument(
  110 + "--wenet-ctc",
  111 + default="",
  112 + type=str,
  113 + help="Path to the CTC model.onnx from WeNet",
  114 + )
  115 +
  116 + parser.add_argument(
  117 + "--whisper-encoder",
  118 + default="",
  119 + type=str,
  120 + help="Path to whisper encoder model",
  121 + )
  122 +
  123 + parser.add_argument(
  124 + "--whisper-decoder",
  125 + default="",
  126 + type=str,
  127 + help="Path to whisper decoder model",
  128 + )
  129 +
  130 + parser.add_argument(
  131 + "--whisper-language",
  132 + default="",
  133 + type=str,
  134 + help="""It specifies the spoken language in the input file.
  135 + Example values: en, fr, de, zh, jp.
  136 + Available languages for multilingual models can be found at
  137 + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
  138 + If not specified, we infer the language from the input audio file.
  139 + """,
  140 + )
  141 +
  142 + parser.add_argument(
  143 + "--whisper-task",
  144 + default="transcribe",
  145 + choices=["transcribe", "translate"],
  146 + type=str,
  147 + help="""For multilingual models, if you specify translate, the output
  148 + will be in English.
  149 + """,
  150 + )
  151 +
  152 + parser.add_argument(
  153 + "--whisper-tail-paddings",
  154 + default=-1,
  155 + type=int,
  156 + help="""Number of tail padding frames.
  157 + We have removed the 30-second constraint from whisper, so you need to
  158 + choose the amount of tail padding frames by yourself.
  159 + Use -1 to use a default value for tail padding.
  160 + """,
  161 + )
  162 +
  163 + parser.add_argument(
  164 + "--decoding-method",
  165 + type=str,
  166 + default="greedy_search",
  167 + help="""Valid values are greedy_search and modified_beam_search.
  168 + modified_beam_search is valid only for transducer models.
  169 + """,
  170 + )
  171 +
  172 + parser.add_argument(
  173 + "--feature-dim",
  174 + type=int,
  175 + default=80,
  176 + help="Feature dimension. Must match the one expected by the model",
  177 + )
  178 +
  179 +
  180 +def get_args():
  181 + parser = argparse.ArgumentParser(
  182 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  183 + )
  184 +
  185 + register_non_streaming_asr_model_args(parser)
  186 +
  187 + parser.add_argument(
  188 + "--speaker-file",
  189 + type=str,
  190 + required=True,
  191 + help="""Path to the speaker file. Read the help doc at the beginning of this
  192 + file for the format.""",
  193 + )
  194 +
  195 + parser.add_argument(
  196 + "--model",
  197 + type=str,
  198 + required=True,
  199 + help="Path to the speaker embedding model file.",
  200 + )
  201 +
  202 + parser.add_argument(
  203 + "--silero-vad-model",
  204 + type=str,
  205 + required=True,
  206 + help="Path to silero_vad.onnx",
  207 + )
  208 +
  209 + parser.add_argument("--threshold", type=float, default=0.6)
  210 +
  211 + parser.add_argument(
  212 + "--num-threads",
  213 + type=int,
  214 + default=1,
  215 + help="Number of threads for neural network computation",
  216 + )
  217 +
  218 + parser.add_argument(
  219 + "--debug",
  220 + type=bool,
  221 + default=False,
  222 + help="True to show debug messages",
  223 + )
  224 +
  225 + parser.add_argument(
  226 + "--provider",
  227 + type=str,
  228 + default="cpu",
  229 + help="Valid values: cpu, cuda, coreml",
  230 + )
  231 +
  232 + parser.add_argument(
  233 + "--device-name",
  234 + type=str,
  235 + required=True,
  236 + help="""
  237 +The device name specifies which microphone to use in case there are several
  238 +on your system. You can use
  239 +
  240 + arecord -l
  241 +
  242 +to find all available microphones on your computer. For instance, if it outputs
  243 +
  244 +**** List of CAPTURE Hardware Devices ****
  245 +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio]
  246 + Subdevices: 1/1
  247 + Subdevice #0: subdevice #0
  248 +
  249 +and if you want to select card 3 and device 0 on that card, please use:
  250 +
  251 + plughw:3,0
  252 +
  253 +as the device_name.
  254 + """,
  255 + )
  256 +
  257 + return parser.parse_args()
  258 +
  259 +
  260 +def assert_file_exists(filename: str):
  261 + assert Path(filename).is_file(), (
  262 + f"{filename} does not exist!\n"
  263 + "Please refer to "
  264 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  265 + )
  266 +
  267 +
  268 +def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
  269 + if args.encoder:
  270 + assert len(args.paraformer) == 0, args.paraformer
  271 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
  272 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  273 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  274 +
  275 + assert_file_exists(args.encoder)
  276 + assert_file_exists(args.decoder)
  277 + assert_file_exists(args.joiner)
  278 +
  279 + recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
  280 + encoder=args.encoder,
  281 + decoder=args.decoder,
  282 + joiner=args.joiner,
  283 + tokens=args.tokens,
  284 + num_threads=args.num_threads,
  285 + sample_rate=args.sample_rate,
  286 + feature_dim=args.feature_dim,
  287 + decoding_method=args.decoding_method,
  288 + debug=args.debug,
  289 + )
  290 + elif args.paraformer:
  291 + assert len(args.wenet_ctc) == 0, args.wenet_ctc
  292 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  293 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  294 +
  295 + assert_file_exists(args.paraformer)
  296 +
  297 + recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
  298 + paraformer=args.paraformer,
  299 + tokens=args.tokens,
  300 + num_threads=args.num_threads,
  301 + sample_rate=g_sample_rate,
  302 + feature_dim=args.feature_dim,
  303 + decoding_method=args.decoding_method,
  304 + debug=args.debug,
  305 + )
  306 + elif args.wenet_ctc:
  307 + assert len(args.whisper_encoder) == 0, args.whisper_encoder
  308 + assert len(args.whisper_decoder) == 0, args.whisper_decoder
  309 +
  310 + assert_file_exists(args.wenet_ctc)
  311 +
  312 + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
  313 + model=args.wenet_ctc,
  314 + tokens=args.tokens,
  315 + num_threads=args.num_threads,
  316 + sample_rate=args.sample_rate,
  317 + feature_dim=args.feature_dim,
  318 + decoding_method=args.decoding_method,
  319 + debug=args.debug,
  320 + )
  321 + elif args.whisper_encoder:
  322 + assert_file_exists(args.whisper_encoder)
  323 + assert_file_exists(args.whisper_decoder)
  324 +
  325 + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
  326 + encoder=args.whisper_encoder,
  327 + decoder=args.whisper_decoder,
  328 + tokens=args.tokens,
  329 + num_threads=args.num_threads,
  330 + decoding_method=args.decoding_method,
  331 + debug=args.debug,
  332 + language=args.whisper_language,
  333 + task=args.whisper_task,
  334 + tail_paddings=args.whisper_tail_paddings,
  335 + )
  336 + else:
  337 + raise ValueError("Please specify at least one model")
  338 +
  339 + return recognizer
  340 +
  341 +
  342 +def load_speaker_embedding_model(args):
  343 + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
  344 + model=args.model,
  345 + num_threads=args.num_threads,
  346 + debug=args.debug,
  347 + provider=args.provider,
  348 + )
  349 + if not config.validate():
  350 + raise ValueError(f"Invalid config. {config}")
  351 + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
  352 + return extractor
  353 +
  354 +
  355 +def load_speaker_file(args) -> Dict[str, List[str]]:
  356 + if not Path(args.speaker_file).is_file():
  357 + raise ValueError(f"--speaker-file {args.speaker_file} does not exist")
  358 +
  359 + ans = defaultdict(list)
  360 + with open(args.speaker_file) as f:
  361 + for line in f:
  362 + line = line.strip()
  363 + if not line:
  364 + continue
  365 +
  366 + fields = line.split()
  367 + if len(fields) != 2:
  368 + raise ValueError(f"Invalid line: {line}. Fields: {fields}")
  369 +
  370 + speaker_name, filename = fields
  371 + ans[speaker_name].append(filename)
  372 + return ans
  373 +
  374 +
  375 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  376 + data, sample_rate = sf.read(
  377 + filename,
  378 + always_2d=True,
  379 + dtype="float32",
  380 + )
  381 + data = data[:, 0] # use only the first channel
  382 + samples = np.ascontiguousarray(data)
  383 + return samples, sample_rate
  384 +
  385 +
  386 +def compute_speaker_embedding(
  387 + filenames: List[str],
  388 + extractor: sherpa_onnx.SpeakerEmbeddingExtractor,
  389 +) -> np.ndarray:
  390 + assert len(filenames) > 0, "filenames is empty"
  391 +
  392 + ans = None
  393 + for filename in filenames:
  394 + print(f"processing {filename}")
  395 + samples, sample_rate = load_audio(filename)
  396 + stream = extractor.create_stream()
  397 + stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
  398 + stream.input_finished()
  399 +
  400 + assert extractor.is_ready(stream)
  401 + embedding = extractor.compute(stream)
  402 + embedding = np.array(embedding)
  403 + if ans is None:
  404 + ans = embedding
  405 + else:
  406 + ans += embedding
  407 +
  408 + return ans / len(filenames)
  409 +
  410 +
  411 +def main():
  412 + args = get_args()
  413 + print(args)
  414 +
  415 + device_name = args.device_name
  416 + print(f"device_name: {device_name}")
  417 + alsa = sherpa_onnx.Alsa(device_name)
  418 +
  419 + recognizer = create_recognizer(args)
  420 + extractor = load_speaker_embedding_model(args)
  421 + speaker_file = load_speaker_file(args)
  422 +
  423 + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
  424 + for name, filename_list in speaker_file.items():
  425 + embedding = compute_speaker_embedding(
  426 + filenames=filename_list,
  427 + extractor=extractor,
  428 + )
  429 + status = manager.add(name, embedding)
  430 + if not status:
  431 + raise RuntimeError(f"Failed to register speaker {name}")
  432 +
  433 + vad_config = sherpa_onnx.VadModelConfig()
  434 + vad_config.silero_vad.model = args.silero_vad_model
  435 + vad_config.silero_vad.min_silence_duration = 0.25
  436 + vad_config.silero_vad.min_speech_duration = 0.25
  437 + vad_config.sample_rate = g_sample_rate
  438 + if not vad_config.validate():
  439 + raise ValueError("Errors in vad config")
  440 +
  441 + window_size = vad_config.silero_vad.window_size
  442 +
  443 + vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100)
  444 +
  445 + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
  446 +
  447 + print("Started! Please speak")
  448 +
  449 + idx = 0
  450 + buffer = []
  451 + while True:
  452 + samples = alsa.read(samples_per_read) # a blocking read
  453 + samples = np.array(samples)
  454 + buffer = np.concatenate([buffer, samples])
  455 + while len(buffer) > window_size:
  456 + vad.accept_waveform(buffer[:window_size])
  457 + buffer = buffer[window_size:]
  458 +
  459 + while not vad.empty():
  460 + if len(vad.front.samples) < 0.5 * g_sample_rate:
  461 + # this segment is too short, skip it
  462 + vad.pop()
  463 + continue
  464 + stream = extractor.create_stream()
  465 + stream.accept_waveform(
  466 + sample_rate=g_sample_rate, waveform=vad.front.samples
  467 + )
  468 + stream.input_finished()
  469 +
  470 + embedding = extractor.compute(stream)
  471 + embedding = np.array(embedding)
  472 + name = manager.search(embedding, threshold=args.threshold)
  473 + if not name:
  474 + name = "unknown"
  475 +
  476 + # Now for non-streaming ASR
  477 + asr_stream = recognizer.create_stream()
  478 + asr_stream.accept_waveform(
  479 + sample_rate=g_sample_rate, waveform=vad.front.samples
  480 + )
  481 + recognizer.decode_stream(asr_stream)
  482 + text = asr_stream.result.text
  483 +
  484 + vad.pop()
  485 +
  486 + print(f"\r{idx}-{name}: {text}")
  487 + idx += 1
  488 +
  489 +
  490 +if __name__ == "__main__":
  491 + try:
  492 + main()
  493 + except KeyboardInterrupt:
  494 + print("\nCaught Ctrl + C. Exiting")