offline-speech-denoiser-gtcrn-impl.h
4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <utility>
#include <vector>
#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/istft.h"
#include "kaldi-native-fbank/csrc/stft.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/csrc/resample.h"
namespace sherpa_onnx {
class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl {
public:
explicit OfflineSpeechDenoiserGtcrnImpl(
const OfflineSpeechDenoiserConfig &config)
: model_(config.model) {}
template <typename Manager>
OfflineSpeechDenoiserGtcrnImpl(Manager *mgr,
const OfflineSpeechDenoiserConfig &config)
: model_(mgr, config.model) {}
DenoisedAudio Run(const float *samples, int32_t n,
int32_t sample_rate) const override {
const auto &meta = model_.GetMetaData();
std::vector<float> tmp;
auto p = samples;
if (sample_rate != meta.sample_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sample_rate, meta.sample_rate);
float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width);
resampler->Resample(samples, n, true, &tmp);
p = tmp.data();
n = tmp.size();
}
knf::StftConfig stft_config;
stft_config.n_fft = meta.n_fft;
stft_config.hop_length = meta.hop_length;
stft_config.win_length = meta.window_length;
stft_config.window_type = meta.window_type;
if (stft_config.window_type == "hann_sqrt") {
auto window = knf::GetWindow("hann", stft_config.win_length);
for (auto &w : window) {
w = std::sqrt(w);
}
stft_config.window = std::move(window);
}
knf::Stft stft(stft_config);
knf::StftResult stft_result = stft.Compute(p, n);
auto states = model_.GetInitStates();
OfflineSpeechDenoiserGtcrnModel::States next_states;
knf::StftResult enhanced_stft_result;
enhanced_stft_result.num_frames = stft_result.num_frames;
for (int32_t i = 0; i < stft_result.num_frames; ++i) {
auto p = Process(stft_result, i, std::move(states), &next_states);
states = std::move(next_states);
enhanced_stft_result.real.insert(enhanced_stft_result.real.end(),
p.first.begin(), p.first.end());
enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(),
p.second.begin(), p.second.end());
}
knf::IStft istft(stft_config);
DenoisedAudio denoised_audio;
denoised_audio.sample_rate = meta.sample_rate;
denoised_audio.samples = istft.Compute(enhanced_stft_result);
return denoised_audio;
}
int32_t GetSampleRate() const override {
return model_.GetMetaData().sample_rate;
}
private:
std::pair<std::vector<float>, std::vector<float>> Process(
const knf::StftResult &stft_result, int32_t frame_index,
OfflineSpeechDenoiserGtcrnModel::States states,
OfflineSpeechDenoiserGtcrnModel::States *next_states) const {
const auto &meta = model_.GetMetaData();
int32_t n_fft = meta.n_fft;
std::vector<float> x((n_fft / 2 + 1) * 2);
const float *p_real =
stft_result.real.data() + frame_index * (n_fft / 2 + 1);
const float *p_imag =
stft_result.imag.data() + frame_index * (n_fft / 2 + 1);
for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
x[2 * i] = p_real[i];
x[2 * i + 1] = p_imag[i];
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 4> x_shape{1, n_fft / 2 + 1, 1, 2};
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value output{nullptr};
std::tie(output, *next_states) =
model_.Run(std::move(x_tensor), std::move(states));
std::vector<float> real(n_fft / 2 + 1);
std::vector<float> imag(n_fft / 2 + 1);
const auto *p = output.GetTensorData<float>();
for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
real[i] = p[2 * i];
imag[i] = p[2 * i + 1];
}
return {std::move(real), std::move(imag)};
}
private:
OfflineSpeechDenoiserGtcrnModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_