Committed by
GitHub
Add C++ runtime for speech enhancement GTCRN models (#1977)
See also https://github.com/Xiaobin-Rong/gtcrn
正在显示
20 个修改的文件
包含
950 行增加
和
12 行删除
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +if [ -z $EXE ]; then | ||
| 12 | + EXE=./build/bin/sherpa-onnx-offline-denoiser | ||
| 13 | +fi | ||
| 14 | + | ||
| 15 | +echo "EXE is $EXE" | ||
| 16 | +echo "PATH: $PATH" | ||
| 17 | + | ||
| 18 | +which $EXE | ||
| 19 | + | ||
| 20 | +log "------------------------------------------------------------" | ||
| 21 | +log "Run gtcrn" | ||
| 22 | +log "------------------------------------------------------------" | ||
| 23 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | ||
| 24 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav | ||
| 25 | + | ||
| 26 | +$EXE \ | ||
| 27 | + --debug=1 \ | ||
| 28 | + --speech-denoiser-gtcrn-model=./gtcrn_simple.onnx \ | ||
| 29 | + --input-wav=./speech_with_noise.wav \ | ||
| 30 | + --output-wav=./enhanced_speech_16k.wav | ||
| 31 | + | ||
| 32 | +rm ./gtcrn_simple.onnx |
| @@ -10,6 +10,7 @@ on: | @@ -10,6 +10,7 @@ on: | ||
| 10 | - '.github/workflows/linux.yaml' | 10 | - '.github/workflows/linux.yaml' |
| 11 | - '.github/scripts/test-kws.sh' | 11 | - '.github/scripts/test-kws.sh' |
| 12 | - '.github/scripts/test-online-transducer.sh' | 12 | - '.github/scripts/test-online-transducer.sh' |
| 13 | + - '.github/scripts/test-offline-speech-denoiser.sh' | ||
| 13 | - '.github/scripts/test-online-paraformer.sh' | 14 | - '.github/scripts/test-online-paraformer.sh' |
| 14 | - '.github/scripts/test-offline-transducer.sh' | 15 | - '.github/scripts/test-offline-transducer.sh' |
| 15 | - '.github/scripts/test-offline-ctc.sh' | 16 | - '.github/scripts/test-offline-ctc.sh' |
| @@ -31,6 +32,7 @@ on: | @@ -31,6 +32,7 @@ on: | ||
| 31 | paths: | 32 | paths: |
| 32 | - '.github/workflows/linux.yaml' | 33 | - '.github/workflows/linux.yaml' |
| 33 | - '.github/scripts/test-kws.sh' | 34 | - '.github/scripts/test-kws.sh' |
| 35 | + - '.github/scripts/test-offline-speech-denoiser.sh' | ||
| 34 | - '.github/scripts/test-online-transducer.sh' | 36 | - '.github/scripts/test-online-transducer.sh' |
| 35 | - '.github/scripts/test-online-paraformer.sh' | 37 | - '.github/scripts/test-online-paraformer.sh' |
| 36 | - '.github/scripts/test-offline-transducer.sh' | 38 | - '.github/scripts/test-offline-transducer.sh' |
| @@ -203,6 +205,15 @@ jobs: | @@ -203,6 +205,15 @@ jobs: | ||
| 203 | overwrite: true | 205 | overwrite: true |
| 204 | file: sherpa-onnx-*.tar.bz2 | 206 | file: sherpa-onnx-*.tar.bz2 |
| 205 | 207 | ||
| 208 | + - name: Test offline speech denoiser | ||
| 209 | + shell: bash | ||
| 210 | + run: | | ||
| 211 | + du -h -d1 . | ||
| 212 | + export PATH=$PWD/build/bin:$PATH | ||
| 213 | + export EXE=sherpa-onnx-offline-denoiser | ||
| 214 | + | ||
| 215 | + .github/scripts/test-offline-speech-denoiser.sh | ||
| 216 | + | ||
| 206 | - name: Test offline TTS | 217 | - name: Test offline TTS |
| 207 | if: matrix.with_tts == 'ON' | 218 | if: matrix.with_tts == 'ON' |
| 208 | shell: bash | 219 | shell: bash |
| @@ -215,6 +226,11 @@ jobs: | @@ -215,6 +226,11 @@ jobs: | ||
| 215 | du -h -d1 . | 226 | du -h -d1 . |
| 216 | 227 | ||
| 217 | - uses: actions/upload-artifact@v4 | 228 | - uses: actions/upload-artifact@v4 |
| 229 | + with: | ||
| 230 | + name: speech-denoiser-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | ||
| 231 | + path: ./*speech*.wav | ||
| 232 | + | ||
| 233 | + - uses: actions/upload-artifact@v4 | ||
| 218 | if: matrix.with_tts == 'ON' | 234 | if: matrix.with_tts == 'ON' |
| 219 | with: | 235 | with: |
| 220 | name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 236 | name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| @@ -7,6 +7,7 @@ on: | @@ -7,6 +7,7 @@ on: | ||
| 7 | tags: | 7 | tags: |
| 8 | - 'v[0-9]+.[0-9]+.[0-9]+*' | 8 | - 'v[0-9]+.[0-9]+.[0-9]+*' |
| 9 | paths: | 9 | paths: |
| 10 | + - '.github/scripts/test-offline-speech-denoiser.sh' | ||
| 10 | - '.github/workflows/macos.yaml' | 11 | - '.github/workflows/macos.yaml' |
| 11 | - '.github/scripts/test-kws.sh' | 12 | - '.github/scripts/test-kws.sh' |
| 12 | - '.github/scripts/test-online-transducer.sh' | 13 | - '.github/scripts/test-online-transducer.sh' |
| @@ -28,6 +29,7 @@ on: | @@ -28,6 +29,7 @@ on: | ||
| 28 | branches: | 29 | branches: |
| 29 | - master | 30 | - master |
| 30 | paths: | 31 | paths: |
| 32 | + - '.github/scripts/test-offline-speech-denoiser.sh' | ||
| 31 | - '.github/workflows/macos.yaml' | 33 | - '.github/workflows/macos.yaml' |
| 32 | - '.github/scripts/test-kws.sh' | 34 | - '.github/scripts/test-kws.sh' |
| 33 | - '.github/scripts/test-online-transducer.sh' | 35 | - '.github/scripts/test-online-transducer.sh' |
| @@ -160,6 +162,15 @@ jobs: | @@ -160,6 +162,15 @@ jobs: | ||
| 160 | overwrite: true | 162 | overwrite: true |
| 161 | file: sherpa-onnx-*osx-universal2*.tar.bz2 | 163 | file: sherpa-onnx-*osx-universal2*.tar.bz2 |
| 162 | 164 | ||
| 165 | + - name: Test offline speech denoiser | ||
| 166 | + shell: bash | ||
| 167 | + run: | | ||
| 168 | + du -h -d1 . | ||
| 169 | + export PATH=$PWD/build/bin:$PATH | ||
| 170 | + export EXE=sherpa-onnx-offline-denoiser | ||
| 171 | + | ||
| 172 | + .github/scripts/test-offline-speech-denoiser.sh | ||
| 173 | + | ||
| 163 | - name: Test offline TTS | 174 | - name: Test offline TTS |
| 164 | if: matrix.with_tts == 'ON' | 175 | if: matrix.with_tts == 'ON' |
| 165 | shell: bash | 176 | shell: bash |
| @@ -12,9 +12,9 @@ | @@ -12,9 +12,9 @@ | ||
| 12 | |--------------------------------|---------------|--------------------------| | 12 | |--------------------------------|---------------|--------------------------| |
| 13 | | ✔️ | ✔️ | ✔️ | | 13 | | ✔️ | ✔️ | ✔️ | |
| 14 | 14 | ||
| 15 | -| Keyword spotting | Add punctuation | | ||
| 16 | -|------------------|-----------------| | ||
| 17 | -| ✔️ | ✔️ | | 15 | +| Keyword spotting | Add punctuation | Speech enhancement | |
| 16 | +|------------------|-----------------|--------------------| | ||
| 17 | +| ✔️ | ✔️ | ✔️ | | ||
| 18 | 18 | ||
| 19 | ### Supported platforms | 19 | ### Supported platforms |
| 20 | 20 | ||
| @@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below: | @@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below: | ||
| 198 | | Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]| | 198 | | Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]| |
| 199 | | Punctuation | [Address][punct-models] | | 199 | | Punctuation | [Address][punct-models] | |
| 200 | | Speaker segmentation | [Address][speaker-segmentation-models] | | 200 | | Speaker segmentation | [Address][speaker-segmentation-models] | |
| 201 | +| Speech enhancement | [Address][speech-enhancement-models] | | ||
| 201 | 202 | ||
| 202 | </details> | 203 | </details> |
| 203 | 204 | ||
| @@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss | @@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss | ||
| 442 | [Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 | 443 | [Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 |
| 443 | [NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9 | 444 | [NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9 |
| 444 | [NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/ | 445 | [NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/ |
| 446 | +[speech-enhancement-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models |
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.20.0.tar.gz") | ||
| 6 | - set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.1.tar.gz") |
| 5 | + set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.1.tar.gz") | ||
| 6 | + set(kaldi_native_fbank_HASH "SHA256=37c1aa230b00fe062791d800d8fc50aa3de215918d3dce6440699e67275d859e") | ||
| 7 | 7 | ||
| 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-native-fbank | 13 | # please pre-download kaldi-native-fbank |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz | ||
| 16 | - ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz | ||
| 17 | - ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz | ||
| 18 | - /tmp/kaldi-native-fbank-1.20.0.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.1.tar.gz |
| 16 | + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.1.tar.gz | ||
| 17 | + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.1.tar.gz | ||
| 18 | + /tmp/kaldi-native-fbank-1.21.1.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.1.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| @@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS) | @@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS) | ||
| 186 | ) | 186 | ) |
| 187 | endif() | 187 | endif() |
| 188 | 188 | ||
| 189 | +list(APPEND sources | ||
| 190 | + offline-speech-denoiser-gtcrn-model-config.cc | ||
| 191 | + offline-speech-denoiser-gtcrn-model.cc | ||
| 192 | + offline-speech-denoiser-impl.cc | ||
| 193 | + offline-speech-denoiser-model-config.cc | ||
| 194 | + offline-speech-denoiser.cc | ||
| 195 | +) | ||
| 196 | + | ||
| 189 | if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | 197 | if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) |
| 190 | list(APPEND sources | 198 | list(APPEND sources |
| 191 | fast-clustering-config.cc | 199 | fast-clustering-config.cc |
| @@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 301 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) | 309 | add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) |
| 302 | add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) | 310 | add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) |
| 303 | add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) | 311 | add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) |
| 312 | + add_executable(sherpa-onnx-offline-denoiser sherpa-onnx-offline-denoiser.cc) | ||
| 304 | 313 | ||
| 305 | if(SHERPA_ONNX_ENABLE_TTS) | 314 | if(SHERPA_ONNX_ENABLE_TTS) |
| 306 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | 315 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) |
| @@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 318 | sherpa-onnx-offline-language-identification | 327 | sherpa-onnx-offline-language-identification |
| 319 | sherpa-onnx-offline-parallel | 328 | sherpa-onnx-offline-parallel |
| 320 | sherpa-onnx-offline-punctuation | 329 | sherpa-onnx-offline-punctuation |
| 330 | + sherpa-onnx-offline-denoiser | ||
| 321 | sherpa-onnx-online-punctuation | 331 | sherpa-onnx-online-punctuation |
| 322 | ) | 332 | ) |
| 323 | if(SHERPA_ONNX_ENABLE_TTS) | 333 | if(SHERPA_ONNX_ENABLE_TTS) |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "kaldi-native-fbank/csrc/feature-window.h" | ||
| 14 | +#include "kaldi-native-fbank/csrc/istft.h" | ||
| 15 | +#include "kaldi-native-fbank/csrc/stft.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h" | ||
| 17 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" | ||
| 19 | +#include "sherpa-onnx/csrc/resample.h" | ||
| 20 | + | ||
| 21 | +namespace sherpa_onnx { | ||
| 22 | + | ||
| 23 | +class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl { | ||
| 24 | + public: | ||
| 25 | + explicit OfflineSpeechDenoiserGtcrnImpl( | ||
| 26 | + const OfflineSpeechDenoiserConfig &config) | ||
| 27 | + : model_(config.model) {} | ||
| 28 | + | ||
| 29 | + template <typename Manager> | ||
| 30 | + OfflineSpeechDenoiserGtcrnImpl(Manager *mgr, | ||
| 31 | + const OfflineSpeechDenoiserConfig &config) | ||
| 32 | + : model_(mgr, config.model) {} | ||
| 33 | + | ||
| 34 | + DenoisedAudio Run(const float *samples, int32_t n, | ||
| 35 | + int32_t sample_rate) const override { | ||
| 36 | + SHERPA_ONNX_LOGE("n: %d, sample_rate: %d", n, sample_rate); | ||
| 37 | + const auto &meta = model_.GetMetaData(); | ||
| 38 | + | ||
| 39 | + std::vector<float> tmp; | ||
| 40 | + auto p = samples; | ||
| 41 | + | ||
| 42 | + if (sample_rate != meta.sample_rate) { | ||
| 43 | + SHERPA_ONNX_LOGE( | ||
| 44 | + "Creating a resampler:\n" | ||
| 45 | + " in_sample_rate: %d\n" | ||
| 46 | + " output_sample_rate: %d\n", | ||
| 47 | + sample_rate, meta.sample_rate); | ||
| 48 | + | ||
| 49 | + float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate); | ||
| 50 | + float lowpass_cutoff = 0.99 * 0.5 * min_freq; | ||
| 51 | + | ||
| 52 | + int32_t lowpass_filter_width = 6; | ||
| 53 | + auto resampler = std::make_unique<LinearResample>( | ||
| 54 | + sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width); | ||
| 55 | + resampler->Resample(samples, n, true, &tmp); | ||
| 56 | + p = tmp.data(); | ||
| 57 | + n = tmp.size(); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + knf::StftConfig stft_config; | ||
| 61 | + stft_config.n_fft = meta.n_fft; | ||
| 62 | + stft_config.hop_length = meta.hop_length; | ||
| 63 | + stft_config.win_length = meta.window_length; | ||
| 64 | + stft_config.window_type = meta.window_type; | ||
| 65 | + if (stft_config.window_type == "hann_sqrt") { | ||
| 66 | + auto window = knf::GetWindow("hann", stft_config.win_length); | ||
| 67 | + for (auto &w : window) { | ||
| 68 | + w = std::sqrt(w); | ||
| 69 | + } | ||
| 70 | + stft_config.window = std::move(window); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + knf::Stft stft(stft_config); | ||
| 74 | + knf::StftResult stft_result = stft.Compute(p, n); | ||
| 75 | + | ||
| 76 | + auto states = model_.GetInitStates(); | ||
| 77 | + OfflineSpeechDenoiserGtcrnModel::States next_states; | ||
| 78 | + | ||
| 79 | + knf::StftResult enhanced_stft_result; | ||
| 80 | + enhanced_stft_result.num_frames = stft_result.num_frames; | ||
| 81 | + for (int32_t i = 0; i < stft_result.num_frames; ++i) { | ||
| 82 | + auto p = Process(stft_result, i, std::move(states), &next_states); | ||
| 83 | + states = std::move(next_states); | ||
| 84 | + | ||
| 85 | + enhanced_stft_result.real.insert(enhanced_stft_result.real.end(), | ||
| 86 | + p.first.begin(), p.first.end()); | ||
| 87 | + enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(), | ||
| 88 | + p.second.begin(), p.second.end()); | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + knf::IStft istft(stft_config); | ||
| 92 | + | ||
| 93 | + DenoisedAudio denoised_audio; | ||
| 94 | + denoised_audio.sample_rate = meta.sample_rate; | ||
| 95 | + denoised_audio.samples = istft.Compute(enhanced_stft_result); | ||
| 96 | + return denoised_audio; | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + int32_t GetSampleRate() const override { | ||
| 100 | + return model_.GetMetaData().sample_rate; | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + private: | ||
| 104 | + std::pair<std::vector<float>, std::vector<float>> Process( | ||
| 105 | + const knf::StftResult &stft_result, int32_t frame_index, | ||
| 106 | + OfflineSpeechDenoiserGtcrnModel::States states, | ||
| 107 | + OfflineSpeechDenoiserGtcrnModel::States *next_states) const { | ||
| 108 | + const auto &meta = model_.GetMetaData(); | ||
| 109 | + int32_t n_fft = meta.n_fft; | ||
| 110 | + std::vector<float> x((n_fft / 2 + 1) * 2); | ||
| 111 | + | ||
| 112 | + const float *p_real = | ||
| 113 | + stft_result.real.data() + frame_index * (n_fft / 2 + 1); | ||
| 114 | + const float *p_imag = | ||
| 115 | + stft_result.imag.data() + frame_index * (n_fft / 2 + 1); | ||
| 116 | + | ||
| 117 | + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { | ||
| 118 | + x[2 * i] = p_real[i]; | ||
| 119 | + x[2 * i + 1] = p_imag[i]; | ||
| 120 | + } | ||
| 121 | + auto memory_info = | ||
| 122 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 123 | + | ||
| 124 | + std::array<int64_t, 4> x_shape{1, n_fft / 2 + 1, 1, 2}; | ||
| 125 | + Ort::Value x_tensor = Ort::Value::CreateTensor( | ||
| 126 | + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); | ||
| 127 | + | ||
| 128 | + Ort::Value output{nullptr}; | ||
| 129 | + std::tie(output, *next_states) = | ||
| 130 | + model_.Run(std::move(x_tensor), std::move(states)); | ||
| 131 | + | ||
| 132 | + std::vector<float> real(n_fft / 2 + 1); | ||
| 133 | + std::vector<float> imag(n_fft / 2 + 1); | ||
| 134 | + const auto *p = output.GetTensorData<float>(); | ||
| 135 | + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { | ||
| 136 | + real[i] = p[2 * i]; | ||
| 137 | + imag[i] = p[2 * i + 1]; | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + return {std::move(real), std::move(imag)}; | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + private: | ||
| 144 | + OfflineSpeechDenoiserGtcrnModel model_; | ||
| 145 | +}; | ||
| 146 | + | ||
| 147 | +} // namespace sherpa_onnx | ||
| 148 | + | ||
| 149 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void OfflineSpeechDenoiserGtcrnModelConfig::Register(ParseOptions *po) { | ||
| 15 | + po->Register("speech-denoiser-gtcrn-model", &model, | ||
| 16 | + "Path to the gtcrn model for speech denoising"); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OfflineSpeechDenoiserGtcrnModelConfig::Validate() const { | ||
| 20 | + if (model.empty()) { | ||
| 21 | + SHERPA_ONNX_LOGE("Please provide --speech-denoiser-gtcrn-model"); | ||
| 22 | + return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + if (!FileExists(model)) { | ||
| 26 | + SHERPA_ONNX_LOGE("gtcrn model file '%s' does not exist", model.c_str()); | ||
| 27 | + return false; | ||
| 28 | + } | ||
| 29 | + return true; | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +std::string OfflineSpeechDenoiserGtcrnModelConfig::ToString() const { | ||
| 33 | + std::ostringstream os; | ||
| 34 | + | ||
| 35 | + os << "OfflineSpeechDenoiserGtcrnModelConfig("; | ||
| 36 | + os << "model=\"" << model << "\")"; | ||
| 37 | + return os.str(); | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineSpeechDenoiserGtcrnModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + void Register(ParseOptions *po); | ||
| 17 | + bool Validate() const; | ||
| 18 | + | ||
| 19 | + std::string ToString() const; | ||
| 20 | +}; | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx | ||
| 23 | + | ||
| 24 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ | ||
| 7 | + | ||
| 8 | +#include <cstdint> | ||
| 9 | +#include <string> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +// please refer to | ||
| 15 | +// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py | ||
| 16 | +struct OfflineSpeechDenoiserGtcrnModelMetaData { | ||
| 17 | + int32_t sample_rate = 0; | ||
| 18 | + int32_t version = 1; | ||
| 19 | + int32_t n_fft = 0; | ||
| 20 | + int32_t hop_length = 0; | ||
| 21 | + int32_t window_length = 0; | ||
| 22 | + std::string window_type; | ||
| 23 | + | ||
| 24 | + std::vector<int64_t> conv_cache_shape; | ||
| 25 | + std::vector<int64_t> tra_cache_shape; | ||
| 26 | + std::vector<int64_t> inter_cache_shape; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | +#include "sherpa-onnx/csrc/session.h" | ||
| 15 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +class OfflineSpeechDenoiserGtcrnModel::Impl { | ||
| 20 | + public: | ||
| 21 | + explicit Impl(const OfflineSpeechDenoiserModelConfig &config) | ||
| 22 | + : config_(config), | ||
| 23 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 24 | + sess_opts_(GetSessionOptions(config)), | ||
| 25 | + allocator_{} { | ||
| 26 | + { | ||
| 27 | + auto buf = ReadFile(config.gtcrn.model); | ||
| 28 | + Init(buf.data(), buf.size()); | ||
| 29 | + } | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + template <typename Manager> | ||
| 33 | + Impl(Manager *mgr, const OfflineSpeechDenoiserModelConfig &config) | ||
| 34 | + : config_(config), | ||
| 35 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 36 | + sess_opts_(GetSessionOptions(config)), | ||
| 37 | + allocator_{} { | ||
| 38 | + { | ||
| 39 | + auto buf = ReadFile(mgr, config.gtcrn.model); | ||
| 40 | + Init(buf.data(), buf.size()); | ||
| 41 | + } | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const { | ||
| 45 | + return meta_; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + States GetInitStates() const { | ||
| 49 | + Ort::Value conv_cache = Ort::Value::CreateTensor<float>( | ||
| 50 | + allocator_, meta_.conv_cache_shape.data(), | ||
| 51 | + meta_.conv_cache_shape.size()); | ||
| 52 | + | ||
| 53 | + Ort::Value tra_cache = Ort::Value::CreateTensor<float>( | ||
| 54 | + allocator_, meta_.tra_cache_shape.data(), meta_.tra_cache_shape.size()); | ||
| 55 | + | ||
| 56 | + Ort::Value inter_cache = Ort::Value::CreateTensor<float>( | ||
| 57 | + allocator_, meta_.inter_cache_shape.data(), | ||
| 58 | + meta_.inter_cache_shape.size()); | ||
| 59 | + | ||
| 60 | + Fill<float>(&conv_cache, 0); | ||
| 61 | + Fill<float>(&tra_cache, 0); | ||
| 62 | + Fill<float>(&inter_cache, 0); | ||
| 63 | + | ||
| 64 | + std::vector<Ort::Value> states; | ||
| 65 | + | ||
| 66 | + states.reserve(3); | ||
| 67 | + states.push_back(std::move(conv_cache)); | ||
| 68 | + states.push_back(std::move(tra_cache)); | ||
| 69 | + states.push_back(std::move(inter_cache)); | ||
| 70 | + | ||
| 71 | + return states; | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + std::pair<Ort::Value, States> Run(Ort::Value x, States states) const { | ||
| 75 | + std::vector<Ort::Value> inputs; | ||
| 76 | + inputs.reserve(1 + states.size()); | ||
| 77 | + inputs.push_back(std::move(x)); | ||
| 78 | + for (auto &s : states) { | ||
| 79 | + inputs.push_back(std::move(s)); | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + auto out = | ||
| 83 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 84 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 85 | + | ||
| 86 | + std::vector<Ort::Value> next_states; | ||
| 87 | + next_states.reserve(out.size() - 1); | ||
| 88 | + for (int32_t k = 1; k < out.size(); ++k) { | ||
| 89 | + next_states.push_back(std::move(out[k])); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + return {std::move(out[0]), std::move(next_states)}; | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + private: | ||
| 96 | + void Init(void *model_data, size_t model_data_length) { | ||
| 97 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 98 | + sess_opts_); | ||
| 99 | + | ||
| 100 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 101 | + | ||
| 102 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 103 | + | ||
| 104 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 105 | + if (config_.debug) { | ||
| 106 | + std::ostringstream os; | ||
| 107 | + os << "---gtcrn model---\n"; | ||
| 108 | + PrintModelMetadata(os, meta_data); | ||
| 109 | + | ||
| 110 | + os << "----------input names----------\n"; | ||
| 111 | + int32_t i = 0; | ||
| 112 | + for (const auto &s : input_names_) { | ||
| 113 | + os << i << " " << s << "\n"; | ||
| 114 | + ++i; | ||
| 115 | + } | ||
| 116 | + os << "----------output names----------\n"; | ||
| 117 | + i = 0; | ||
| 118 | + for (const auto &s : output_names_) { | ||
| 119 | + os << i << " " << s << "\n"; | ||
| 120 | + ++i; | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | +#if __OHOS__ | ||
| 124 | + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); | ||
| 125 | +#else | ||
| 126 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 127 | +#endif | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 131 | + | ||
| 132 | + std::string model_type; | ||
| 133 | + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); | ||
| 134 | + if (model_type != "gtcrn") { | ||
| 135 | + SHERPA_ONNX_LOGE("Expect model type 'gtcrn'. Given: '%s'", | ||
| 136 | + model_type.c_str()); | ||
| 137 | + SHERPA_ONNX_EXIT(-1); | ||
| 138 | + } | ||
| 139 | + | ||
| 140 | + SHERPA_ONNX_READ_META_DATA(meta_.sample_rate, "sample_rate"); | ||
| 141 | + SHERPA_ONNX_READ_META_DATA(meta_.n_fft, "n_fft"); | ||
| 142 | + SHERPA_ONNX_READ_META_DATA(meta_.hop_length, "hop_length"); | ||
| 143 | + SHERPA_ONNX_READ_META_DATA(meta_.window_length, "window_length"); | ||
| 144 | + SHERPA_ONNX_READ_META_DATA_STR(meta_.window_type, "window_type"); | ||
| 145 | + SHERPA_ONNX_READ_META_DATA(meta_.version, "version"); | ||
| 146 | + | ||
| 147 | + SHERPA_ONNX_READ_META_DATA_VEC(meta_.conv_cache_shape, "conv_cache_shape"); | ||
| 148 | + SHERPA_ONNX_READ_META_DATA_VEC(meta_.tra_cache_shape, "tra_cache_shape"); | ||
| 149 | + SHERPA_ONNX_READ_META_DATA_VEC(meta_.inter_cache_shape, | ||
| 150 | + "inter_cache_shape"); | ||
| 151 | + } | ||
| 152 | + | ||
| 153 | + private: | ||
| 154 | + OfflineSpeechDenoiserModelConfig config_; | ||
| 155 | + OfflineSpeechDenoiserGtcrnModelMetaData meta_; | ||
| 156 | + | ||
| 157 | + Ort::Env env_; | ||
| 158 | + Ort::SessionOptions sess_opts_; | ||
| 159 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 160 | + | ||
| 161 | + std::unique_ptr<Ort::Session> sess_; | ||
| 162 | + | ||
| 163 | + std::vector<std::string> input_names_; | ||
| 164 | + std::vector<const char *> input_names_ptr_; | ||
| 165 | + | ||
| 166 | + std::vector<std::string> output_names_; | ||
| 167 | + std::vector<const char *> output_names_ptr_; | ||
| 168 | +}; | ||
| 169 | + | ||
| 170 | +OfflineSpeechDenoiserGtcrnModel::~OfflineSpeechDenoiserGtcrnModel() = default; | ||
| 171 | + | ||
| 172 | +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel( | ||
| 173 | + const OfflineSpeechDenoiserModelConfig &config) | ||
| 174 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 175 | + | ||
| 176 | +template <typename Manager> | ||
| 177 | +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel( | ||
| 178 | + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config) | ||
| 179 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 180 | + | ||
| 181 | +OfflineSpeechDenoiserGtcrnModel::States | ||
| 182 | +OfflineSpeechDenoiserGtcrnModel::GetInitStates() const { | ||
| 183 | + return impl_->GetInitStates(); | ||
| 184 | +} | ||
| 185 | + | ||
| 186 | +std::pair<Ort::Value, OfflineSpeechDenoiserGtcrnModel::States> | ||
| 187 | +OfflineSpeechDenoiserGtcrnModel::Run(Ort::Value x, States states) const { | ||
| 188 | + return impl_->Run(std::move(x), std::move(states)); | ||
| 189 | +} | ||
| 190 | + | ||
| 191 | +const OfflineSpeechDenoiserGtcrnModelMetaData & | ||
| 192 | +OfflineSpeechDenoiserGtcrnModel::GetMetaData() const { | ||
| 193 | + return impl_->GetMetaData(); | ||
| 194 | +} | ||
| 195 | + | ||
| 196 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <utility> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineSpeechDenoiserGtcrnModel { | ||
| 18 | + public: | ||
| 19 | + ~OfflineSpeechDenoiserGtcrnModel(); | ||
| 20 | + explicit OfflineSpeechDenoiserGtcrnModel( | ||
| 21 | + const OfflineSpeechDenoiserModelConfig &config); | ||
| 22 | + | ||
| 23 | + template <typename Manager> | ||
| 24 | + OfflineSpeechDenoiserGtcrnModel( | ||
| 25 | + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config); | ||
| 26 | + | ||
| 27 | + using States = std::vector<Ort::Value>; | ||
| 28 | + | ||
| 29 | + States GetInitStates() const; | ||
| 30 | + | ||
| 31 | + std::pair<Ort::Value, States> Run(Ort::Value x, States states) const; | ||
| 32 | + | ||
| 33 | + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const; | ||
| 34 | + | ||
| 35 | + private: | ||
| 36 | + class Impl; | ||
| 37 | + std::unique_ptr<Impl> impl_; | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx | ||
| 41 | + | ||
| 42 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" | ||
| 5 | + | ||
| 6 | +#include <memory> | ||
| 7 | + | ||
| 8 | +#if __ANDROID_API__ >= 9 | ||
| 9 | +#include "android/asset_manager.h" | ||
| 10 | +#include "android/asset_manager_jni.h" | ||
| 11 | +#endif | ||
| 12 | + | ||
| 13 | +#if __OHOS__ | ||
| 14 | +#include "rawfile/raw_file_manager.h" | ||
| 15 | +#endif | ||
| 16 | + | ||
| 17 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create( | ||
| 23 | + const OfflineSpeechDenoiserConfig &config) { | ||
| 24 | + if (!config.model.gtcrn.model.empty()) { | ||
| 25 | + return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(config); | ||
| 26 | + } | ||
| 27 | + SHERPA_ONNX_LOGE("Please provide a speech denoising model."); | ||
| 28 | + return nullptr; | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +template <typename Manager> | ||
| 32 | +std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create( | ||
| 33 | + Manager *mgr, const OfflineSpeechDenoiserConfig &config) { | ||
| 34 | + if (!config.model.gtcrn.model.empty()) { | ||
| 35 | + return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(mgr, config); | ||
| 36 | + } | ||
| 37 | + SHERPA_ONNX_LOGE("Please provide a speech denoising model."); | ||
| 38 | + return nullptr; | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +#if __ANDROID_API__ >= 9 | ||
| 42 | +template std::unique_ptr<OfflineSpeechDenoiserImpl> | ||
| 43 | +OfflineSpeechDenoiserImpl::Create(AAssetManager *mgr, | ||
| 44 | + const OfflineSpeechDenoiserConfig &config); | ||
| 45 | +#endif | ||
| 46 | + | ||
| 47 | +#if __OHOS__ | ||
| 48 | +template std::unique_ptr<OfflineSpeechDenoiserImpl> | ||
| 49 | +OfflineSpeechDenoiserImpl::Create(NativeResourceManager *mgr, | ||
| 50 | + const OfflineSpeechDenoiserConfig &config); | ||
| 51 | +#endif | ||
| 52 | + | ||
| 53 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-speech-denoiser-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineSpeechDenoiserImpl { | ||
| 15 | + public: | ||
| 16 | + virtual ~OfflineSpeechDenoiserImpl() = default; | ||
| 17 | + | ||
| 18 | + static std::unique_ptr<OfflineSpeechDenoiserImpl> Create( | ||
| 19 | + const OfflineSpeechDenoiserConfig &config); | ||
| 20 | + | ||
| 21 | + template <typename Manager> | ||
| 22 | + static std::unique_ptr<OfflineSpeechDenoiserImpl> Create( | ||
| 23 | + Manager *mgr, const OfflineSpeechDenoiserConfig &config); | ||
| 24 | + | ||
| 25 | + virtual DenoisedAudio Run(const float *samples, int32_t n, | ||
| 26 | + int32_t sample_rate) const = 0; | ||
| 27 | + | ||
| 28 | + virtual int32_t GetSampleRate() const = 0; | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +void OfflineSpeechDenoiserModelConfig::Register(ParseOptions *po) { | ||
| 12 | + gtcrn.Register(po); | ||
| 13 | + | ||
| 14 | + po->Register("num-threads", &num_threads, | ||
| 15 | + "Number of threads to run the neural network"); | ||
| 16 | + | ||
| 17 | + po->Register("debug", &debug, | ||
| 18 | + "true to print model information while loading it."); | ||
| 19 | + | ||
| 20 | + po->Register("provider", &provider, | ||
| 21 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +bool OfflineSpeechDenoiserModelConfig::Validate() const { | ||
| 25 | + return gtcrn.Validate(); | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +std::string OfflineSpeechDenoiserModelConfig::ToString() const { | ||
| 29 | + std::ostringstream os; | ||
| 30 | + | ||
| 31 | + os << "OfflineSpeechDenoiserModelConfig("; | ||
| 32 | + os << "gtcrn=" << gtcrn.ToString() << ", "; | ||
| 33 | + os << "num_threads=" << num_threads << ", "; | ||
| 34 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 35 | + os << "provider=\"" << provider << "\")"; | ||
| 36 | + | ||
| 37 | + return os.str(); | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h" | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineSpeechDenoiserModelConfig { | ||
| 15 | + OfflineSpeechDenoiserGtcrnModelConfig gtcrn; | ||
| 16 | + | ||
| 17 | + int32_t num_threads = 1; | ||
| 18 | + bool debug = false; | ||
| 19 | + std::string provider = "cpu"; | ||
| 20 | + | ||
| 21 | + OfflineSpeechDenoiserModelConfig() = default; | ||
| 22 | + | ||
| 23 | + OfflineSpeechDenoiserModelConfig(OfflineSpeechDenoiserGtcrnModelConfig gtcrn, | ||
| 24 | + int32_t num_threads, bool debug, | ||
| 25 | + const std::string &provider) | ||
| 26 | + : gtcrn(gtcrn), | ||
| 27 | + num_threads(num_threads), | ||
| 28 | + debug(debug), | ||
| 29 | + provider(provider) {} | ||
| 30 | + | ||
| 31 | + void Register(ParseOptions *po); | ||
| 32 | + bool Validate() const; | ||
| 33 | + | ||
| 34 | + std::string ToString() const; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +} // namespace sherpa_onnx | ||
| 38 | + | ||
| 39 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-speech-denoiser.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h" | ||
| 8 | + | ||
| 9 | +#if __ANDROID_API__ >= 9 | ||
| 10 | +#include "android/asset_manager.h" | ||
| 11 | +#include "android/asset_manager_jni.h" | ||
| 12 | +#endif | ||
| 13 | + | ||
| 14 | +#if __OHOS__ | ||
| 15 | +#include "rawfile/raw_file_manager.h" | ||
| 16 | +#endif | ||
| 17 | + | ||
| 18 | +namespace sherpa_onnx { | ||
| 19 | + | ||
| 20 | +void OfflineSpeechDenoiserConfig::Register(ParseOptions *po) { | ||
| 21 | + model.Register(po); | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +bool OfflineSpeechDenoiserConfig::Validate() const { return model.Validate(); } | ||
| 25 | + | ||
| 26 | +std::string OfflineSpeechDenoiserConfig::ToString() const { | ||
| 27 | + std::ostringstream os; | ||
| 28 | + | ||
| 29 | + os << "OfflineSpeechDenoiserConfig("; | ||
| 30 | + os << "model=" << model.ToString() << ")"; | ||
| 31 | + return os.str(); | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +template <typename Manager> | ||
| 35 | +OfflineSpeechDenoiser::OfflineSpeechDenoiser( | ||
| 36 | + Manager *mgr, const OfflineSpeechDenoiserConfig &config) | ||
| 37 | + : impl_(OfflineSpeechDenoiserImpl::Create(mgr, config)) {} | ||
| 38 | + | ||
| 39 | +OfflineSpeechDenoiser::OfflineSpeechDenoiser( | ||
| 40 | + const OfflineSpeechDenoiserConfig &config) | ||
| 41 | + : impl_(OfflineSpeechDenoiserImpl::Create(config)) {} | ||
| 42 | + | ||
| 43 | +OfflineSpeechDenoiser::~OfflineSpeechDenoiser() = default; | ||
| 44 | + | ||
| 45 | +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n, | ||
| 46 | + int32_t sample_rate) const { | ||
| 47 | + return impl_->Run(samples, n, sample_rate); | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +int32_t OfflineSpeechDenoiser::GetSampleRate() const { | ||
| 51 | + return impl_->GetSampleRate(); | ||
| 52 | +} | ||
| 53 | + | ||
| 54 | +#if __ANDROID_API__ >= 9 | ||
| 55 | +template OfflineSpeechDenoiser::OfflineSpeechDenoiser( | ||
| 56 | + AAssetManager *mgr, const OfflineSpeechDenoiserConfig &config); | ||
| 57 | +#endif | ||
| 58 | + | ||
| 59 | +#if __OHOS__ | ||
| 60 | +template OfflineSpeechDenoiser::OfflineSpeechDenoiser( | ||
| 61 | + NativeResourceManager *mgr, const OfflineSpeechDenoiserConfig &config); | ||
| 62 | +#endif | ||
| 63 | + | ||
| 64 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-speech-denoiser.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-speech-denoiser.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <string> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h" | ||
| 12 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +struct DenoisedAudio { | ||
| 17 | + std::vector<float> samples; | ||
| 18 | + int32_t sample_rate; | ||
| 19 | +}; | ||
| 20 | + | ||
| 21 | +struct OfflineSpeechDenoiserConfig { | ||
| 22 | + OfflineSpeechDenoiserModelConfig model; | ||
| 23 | + | ||
| 24 | + void Register(ParseOptions *po); | ||
| 25 | + bool Validate() const; | ||
| 26 | + | ||
| 27 | + std::string ToString() const; | ||
| 28 | +}; | ||
| 29 | + | ||
| 30 | +class OfflineSpeechDenoiserImpl; | ||
| 31 | + | ||
| 32 | +class OfflineSpeechDenoiser { | ||
| 33 | + public: | ||
| 34 | + explicit OfflineSpeechDenoiser(const OfflineSpeechDenoiserConfig &config); | ||
| 35 | + ~OfflineSpeechDenoiser(); | ||
| 36 | + | ||
| 37 | + template <typename Manager> | ||
| 38 | + OfflineSpeechDenoiser(Manager *mgr, | ||
| 39 | + const OfflineSpeechDenoiserConfig &config); | ||
| 40 | + | ||
| 41 | + /* | ||
| 42 | + * @param samples 1-D array of audio samples. Each sample is in the | ||
| 43 | + * range [-1, 1]. | ||
| 44 | + * @param n Number of samples | ||
| 45 | + * @param sample_rate Sample rate of the input samples | ||
| 46 | + * | ||
| 47 | + */ | ||
| 48 | + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const; | ||
| 49 | + | ||
| 50 | + /* | ||
| 51 | + * Return the sample rate of the denoised audio | ||
| 52 | + */ | ||
| 53 | + int32_t GetSampleRate() const; | ||
| 54 | + | ||
| 55 | + private: | ||
| 56 | + std::unique_ptr<OfflineSpeechDenoiserImpl> impl_; | ||
| 57 | +}; | ||
| 58 | + | ||
| 59 | +} // namespace sherpa_onnx | ||
| 60 | + | ||
| 61 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +#include <stdio.h> | ||
| 5 | + | ||
| 6 | +#include <chrono> // NOLINT | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/csrc/offline-speech-denoiser.h" | ||
| 9 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 10 | +#include "sherpa-onnx/csrc/wave-writer.h" | ||
| 11 | + | ||
| 12 | +int main(int32_t argc, char *argv[]) { | ||
| 13 | + const char *kUsageMessage = R"usage( | ||
| 14 | +Non-stremaing speech denoising with sherpa-onnx. | ||
| 15 | + | ||
| 16 | +Please visit | ||
| 17 | +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models | ||
| 18 | +to download models. | ||
| 19 | + | ||
| 20 | +Usage: | ||
| 21 | + | ||
| 22 | +(1) Use gtcrn models | ||
| 23 | + | ||
| 24 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | ||
| 25 | +./bin/sherpa-onnx-offline-denoiser \ | ||
| 26 | + --speech-denoiser-gtcrn-model=gtcrn_simple.onnx \ | ||
| 27 | + --input-wav input.wav \ | ||
| 28 | + --output-wav output_16k.wav | ||
| 29 | +)usage"; | ||
| 30 | + | ||
| 31 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 32 | + sherpa_onnx::OfflineSpeechDenoiserConfig config; | ||
| 33 | + std::string input_wave; | ||
| 34 | + std::string output_wave; | ||
| 35 | + | ||
| 36 | + config.Register(&po); | ||
| 37 | + po.Register("input-wav", &input_wave, "Path to input wav."); | ||
| 38 | + po.Register("output-wav", &output_wave, "Path to output wav"); | ||
| 39 | + | ||
| 40 | + po.Read(argc, argv); | ||
| 41 | + if (po.NumArgs() != 0) { | ||
| 42 | + fprintf(stderr, "Please don't give positional arguments\n"); | ||
| 43 | + po.PrintUsage(); | ||
| 44 | + exit(EXIT_FAILURE); | ||
| 45 | + } | ||
| 46 | + fprintf(stderr, "%s\n", config.ToString().c_str()); | ||
| 47 | + | ||
| 48 | + if (input_wave.empty()) { | ||
| 49 | + fprintf(stderr, "Please provide --input-wav\n"); | ||
| 50 | + po.PrintUsage(); | ||
| 51 | + exit(EXIT_FAILURE); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + if (output_wave.empty()) { | ||
| 55 | + fprintf(stderr, "Please provide --output-wav\n"); | ||
| 56 | + po.PrintUsage(); | ||
| 57 | + exit(EXIT_FAILURE); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + sherpa_onnx::OfflineSpeechDenoiser denoiser(config); | ||
| 61 | + int32_t sampling_rate = -1; | ||
| 62 | + bool is_ok = false; | ||
| 63 | + std::vector<float> samples = | ||
| 64 | + sherpa_onnx::ReadWave(input_wave, &sampling_rate, &is_ok); | ||
| 65 | + if (!is_ok) { | ||
| 66 | + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); | ||
| 67 | + return -1; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + fprintf(stderr, "Started\n"); | ||
| 71 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 72 | + auto result = denoiser.Run(samples.data(), samples.size(), sampling_rate); | ||
| 73 | + const auto end = std::chrono::steady_clock::now(); | ||
| 74 | + | ||
| 75 | + float elapsed_seconds = | ||
| 76 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 77 | + .count() / | ||
| 78 | + 1000.; | ||
| 79 | + | ||
| 80 | + fprintf(stderr, "Done\n"); | ||
| 81 | + is_ok = sherpa_onnx::WriteWave(output_wave, result.sample_rate, | ||
| 82 | + result.samples.data(), result.samples.size()); | ||
| 83 | + if (is_ok) { | ||
| 84 | + fprintf(stderr, "Saved to %s\n", output_wave.c_str()); | ||
| 85 | + } else { | ||
| 86 | + fprintf(stderr, "Failed to save to %s\n", output_wave.c_str()); | ||
| 87 | + } | ||
| 88 | + | ||
| 89 | + float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 90 | + fprintf(stderr, "num threads: %d\n", config.model.num_threads); | ||
| 91 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 92 | + float rtf = elapsed_seconds / duration; | ||
| 93 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 94 | + elapsed_seconds, duration, rtf); | ||
| 95 | +} |
-
请 注册 或 登录 后发表评论