Fangjun Kuang
Committed by GitHub

Add C++ runtime for speech enhancement GTCRN models (#1977)

See also https://github.com/Xiaobin-Rong/gtcrn
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +if [ -z $EXE ]; then
  12 + EXE=./build/bin/sherpa-onnx-offline-denoiser
  13 +fi
  14 +
  15 +echo "EXE is $EXE"
  16 +echo "PATH: $PATH"
  17 +
  18 +which $EXE
  19 +
  20 +log "------------------------------------------------------------"
  21 +log "Run gtcrn"
  22 +log "------------------------------------------------------------"
  23 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  24 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
  25 +
  26 +$EXE \
  27 + --debug=1 \
  28 + --speech-denoiser-gtcrn-model=./gtcrn_simple.onnx \
  29 + --input-wav=./speech_with_noise.wav \
  30 + --output-wav=./enhanced_speech_16k.wav
  31 +
  32 +rm ./gtcrn_simple.onnx
@@ -10,6 +10,7 @@ on: @@ -10,6 +10,7 @@ on:
10 - '.github/workflows/linux.yaml' 10 - '.github/workflows/linux.yaml'
11 - '.github/scripts/test-kws.sh' 11 - '.github/scripts/test-kws.sh'
12 - '.github/scripts/test-online-transducer.sh' 12 - '.github/scripts/test-online-transducer.sh'
  13 + - '.github/scripts/test-offline-speech-denoiser.sh'
13 - '.github/scripts/test-online-paraformer.sh' 14 - '.github/scripts/test-online-paraformer.sh'
14 - '.github/scripts/test-offline-transducer.sh' 15 - '.github/scripts/test-offline-transducer.sh'
15 - '.github/scripts/test-offline-ctc.sh' 16 - '.github/scripts/test-offline-ctc.sh'
@@ -31,6 +32,7 @@ on: @@ -31,6 +32,7 @@ on:
31 paths: 32 paths:
32 - '.github/workflows/linux.yaml' 33 - '.github/workflows/linux.yaml'
33 - '.github/scripts/test-kws.sh' 34 - '.github/scripts/test-kws.sh'
  35 + - '.github/scripts/test-offline-speech-denoiser.sh'
34 - '.github/scripts/test-online-transducer.sh' 36 - '.github/scripts/test-online-transducer.sh'
35 - '.github/scripts/test-online-paraformer.sh' 37 - '.github/scripts/test-online-paraformer.sh'
36 - '.github/scripts/test-offline-transducer.sh' 38 - '.github/scripts/test-offline-transducer.sh'
@@ -203,6 +205,15 @@ jobs: @@ -203,6 +205,15 @@ jobs:
203 overwrite: true 205 overwrite: true
204 file: sherpa-onnx-*.tar.bz2 206 file: sherpa-onnx-*.tar.bz2
205 207
  208 + - name: Test offline speech denoiser
  209 + shell: bash
  210 + run: |
  211 + du -h -d1 .
  212 + export PATH=$PWD/build/bin:$PATH
  213 + export EXE=sherpa-onnx-offline-denoiser
  214 +
  215 + .github/scripts/test-offline-speech-denoiser.sh
  216 +
206 - name: Test offline TTS 217 - name: Test offline TTS
207 if: matrix.with_tts == 'ON' 218 if: matrix.with_tts == 'ON'
208 shell: bash 219 shell: bash
@@ -215,6 +226,11 @@ jobs: @@ -215,6 +226,11 @@ jobs:
215 du -h -d1 . 226 du -h -d1 .
216 227
217 - uses: actions/upload-artifact@v4 228 - uses: actions/upload-artifact@v4
  229 + with:
  230 + name: speech-denoiser-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
  231 + path: ./*speech*.wav
  232 +
  233 + - uses: actions/upload-artifact@v4
218 if: matrix.with_tts == 'ON' 234 if: matrix.with_tts == 'ON'
219 with: 235 with:
220 name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 236 name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
@@ -7,6 +7,7 @@ on: @@ -7,6 +7,7 @@ on:
7 tags: 7 tags:
8 - 'v[0-9]+.[0-9]+.[0-9]+*' 8 - 'v[0-9]+.[0-9]+.[0-9]+*'
9 paths: 9 paths:
  10 + - '.github/scripts/test-offline-speech-denoiser.sh'
10 - '.github/workflows/macos.yaml' 11 - '.github/workflows/macos.yaml'
11 - '.github/scripts/test-kws.sh' 12 - '.github/scripts/test-kws.sh'
12 - '.github/scripts/test-online-transducer.sh' 13 - '.github/scripts/test-online-transducer.sh'
@@ -28,6 +29,7 @@ on: @@ -28,6 +29,7 @@ on:
28 branches: 29 branches:
29 - master 30 - master
30 paths: 31 paths:
  32 + - '.github/scripts/test-offline-speech-denoiser.sh'
31 - '.github/workflows/macos.yaml' 33 - '.github/workflows/macos.yaml'
32 - '.github/scripts/test-kws.sh' 34 - '.github/scripts/test-kws.sh'
33 - '.github/scripts/test-online-transducer.sh' 35 - '.github/scripts/test-online-transducer.sh'
@@ -160,6 +162,15 @@ jobs: @@ -160,6 +162,15 @@ jobs:
160 overwrite: true 162 overwrite: true
161 file: sherpa-onnx-*osx-universal2*.tar.bz2 163 file: sherpa-onnx-*osx-universal2*.tar.bz2
162 164
  165 + - name: Test offline speech denoiser
  166 + shell: bash
  167 + run: |
  168 + du -h -d1 .
  169 + export PATH=$PWD/build/bin:$PATH
  170 + export EXE=sherpa-onnx-offline-denoiser
  171 +
  172 + .github/scripts/test-offline-speech-denoiser.sh
  173 +
163 - name: Test offline TTS 174 - name: Test offline TTS
164 if: matrix.with_tts == 'ON' 175 if: matrix.with_tts == 'ON'
165 shell: bash 176 shell: bash
@@ -12,9 +12,9 @@ @@ -12,9 +12,9 @@
12 |--------------------------------|---------------|--------------------------| 12 |--------------------------------|---------------|--------------------------|
13 | ✔️ | ✔️ | ✔️ | 13 | ✔️ | ✔️ | ✔️ |
14 14
15 -| Keyword spotting | Add punctuation |  
16 -|------------------|-----------------|  
17 -| ✔️ | ✔️ | 15 +| Keyword spotting | Add punctuation | Speech enhancement |
  16 +|------------------|-----------------|--------------------|
  17 +| ✔️ | ✔️ | ✔️ |
18 18
19 ### Supported platforms 19 ### Supported platforms
20 20
@@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below: @@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below:
198 | Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]| 198 | Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]|
199 | Punctuation | [Address][punct-models] | 199 | Punctuation | [Address][punct-models] |
200 | Speaker segmentation | [Address][speaker-segmentation-models] | 200 | Speaker segmentation | [Address][speaker-segmentation-models] |
  201 +| Speech enhancement | [Address][speech-enhancement-models] |
201 202
202 </details> 203 </details>
203 204
@@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss @@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss
442 [Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 443 [Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
443 [NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9 444 [NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9
444 [NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/ 445 [NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/
  446 +[speech-enhancement-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
1 function(download_kaldi_native_fbank) 1 function(download_kaldi_native_fbank)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")  
5 - set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.20.0.tar.gz")  
6 - set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9") 4 + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.1.tar.gz")
  5 + set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.1.tar.gz")
  6 + set(kaldi_native_fbank_HASH "SHA256=37c1aa230b00fe062791d800d8fc50aa3de215918d3dce6440699e67275d859e")
7 7
8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) 8 set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 9 set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-native-fbank 13 # please pre-download kaldi-native-fbank
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz  
16 - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz  
17 - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz  
18 - /tmp/kaldi-native-fbank-1.20.0.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.1.tar.gz
  16 + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.1.tar.gz
  17 + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.1.tar.gz
  18 + /tmp/kaldi-native-fbank-1.21.1.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.1.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
@@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS) @@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS)
186 ) 186 )
187 endif() 187 endif()
188 188
  189 +list(APPEND sources
  190 + offline-speech-denoiser-gtcrn-model-config.cc
  191 + offline-speech-denoiser-gtcrn-model.cc
  192 + offline-speech-denoiser-impl.cc
  193 + offline-speech-denoiser-model-config.cc
  194 + offline-speech-denoiser.cc
  195 +)
  196 +
189 if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) 197 if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
190 list(APPEND sources 198 list(APPEND sources
191 fast-clustering-config.cc 199 fast-clustering-config.cc
@@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
301 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) 309 add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
302 add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) 310 add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
303 add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) 311 add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
  312 + add_executable(sherpa-onnx-offline-denoiser sherpa-onnx-offline-denoiser.cc)
304 313
305 if(SHERPA_ONNX_ENABLE_TTS) 314 if(SHERPA_ONNX_ENABLE_TTS)
306 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 315 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
@@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
318 sherpa-onnx-offline-language-identification 327 sherpa-onnx-offline-language-identification
319 sherpa-onnx-offline-parallel 328 sherpa-onnx-offline-parallel
320 sherpa-onnx-offline-punctuation 329 sherpa-onnx-offline-punctuation
  330 + sherpa-onnx-offline-denoiser
321 sherpa-onnx-online-punctuation 331 sherpa-onnx-online-punctuation
322 ) 332 )
323 if(SHERPA_ONNX_ENABLE_TTS) 333 if(SHERPA_ONNX_ENABLE_TTS)
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "kaldi-native-fbank/csrc/feature-window.h"
  14 +#include "kaldi-native-fbank/csrc/istft.h"
  15 +#include "kaldi-native-fbank/csrc/stft.h"
  16 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
  17 +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
  18 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  19 +#include "sherpa-onnx/csrc/resample.h"
  20 +
  21 +namespace sherpa_onnx {
  22 +
  23 +class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl {
  24 + public:
  25 + explicit OfflineSpeechDenoiserGtcrnImpl(
  26 + const OfflineSpeechDenoiserConfig &config)
  27 + : model_(config.model) {}
  28 +
  29 + template <typename Manager>
  30 + OfflineSpeechDenoiserGtcrnImpl(Manager *mgr,
  31 + const OfflineSpeechDenoiserConfig &config)
  32 + : model_(mgr, config.model) {}
  33 +
  34 + DenoisedAudio Run(const float *samples, int32_t n,
  35 + int32_t sample_rate) const override {
  36 + SHERPA_ONNX_LOGE("n: %d, sample_rate: %d", n, sample_rate);
  37 + const auto &meta = model_.GetMetaData();
  38 +
  39 + std::vector<float> tmp;
  40 + auto p = samples;
  41 +
  42 + if (sample_rate != meta.sample_rate) {
  43 + SHERPA_ONNX_LOGE(
  44 + "Creating a resampler:\n"
  45 + " in_sample_rate: %d\n"
  46 + " output_sample_rate: %d\n",
  47 + sample_rate, meta.sample_rate);
  48 +
  49 + float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate);
  50 + float lowpass_cutoff = 0.99 * 0.5 * min_freq;
  51 +
  52 + int32_t lowpass_filter_width = 6;
  53 + auto resampler = std::make_unique<LinearResample>(
  54 + sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width);
  55 + resampler->Resample(samples, n, true, &tmp);
  56 + p = tmp.data();
  57 + n = tmp.size();
  58 + }
  59 +
  60 + knf::StftConfig stft_config;
  61 + stft_config.n_fft = meta.n_fft;
  62 + stft_config.hop_length = meta.hop_length;
  63 + stft_config.win_length = meta.window_length;
  64 + stft_config.window_type = meta.window_type;
  65 + if (stft_config.window_type == "hann_sqrt") {
  66 + auto window = knf::GetWindow("hann", stft_config.win_length);
  67 + for (auto &w : window) {
  68 + w = std::sqrt(w);
  69 + }
  70 + stft_config.window = std::move(window);
  71 + }
  72 +
  73 + knf::Stft stft(stft_config);
  74 + knf::StftResult stft_result = stft.Compute(p, n);
  75 +
  76 + auto states = model_.GetInitStates();
  77 + OfflineSpeechDenoiserGtcrnModel::States next_states;
  78 +
  79 + knf::StftResult enhanced_stft_result;
  80 + enhanced_stft_result.num_frames = stft_result.num_frames;
  81 + for (int32_t i = 0; i < stft_result.num_frames; ++i) {
  82 + auto p = Process(stft_result, i, std::move(states), &next_states);
  83 + states = std::move(next_states);
  84 +
  85 + enhanced_stft_result.real.insert(enhanced_stft_result.real.end(),
  86 + p.first.begin(), p.first.end());
  87 + enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(),
  88 + p.second.begin(), p.second.end());
  89 + }
  90 +
  91 + knf::IStft istft(stft_config);
  92 +
  93 + DenoisedAudio denoised_audio;
  94 + denoised_audio.sample_rate = meta.sample_rate;
  95 + denoised_audio.samples = istft.Compute(enhanced_stft_result);
  96 + return denoised_audio;
  97 + }
  98 +
  99 + int32_t GetSampleRate() const override {
  100 + return model_.GetMetaData().sample_rate;
  101 + }
  102 +
  103 + private:
  104 + std::pair<std::vector<float>, std::vector<float>> Process(
  105 + const knf::StftResult &stft_result, int32_t frame_index,
  106 + OfflineSpeechDenoiserGtcrnModel::States states,
  107 + OfflineSpeechDenoiserGtcrnModel::States *next_states) const {
  108 + const auto &meta = model_.GetMetaData();
  109 + int32_t n_fft = meta.n_fft;
  110 + std::vector<float> x((n_fft / 2 + 1) * 2);
  111 +
  112 + const float *p_real =
  113 + stft_result.real.data() + frame_index * (n_fft / 2 + 1);
  114 + const float *p_imag =
  115 + stft_result.imag.data() + frame_index * (n_fft / 2 + 1);
  116 +
  117 + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
  118 + x[2 * i] = p_real[i];
  119 + x[2 * i + 1] = p_imag[i];
  120 + }
  121 + auto memory_info =
  122 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  123 +
  124 + std::array<int64_t, 4> x_shape{1, n_fft / 2 + 1, 1, 2};
  125 + Ort::Value x_tensor = Ort::Value::CreateTensor(
  126 + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
  127 +
  128 + Ort::Value output{nullptr};
  129 + std::tie(output, *next_states) =
  130 + model_.Run(std::move(x_tensor), std::move(states));
  131 +
  132 + std::vector<float> real(n_fft / 2 + 1);
  133 + std::vector<float> imag(n_fft / 2 + 1);
  134 + const auto *p = output.GetTensorData<float>();
  135 + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
  136 + real[i] = p[2 * i];
  137 + imag[i] = p[2 * i + 1];
  138 + }
  139 +
  140 + return {std::move(real), std::move(imag)};
  141 + }
  142 +
  143 + private:
  144 + OfflineSpeechDenoiserGtcrnModel model_;
  145 +};
  146 +
  147 +} // namespace sherpa_onnx
  148 +
  149 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OfflineSpeechDenoiserGtcrnModelConfig::Register(ParseOptions *po) {
  15 + po->Register("speech-denoiser-gtcrn-model", &model,
  16 + "Path to the gtcrn model for speech denoising");
  17 +}
  18 +
  19 +bool OfflineSpeechDenoiserGtcrnModelConfig::Validate() const {
  20 + if (model.empty()) {
  21 + SHERPA_ONNX_LOGE("Please provide --speech-denoiser-gtcrn-model");
  22 + return false;
  23 + }
  24 +
  25 + if (!FileExists(model)) {
  26 + SHERPA_ONNX_LOGE("gtcrn model file '%s' does not exist", model.c_str());
  27 + return false;
  28 + }
  29 + return true;
  30 +}
  31 +
  32 +std::string OfflineSpeechDenoiserGtcrnModelConfig::ToString() const {
  33 + std::ostringstream os;
  34 +
  35 + os << "OfflineSpeechDenoiserGtcrnModelConfig(";
  36 + os << "model=\"" << model << "\")";
  37 + return os.str();
  38 +}
  39 +
  40 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineSpeechDenoiserGtcrnModelConfig {
  14 + std::string model;
  15 +
  16 + void Register(ParseOptions *po);
  17 + bool Validate() const;
  18 +
  19 + std::string ToString() const;
  20 +};
  21 +
  22 +} // namespace sherpa_onnx
  23 +
  24 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +#include <vector>
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +// please refer to
  15 +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py
  16 +struct OfflineSpeechDenoiserGtcrnModelMetaData {
  17 + int32_t sample_rate = 0;
  18 + int32_t version = 1;
  19 + int32_t n_fft = 0;
  20 + int32_t hop_length = 0;
  21 + int32_t window_length = 0;
  22 + std::string window_type;
  23 +
  24 + std::vector<int64_t> conv_cache_shape;
  25 + std::vector<int64_t> tra_cache_shape;
  26 + std::vector<int64_t> inter_cache_shape;
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/file-utils.h"
  13 +#include "sherpa-onnx/csrc/onnx-utils.h"
  14 +#include "sherpa-onnx/csrc/session.h"
  15 +#include "sherpa-onnx/csrc/text-utils.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +class OfflineSpeechDenoiserGtcrnModel::Impl {
  20 + public:
  21 + explicit Impl(const OfflineSpeechDenoiserModelConfig &config)
  22 + : config_(config),
  23 + env_(ORT_LOGGING_LEVEL_ERROR),
  24 + sess_opts_(GetSessionOptions(config)),
  25 + allocator_{} {
  26 + {
  27 + auto buf = ReadFile(config.gtcrn.model);
  28 + Init(buf.data(), buf.size());
  29 + }
  30 + }
  31 +
  32 + template <typename Manager>
  33 + Impl(Manager *mgr, const OfflineSpeechDenoiserModelConfig &config)
  34 + : config_(config),
  35 + env_(ORT_LOGGING_LEVEL_ERROR),
  36 + sess_opts_(GetSessionOptions(config)),
  37 + allocator_{} {
  38 + {
  39 + auto buf = ReadFile(mgr, config.gtcrn.model);
  40 + Init(buf.data(), buf.size());
  41 + }
  42 + }
  43 +
  44 + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const {
  45 + return meta_;
  46 + }
  47 +
  48 + States GetInitStates() const {
  49 + Ort::Value conv_cache = Ort::Value::CreateTensor<float>(
  50 + allocator_, meta_.conv_cache_shape.data(),
  51 + meta_.conv_cache_shape.size());
  52 +
  53 + Ort::Value tra_cache = Ort::Value::CreateTensor<float>(
  54 + allocator_, meta_.tra_cache_shape.data(), meta_.tra_cache_shape.size());
  55 +
  56 + Ort::Value inter_cache = Ort::Value::CreateTensor<float>(
  57 + allocator_, meta_.inter_cache_shape.data(),
  58 + meta_.inter_cache_shape.size());
  59 +
  60 + Fill<float>(&conv_cache, 0);
  61 + Fill<float>(&tra_cache, 0);
  62 + Fill<float>(&inter_cache, 0);
  63 +
  64 + std::vector<Ort::Value> states;
  65 +
  66 + states.reserve(3);
  67 + states.push_back(std::move(conv_cache));
  68 + states.push_back(std::move(tra_cache));
  69 + states.push_back(std::move(inter_cache));
  70 +
  71 + return states;
  72 + }
  73 +
  74 + std::pair<Ort::Value, States> Run(Ort::Value x, States states) const {
  75 + std::vector<Ort::Value> inputs;
  76 + inputs.reserve(1 + states.size());
  77 + inputs.push_back(std::move(x));
  78 + for (auto &s : states) {
  79 + inputs.push_back(std::move(s));
  80 + }
  81 +
  82 + auto out =
  83 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  84 + output_names_ptr_.data(), output_names_ptr_.size());
  85 +
  86 + std::vector<Ort::Value> next_states;
  87 + next_states.reserve(out.size() - 1);
  88 + for (int32_t k = 1; k < out.size(); ++k) {
  89 + next_states.push_back(std::move(out[k]));
  90 + }
  91 +
  92 + return {std::move(out[0]), std::move(next_states)};
  93 + }
  94 +
  95 + private:
  96 + void Init(void *model_data, size_t model_data_length) {
  97 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  98 + sess_opts_);
  99 +
  100 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  101 +
  102 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  103 +
  104 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  105 + if (config_.debug) {
  106 + std::ostringstream os;
  107 + os << "---gtcrn model---\n";
  108 + PrintModelMetadata(os, meta_data);
  109 +
  110 + os << "----------input names----------\n";
  111 + int32_t i = 0;
  112 + for (const auto &s : input_names_) {
  113 + os << i << " " << s << "\n";
  114 + ++i;
  115 + }
  116 + os << "----------output names----------\n";
  117 + i = 0;
  118 + for (const auto &s : output_names_) {
  119 + os << i << " " << s << "\n";
  120 + ++i;
  121 + }
  122 +
  123 +#if __OHOS__
  124 + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
  125 +#else
  126 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  127 +#endif
  128 + }
  129 +
  130 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  131 +
  132 + std::string model_type;
  133 + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
  134 + if (model_type != "gtcrn") {
  135 + SHERPA_ONNX_LOGE("Expect model type 'gtcrn'. Given: '%s'",
  136 + model_type.c_str());
  137 + SHERPA_ONNX_EXIT(-1);
  138 + }
  139 +
  140 + SHERPA_ONNX_READ_META_DATA(meta_.sample_rate, "sample_rate");
  141 + SHERPA_ONNX_READ_META_DATA(meta_.n_fft, "n_fft");
  142 + SHERPA_ONNX_READ_META_DATA(meta_.hop_length, "hop_length");
  143 + SHERPA_ONNX_READ_META_DATA(meta_.window_length, "window_length");
  144 + SHERPA_ONNX_READ_META_DATA_STR(meta_.window_type, "window_type");
  145 + SHERPA_ONNX_READ_META_DATA(meta_.version, "version");
  146 +
  147 + SHERPA_ONNX_READ_META_DATA_VEC(meta_.conv_cache_shape, "conv_cache_shape");
  148 + SHERPA_ONNX_READ_META_DATA_VEC(meta_.tra_cache_shape, "tra_cache_shape");
  149 + SHERPA_ONNX_READ_META_DATA_VEC(meta_.inter_cache_shape,
  150 + "inter_cache_shape");
  151 + }
  152 +
  153 + private:
  154 + OfflineSpeechDenoiserModelConfig config_;
  155 + OfflineSpeechDenoiserGtcrnModelMetaData meta_;
  156 +
  157 + Ort::Env env_;
  158 + Ort::SessionOptions sess_opts_;
  159 + Ort::AllocatorWithDefaultOptions allocator_;
  160 +
  161 + std::unique_ptr<Ort::Session> sess_;
  162 +
  163 + std::vector<std::string> input_names_;
  164 + std::vector<const char *> input_names_ptr_;
  165 +
  166 + std::vector<std::string> output_names_;
  167 + std::vector<const char *> output_names_ptr_;
  168 +};
  169 +
  170 +OfflineSpeechDenoiserGtcrnModel::~OfflineSpeechDenoiserGtcrnModel() = default;
  171 +
  172 +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel(
  173 + const OfflineSpeechDenoiserModelConfig &config)
  174 + : impl_(std::make_unique<Impl>(config)) {}
  175 +
  176 +template <typename Manager>
  177 +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel(
  178 + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config)
  179 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  180 +
  181 +OfflineSpeechDenoiserGtcrnModel::States
  182 +OfflineSpeechDenoiserGtcrnModel::GetInitStates() const {
  183 + return impl_->GetInitStates();
  184 +}
  185 +
  186 +std::pair<Ort::Value, OfflineSpeechDenoiserGtcrnModel::States>
  187 +OfflineSpeechDenoiserGtcrnModel::Run(Ort::Value x, States states) const {
  188 + return impl_->Run(std::move(x), std::move(states));
  189 +}
  190 +
  191 +const OfflineSpeechDenoiserGtcrnModelMetaData &
  192 +OfflineSpeechDenoiserGtcrnModel::GetMetaData() const {
  193 + return impl_->GetMetaData();
  194 +}
  195 +
  196 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
  6 +#include <memory>
  7 +#include <utility>
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h"
  12 +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
  13 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OfflineSpeechDenoiserGtcrnModel {
  18 + public:
  19 + ~OfflineSpeechDenoiserGtcrnModel();
  20 + explicit OfflineSpeechDenoiserGtcrnModel(
  21 + const OfflineSpeechDenoiserModelConfig &config);
  22 +
  23 + template <typename Manager>
  24 + OfflineSpeechDenoiserGtcrnModel(
  25 + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config);
  26 +
  27 + using States = std::vector<Ort::Value>;
  28 +
  29 + States GetInitStates() const;
  30 +
  31 + std::pair<Ort::Value, States> Run(Ort::Value x, States states) const;
  32 +
  33 + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const;
  34 +
  35 + private:
  36 + class Impl;
  37 + std::unique_ptr<Impl> impl_;
  38 +};
  39 +
  40 +} // namespace sherpa_onnx
  41 +
  42 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-impl.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
  5 +
  6 +#include <memory>
  7 +
  8 +#if __ANDROID_API__ >= 9
  9 +#include "android/asset_manager.h"
  10 +#include "android/asset_manager_jni.h"
  11 +#endif
  12 +
  13 +#if __OHOS__
  14 +#include "rawfile/raw_file_manager.h"
  15 +#endif
  16 +
  17 +#include "sherpa-onnx/csrc/macros.h"
  18 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create(
  23 + const OfflineSpeechDenoiserConfig &config) {
  24 + if (!config.model.gtcrn.model.empty()) {
  25 + return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(config);
  26 + }
  27 + SHERPA_ONNX_LOGE("Please provide a speech denoising model.");
  28 + return nullptr;
  29 +}
  30 +
  31 +template <typename Manager>
  32 +std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create(
  33 + Manager *mgr, const OfflineSpeechDenoiserConfig &config) {
  34 + if (!config.model.gtcrn.model.empty()) {
  35 + return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(mgr, config);
  36 + }
  37 + SHERPA_ONNX_LOGE("Please provide a speech denoising model.");
  38 + return nullptr;
  39 +}
  40 +
  41 +#if __ANDROID_API__ >= 9
  42 +template std::unique_ptr<OfflineSpeechDenoiserImpl>
  43 +OfflineSpeechDenoiserImpl::Create(AAssetManager *mgr,
  44 + const OfflineSpeechDenoiserConfig &config);
  45 +#endif
  46 +
  47 +#if __OHOS__
  48 +template std::unique_ptr<OfflineSpeechDenoiserImpl>
  49 +OfflineSpeechDenoiserImpl::Create(NativeResourceManager *mgr,
  50 + const OfflineSpeechDenoiserConfig &config);
  51 +#endif
  52 +
  53 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-speech-denoiser-impl.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
  7 +
  8 +#include <memory>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineSpeechDenoiserImpl {
  15 + public:
  16 + virtual ~OfflineSpeechDenoiserImpl() = default;
  17 +
  18 + static std::unique_ptr<OfflineSpeechDenoiserImpl> Create(
  19 + const OfflineSpeechDenoiserConfig &config);
  20 +
  21 + template <typename Manager>
  22 + static std::unique_ptr<OfflineSpeechDenoiserImpl> Create(
  23 + Manager *mgr, const OfflineSpeechDenoiserConfig &config);
  24 +
  25 + virtual DenoisedAudio Run(const float *samples, int32_t n,
  26 + int32_t sample_rate) const = 0;
  27 +
  28 + virtual int32_t GetSampleRate() const = 0;
  29 +};
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +void OfflineSpeechDenoiserModelConfig::Register(ParseOptions *po) {
  12 + gtcrn.Register(po);
  13 +
  14 + po->Register("num-threads", &num_threads,
  15 + "Number of threads to run the neural network");
  16 +
  17 + po->Register("debug", &debug,
  18 + "true to print model information while loading it.");
  19 +
  20 + po->Register("provider", &provider,
  21 + "Specify a provider to use: cpu, cuda, coreml");
  22 +}
  23 +
  24 +bool OfflineSpeechDenoiserModelConfig::Validate() const {
  25 + return gtcrn.Validate();
  26 +}
  27 +
  28 +std::string OfflineSpeechDenoiserModelConfig::ToString() const {
  29 + std::ostringstream os;
  30 +
  31 + os << "OfflineSpeechDenoiserModelConfig(";
  32 + os << "gtcrn=" << gtcrn.ToString() << ", ";
  33 + os << "num_threads=" << num_threads << ", ";
  34 + os << "debug=" << (debug ? "True" : "False") << ", ";
  35 + os << "provider=\"" << provider << "\")";
  36 +
  37 + return os.str();
  38 +}
  39 +
  40 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speech-denoiser-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineSpeechDenoiserModelConfig {
  15 + OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
  16 +
  17 + int32_t num_threads = 1;
  18 + bool debug = false;
  19 + std::string provider = "cpu";
  20 +
  21 + OfflineSpeechDenoiserModelConfig() = default;
  22 +
  23 + OfflineSpeechDenoiserModelConfig(OfflineSpeechDenoiserGtcrnModelConfig gtcrn,
  24 + int32_t num_threads, bool debug,
  25 + const std::string &provider)
  26 + : gtcrn(gtcrn),
  27 + num_threads(num_threads),
  28 + debug(debug),
  29 + provider(provider) {}
  30 +
  31 + void Register(ParseOptions *po);
  32 + bool Validate() const;
  33 +
  34 + std::string ToString() const;
  35 +};
  36 +
  37 +} // namespace sherpa_onnx
  38 +
  39 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-speech-denoiser.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  6 +
  7 +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
  8 +
  9 +#if __ANDROID_API__ >= 9
  10 +#include "android/asset_manager.h"
  11 +#include "android/asset_manager_jni.h"
  12 +#endif
  13 +
  14 +#if __OHOS__
  15 +#include "rawfile/raw_file_manager.h"
  16 +#endif
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +void OfflineSpeechDenoiserConfig::Register(ParseOptions *po) {
  21 + model.Register(po);
  22 +}
  23 +
  24 +bool OfflineSpeechDenoiserConfig::Validate() const { return model.Validate(); }
  25 +
  26 +std::string OfflineSpeechDenoiserConfig::ToString() const {
  27 + std::ostringstream os;
  28 +
  29 + os << "OfflineSpeechDenoiserConfig(";
  30 + os << "model=" << model.ToString() << ")";
  31 + return os.str();
  32 +}
  33 +
  34 +template <typename Manager>
  35 +OfflineSpeechDenoiser::OfflineSpeechDenoiser(
  36 + Manager *mgr, const OfflineSpeechDenoiserConfig &config)
  37 + : impl_(OfflineSpeechDenoiserImpl::Create(mgr, config)) {}
  38 +
  39 +OfflineSpeechDenoiser::OfflineSpeechDenoiser(
  40 + const OfflineSpeechDenoiserConfig &config)
  41 + : impl_(OfflineSpeechDenoiserImpl::Create(config)) {}
  42 +
  43 +OfflineSpeechDenoiser::~OfflineSpeechDenoiser() = default;
  44 +
  45 +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
  46 + int32_t sample_rate) const {
  47 + return impl_->Run(samples, n, sample_rate);
  48 +}
  49 +
  50 +int32_t OfflineSpeechDenoiser::GetSampleRate() const {
  51 + return impl_->GetSampleRate();
  52 +}
  53 +
  54 +#if __ANDROID_API__ >= 9
  55 +template OfflineSpeechDenoiser::OfflineSpeechDenoiser(
  56 + AAssetManager *mgr, const OfflineSpeechDenoiserConfig &config);
  57 +#endif
  58 +
  59 +#if __OHOS__
  60 +template OfflineSpeechDenoiser::OfflineSpeechDenoiser(
  61 + NativeResourceManager *mgr, const OfflineSpeechDenoiserConfig &config);
  62 +#endif
  63 +
  64 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speech-denoiser.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
  12 +#include "sherpa-onnx/csrc/parse-options.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +struct DenoisedAudio {
  17 + std::vector<float> samples;
  18 + int32_t sample_rate;
  19 +};
  20 +
  21 +struct OfflineSpeechDenoiserConfig {
  22 + OfflineSpeechDenoiserModelConfig model;
  23 +
  24 + void Register(ParseOptions *po);
  25 + bool Validate() const;
  26 +
  27 + std::string ToString() const;
  28 +};
  29 +
  30 +class OfflineSpeechDenoiserImpl;
  31 +
  32 +class OfflineSpeechDenoiser {
  33 + public:
  34 + explicit OfflineSpeechDenoiser(const OfflineSpeechDenoiserConfig &config);
  35 + ~OfflineSpeechDenoiser();
  36 +
  37 + template <typename Manager>
  38 + OfflineSpeechDenoiser(Manager *mgr,
  39 + const OfflineSpeechDenoiserConfig &config);
  40 +
  41 + /*
  42 + * @param samples 1-D array of audio samples. Each sample is in the
  43 + * range [-1, 1].
  44 + * @param n Number of samples
  45 + * @param sample_rate Sample rate of the input samples
  46 + *
  47 + */
  48 + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
  49 +
  50 + /*
  51 + * Return the sample rate of the denoised audio
  52 + */
  53 + int32_t GetSampleRate() const;
  54 +
  55 + private:
  56 + std::unique_ptr<OfflineSpeechDenoiserImpl> impl_;
  57 +};
  58 +
  59 +} // namespace sherpa_onnx
  60 +
  61 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
1 -// sherpa-onnx/csrc/offline-tts-kokoro-model-metadata.h 1 +// sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h
2 // 2 //
3 // Copyright (c) 2025 Xiaomi Corporation 3 // Copyright (c) 2025 Xiaomi Corporation
4 4
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +#include <stdio.h>
  5 +
  6 +#include <chrono> // NOLINT
  7 +
  8 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  9 +#include "sherpa-onnx/csrc/wave-reader.h"
  10 +#include "sherpa-onnx/csrc/wave-writer.h"
  11 +
  12 +int main(int32_t argc, char *argv[]) {
  13 + const char *kUsageMessage = R"usage(
  14 +Non-stremaing speech denoising with sherpa-onnx.
  15 +
  16 +Please visit
  17 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
  18 +to download models.
  19 +
  20 +Usage:
  21 +
  22 +(1) Use gtcrn models
  23 +
  24 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  25 +./bin/sherpa-onnx-offline-denoiser \
  26 + --speech-denoiser-gtcrn-model=gtcrn_simple.onnx \
  27 + --input-wav input.wav \
  28 + --output-wav output_16k.wav
  29 +)usage";
  30 +
  31 + sherpa_onnx::ParseOptions po(kUsageMessage);
  32 + sherpa_onnx::OfflineSpeechDenoiserConfig config;
  33 + std::string input_wave;
  34 + std::string output_wave;
  35 +
  36 + config.Register(&po);
  37 + po.Register("input-wav", &input_wave, "Path to input wav.");
  38 + po.Register("output-wav", &output_wave, "Path to output wav");
  39 +
  40 + po.Read(argc, argv);
  41 + if (po.NumArgs() != 0) {
  42 + fprintf(stderr, "Please don't give positional arguments\n");
  43 + po.PrintUsage();
  44 + exit(EXIT_FAILURE);
  45 + }
  46 + fprintf(stderr, "%s\n", config.ToString().c_str());
  47 +
  48 + if (input_wave.empty()) {
  49 + fprintf(stderr, "Please provide --input-wav\n");
  50 + po.PrintUsage();
  51 + exit(EXIT_FAILURE);
  52 + }
  53 +
  54 + if (output_wave.empty()) {
  55 + fprintf(stderr, "Please provide --output-wav\n");
  56 + po.PrintUsage();
  57 + exit(EXIT_FAILURE);
  58 + }
  59 +
  60 + sherpa_onnx::OfflineSpeechDenoiser denoiser(config);
  61 + int32_t sampling_rate = -1;
  62 + bool is_ok = false;
  63 + std::vector<float> samples =
  64 + sherpa_onnx::ReadWave(input_wave, &sampling_rate, &is_ok);
  65 + if (!is_ok) {
  66 + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
  67 + return -1;
  68 + }
  69 +
  70 + fprintf(stderr, "Started\n");
  71 + const auto begin = std::chrono::steady_clock::now();
  72 + auto result = denoiser.Run(samples.data(), samples.size(), sampling_rate);
  73 + const auto end = std::chrono::steady_clock::now();
  74 +
  75 + float elapsed_seconds =
  76 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  77 + .count() /
  78 + 1000.;
  79 +
  80 + fprintf(stderr, "Done\n");
  81 + is_ok = sherpa_onnx::WriteWave(output_wave, result.sample_rate,
  82 + result.samples.data(), result.samples.size());
  83 + if (is_ok) {
  84 + fprintf(stderr, "Saved to %s\n", output_wave.c_str());
  85 + } else {
  86 + fprintf(stderr, "Failed to save to %s\n", output_wave.c_str());
  87 + }
  88 +
  89 + float duration = samples.size() / static_cast<float>(sampling_rate);
  90 + fprintf(stderr, "num threads: %d\n", config.model.num_threads);
  91 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  92 + float rtf = elapsed_seconds / duration;
  93 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  94 + elapsed_seconds, duration, rtf);
  95 +}