Fangjun Kuang
Committed by GitHub

Add Python example to show how to register speakers dynamically for speaker ID. (#986)

  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This script shows how to use Python APIs for speaker identification with
  5 +a microphone and a VAD model
  6 +
  7 +Usage:
  8 +
  9 +(1) Download a model for computing speaker embeddings
  10 +
  11 +Please visit
  12 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  13 +to download a model. An example is given below:
  14 +
  15 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
  16 +
  17 +Note that `zh` means Chinese, while `en` means English.
  18 +
  19 +(2) Download the VAD model
  20 +Please visit
  21 +https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
  22 +to download silero_vad.onnx
  23 +
  24 +For instance,
  25 +
  26 +wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx
  27 +
  28 +(3) Run this script
  29 +
  30 +python3 ./python-api-examples/speaker-identification-with-vad-dynamic.py \
  31 + --silero-vad-model=/path/to/silero_vad.onnx \
  32 + --model ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
  33 +"""
  34 +import argparse
  35 +import sys
  36 +
  37 +import numpy as np
  38 +import sherpa_onnx
  39 +
  40 +try:
  41 + import sounddevice as sd
  42 +except ImportError:
  43 + print("Please install sounddevice first. You can use")
  44 + print()
  45 + print(" pip install sounddevice")
  46 + print()
  47 + print("to install it")
  48 + sys.exit(-1)
  49 +
  50 +g_sample_rate = 16000
  51 +
  52 +
  53 +def get_args():
  54 + parser = argparse.ArgumentParser(
  55 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  56 + )
  57 +
  58 + parser.add_argument(
  59 + "--model",
  60 + type=str,
  61 + required=True,
  62 + help="Path to the speaker embedding model file.",
  63 + )
  64 +
  65 + parser.add_argument(
  66 + "--silero-vad-model",
  67 + type=str,
  68 + required=True,
  69 + help="Path to silero_vad.onnx",
  70 + )
  71 +
  72 + parser.add_argument("--threshold", type=float, default=0.4)
  73 +
  74 + parser.add_argument(
  75 + "--num-threads",
  76 + type=int,
  77 + default=1,
  78 + help="Number of threads for neural network computation",
  79 + )
  80 +
  81 + parser.add_argument(
  82 + "--debug",
  83 + type=bool,
  84 + default=False,
  85 + help="True to show debug messages",
  86 + )
  87 +
  88 + parser.add_argument(
  89 + "--provider",
  90 + type=str,
  91 + default="cpu",
  92 + help="Valid values: cpu, cuda, coreml",
  93 + )
  94 +
  95 + return parser.parse_args()
  96 +
  97 +
  98 +def load_speaker_embedding_model(args):
  99 + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
  100 + model=args.model,
  101 + num_threads=args.num_threads,
  102 + debug=args.debug,
  103 + provider=args.provider,
  104 + )
  105 + if not config.validate():
  106 + raise ValueError(f"Invalid config. {config}")
  107 + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
  108 + return extractor
  109 +
  110 +
  111 +def compute_speaker_embedding(
  112 + samples: np.ndarray,
  113 + extractor: sherpa_onnx.SpeakerEmbeddingExtractor,
  114 +) -> np.ndarray:
  115 + """
  116 + Args:
  117 + samples:
  118 + A 1-D float32 array.
  119 + extractor:
  120 + The return value of function load_speaker_embedding_model().
  121 + Returns:
  122 + Return a 1-D float32 array.
  123 + """
  124 + if len(samples) < g_sample_rate:
  125 + print(f"Your input contains only {len(samples)} samples!")
  126 +
  127 + stream = extractor.create_stream()
  128 + stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples)
  129 + stream.input_finished()
  130 +
  131 + assert extractor.is_ready(stream)
  132 + embedding = extractor.compute(stream)
  133 + embedding = np.array(embedding)
  134 + return embedding
  135 +
  136 +
  137 +def main():
  138 + args = get_args()
  139 + print(args)
  140 +
  141 + devices = sd.query_devices()
  142 + if len(devices) == 0:
  143 + print("No microphone devices found")
  144 + sys.exit(0)
  145 +
  146 + print(devices)
  147 + # If you want to select a different device, please change
  148 + # sd.default.device[0]. For instance, if you want to select device 10,
  149 + # please use
  150 + #
  151 + # sd.default.device[0] = 4
  152 + # print(devices)
  153 + #
  154 +
  155 + default_input_device_idx = sd.default.device[0]
  156 + print(f'Use default device: {devices[default_input_device_idx]["name"]}')
  157 +
  158 + extractor = load_speaker_embedding_model(args)
  159 +
  160 + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
  161 +
  162 + vad_config = sherpa_onnx.VadModelConfig()
  163 + vad_config.silero_vad.model = args.silero_vad_model
  164 + vad_config.silero_vad.min_silence_duration = 0.25
  165 + vad_config.silero_vad.min_speech_duration = 1.0
  166 + vad_config.sample_rate = g_sample_rate
  167 +
  168 + window_size = vad_config.silero_vad.window_size
  169 + vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100)
  170 +
  171 + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
  172 +
  173 + print("Started! Please speak")
  174 +
  175 + line_num = 0
  176 + speaker_id = 0
  177 + buffer = []
  178 + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s:
  179 + while True:
  180 + samples, _ = s.read(samples_per_read) # a blocking read
  181 + samples = samples.reshape(-1)
  182 + buffer = np.concatenate([buffer, samples])
  183 + while len(buffer) > window_size:
  184 + vad.accept_waveform(buffer[:window_size])
  185 + buffer = buffer[window_size:]
  186 +
  187 + while not vad.empty():
  188 + if len(vad.front.samples) < 0.5 * g_sample_rate:
  189 + # this segment is too short, skip it
  190 + vad.pop()
  191 + continue
  192 + stream = extractor.create_stream()
  193 + stream.accept_waveform(
  194 + sample_rate=g_sample_rate, waveform=vad.front.samples
  195 + )
  196 + vad.pop()
  197 + stream.input_finished()
  198 +
  199 + embedding = extractor.compute(stream)
  200 + embedding = np.array(embedding)
  201 + name = manager.search(embedding, threshold=args.threshold)
  202 + if not name:
  203 + # register it
  204 + new_name = f"speaker_{speaker_id}"
  205 + status = manager.add(new_name, embedding)
  206 + if not status:
  207 + raise RuntimeError(f"Failed to register speaker {new_name}")
  208 + print(
  209 + f"{line_num}: Detected new speaker. Register it as {new_name}"
  210 + )
  211 + speaker_id += 1
  212 + else:
  213 + print(f"{line_num}: Detected existing speaker: {name}")
  214 + line_num += 1
  215 +
  216 +
  217 +if __name__ == "__main__":
  218 + try:
  219 + main()
  220 + except KeyboardInterrupt:
  221 + print("\nCaught Ctrl + C. Exiting")