keyword-spotter.py
3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
"""
This file demonstrates how to use sherpa-onnx Python API to do keyword spotting
from wave file(s).
Please refer to
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
to download pre-trained models.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple
import numpy as np
import sherpa_onnx
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def create_keyword_spotter():
kws = sherpa_onnx.KeywordSpotter(
tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt",
encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
num_threads=2,
keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt",
provider="cpu",
)
return kws
def main():
kws = create_keyword_spotter()
wave_filename = (
"./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
)
samples, sample_rate = read_wave(wave_filename)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
print("----------Use pre-defined keywords----------")
s = kws.create_stream()
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
print("----------Use pre-defined keywords + add a new keyword----------")
s = kws.create_stream("y ǎn y uán @演员")
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
print("----------Use pre-defined keywords + add 2 new keywords----------")
s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名")
s.accept_waveform(sample_rate, samples)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while kws.is_ready(s):
kws.decode_stream(s)
r = kws.get_result(s)
if r != "":
# Remember to call reset right after detected a keyword
kws.reset_stream(s)
print(f"Detected {r}")
if __name__ == "__main__":
main()