keyword-spotter.py 3.7 KB
#!/usr/bin/env python3

"""
This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
from wave file(s).

Please refer to
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
to download pre-trained models.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple

import numpy as np
import sherpa_onnx


def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
    """
    Args:
      wave_filename:
        Path to a wave file. It should be single channel and each sample should
        be 16-bit. Its sample rate does not need to be 16kHz.
    Returns:
      Return a tuple containing:
       - A 1-D array of dtype np.float32 containing the samples, which are
       normalized to the range [-1, 1].
       - sample rate of the wave file
    """

    with wave.open(wave_filename) as f:
        assert f.getnchannels() == 1, f.getnchannels()
        assert f.getsampwidth() == 2, f.getsampwidth()  # it is in bytes
        num_samples = f.getnframes()
        samples = f.readframes(num_samples)
        samples_int16 = np.frombuffer(samples, dtype=np.int16)
        samples_float32 = samples_int16.astype(np.float32)

        samples_float32 = samples_float32 / 32768
        return samples_float32, f.getframerate()


def create_keyword_spotter():
    kws = sherpa_onnx.KeywordSpotter(
        tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
        encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
        decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
        joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
        num_threads=2,
        keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
        provider="cpu",
    )

    return kws


def main():
    kws = create_keyword_spotter()

    wave_filename = (
        "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
    )

    samples, sample_rate = read_wave(wave_filename)

    tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)

    print("----------Use pre-defined keywords----------")
    s = kws.create_stream()
    s.accept_waveform(sample_rate, samples)
    s.accept_waveform(sample_rate, tail_paddings)
    s.input_finished()
    while kws.is_ready(s):
        kws.decode_stream(s)
        r = kws.get_result(s)
        if r != "":
            # Remember to call reset right after detected a keyword
            kws.reset_stream(s)

            print(f"Detected {r}")

    print("----------Use pre-defined keywords + add a new keyword----------")

    s = kws.create_stream("y ǎn y uán @演员")
    s.accept_waveform(sample_rate, samples)
    s.accept_waveform(sample_rate, tail_paddings)
    s.input_finished()
    while kws.is_ready(s):
        kws.decode_stream(s)
        r = kws.get_result(s)
        if r != "":
            # Remember to call reset right after detected a keyword
            kws.reset_stream(s)

            print(f"Detected {r}")

    print("----------Use pre-defined keywords + add 2 new keywords----------")

    s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
    s.accept_waveform(sample_rate, samples)
    s.accept_waveform(sample_rate, tail_paddings)
    s.input_finished()
    while kws.is_ready(s):
        kws.decode_stream(s)
        r = kws.get_result(s)
        if r != "":
            # Remember to call reset right after detected a keyword
            kws.reset_stream(s)

            print(f"Detected {r}")


if __name__ == "__main__":
    main()