test.py
4.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
from typing import Tuple
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
class OnnxModel:
def __init__(self):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
"./gtcrn_simple.onnx",
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(meta["sample_rate"])
self.n_fft = int(meta["n_fft"])
self.hop_length = int(meta["hop_length"])
self.window_length = int(meta["window_length"])
assert meta["window_type"] == "hann_sqrt", meta["window_type"]
self.window = torch.hann_window(self.window_length).pow(0.5)
def get_init_states(self):
meta = self.model.get_modelmeta().custom_metadata_map
conv_cache_shape = list(map(int, meta["conv_cache_shape"].split(",")))
tra_cache_shape = list(map(int, meta["tra_cache_shape"].split(",")))
inter_cache_shape = list(map(int, meta["inter_cache_shape"].split(",")))
conv_cache_shape = np.zeros(conv_cache_shape, dtype=np.float32)
tra_cache = np.zeros(tra_cache_shape, dtype=np.float32)
inter_cache = np.zeros(inter_cache_shape, dtype=np.float32)
return conv_cache_shape, tra_cache, inter_cache
def __call__(self, x, states):
"""
Args:
x: (1, n_fft/2+1, 1, 2)
Returns:
o: (1, n_fft/2+1, 1, 2)
"""
out, next_conv_cache, next_tra_cache, next_inter_cache = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
self.model.get_outputs()[2].name,
self.model.get_outputs()[3].name,
],
{
self.model.get_inputs()[0].name: x,
self.model.get_inputs()[1].name: states[0],
self.model.get_inputs()[2].name: states[1],
self.model.get_inputs()[3].name: states[2],
},
)
return out, (next_conv_cache, next_tra_cache, next_inter_cache)
def main():
model = OnnxModel()
filename = "./inp_16k.wav"
wave, sample_rate = load_audio(filename)
if sample_rate != model.sample_rate:
import librosa
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=model.sample_rate)
sample_rate = model.sample_rate
stft_config = knf.StftConfig(
n_fft=model.n_fft,
hop_length=model.hop_length,
win_length=model.window_length,
window=model.window.tolist(),
)
stft = knf.Stft(stft_config)
stft_result = stft(wave)
num_frames = stft_result.num_frames
real = np.array(stft_result.real, dtype=np.float32).reshape(num_frames, -1)
imag = np.array(stft_result.imag, dtype=np.float32).reshape(num_frames, -1)
states = model.get_init_states()
outputs = []
for i in range(num_frames):
x_real = real[i : i + 1]
x_imag = imag[i : i + 1]
x = np.vstack([x_real, x_imag]).transpose()
x = np.expand_dims(x, axis=0)
x = np.expand_dims(x, axis=2)
o, states = model(x, states)
outputs.append(o)
outputs = np.concatenate(outputs, axis=2)
outputs = outputs.squeeze(0).transpose(1, 0, 2)
enhanced_real = outputs[:, :, 0]
enhanced_imag = outputs[:, :, 1]
enhanced_stft_result = knf.StftResult(
real=enhanced_real.reshape(-1).tolist(),
imag=enhanced_imag.reshape(-1).tolist(),
num_frames=enhanced_real.shape[0],
)
istft = knf.IStft(stft_config)
enhanced = istft(enhanced_stft_result)
sf.write("./enhanced_16k.wav", enhanced, model.sample_rate)
if __name__ == "__main__":
main()