test-onnx.py
2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import kaldi_native_fbank as knf
import onnxruntime as ort
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a 3-D tensor of shape (N, T, C) containing log_probs.
"""
log_probs, log_probs_lens = self.model.run(
[self.model.get_outputs()[0].name, self.model.get_outputs()[1].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(log_probs), torch.from_numpy(log_probs_lens)
def get_features(test_wav_filename):
wave, sample_rate = torchaudio.load(test_wav_filename)
audio = wave[0].contiguous() # only use the first channel
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio, orig_freq=sample_rate, new_freq=16000
)
audio *= 32768
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.mel_opts.num_bins = 80
opts.frame_opts.snip_edges = False
opts.mel_opts.debug_mel = False
fbank = knf.OnlineFbank(opts)
fbank.accept_waveform(16000, audio.numpy())
frames = []
for i in range(fbank.num_frames_ready):
frames.append(torch.from_numpy(fbank.get_frame(i)))
frames = torch.stack(frames)
return frames
def main():
model_filename = "./model.onnx"
model = OnnxModel(model_filename)
filename = "./0.wav"
x = get_features(filename)
x = x.unsqueeze(0)
# Note: It supports only batch size == 1
x_lens = torch.tensor([x.shape[1]], dtype=torch.int64)
print(x.shape, x_lens)
log_probs, log_probs_lens = model(x, x_lens)
log_probs = log_probs[0]
print(log_probs.shape)
indexes = log_probs.argmax(dim=1)
print(indexes)
indexes = torch.unique_consecutive(indexes)
indexes = indexes[indexes != 0].tolist()
id2word = dict()
with open("./units.txt", encoding="utf-8") as f:
for line in f:
word, idx = line.strip().split()
id2word[int(idx)] = word
text = "".join([id2word[i] for i in indexes])
print(text)
if __name__ == "__main__":
main()