Fangjun Kuang
Committed by GitHub

Add C++ runtime for speech enhancement GTCRN models (#1977)

See also https://github.com/Xiaobin-Rong/gtcrn
#!/usr/bin/env bash
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ -z $EXE ]; then
EXE=./build/bin/sherpa-onnx-offline-denoiser
fi
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run gtcrn"
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
$EXE \
--debug=1 \
--speech-denoiser-gtcrn-model=./gtcrn_simple.onnx \
--input-wav=./speech_with_noise.wav \
--output-wav=./enhanced_speech_16k.wav
rm ./gtcrn_simple.onnx
... ...
... ... @@ -10,6 +10,7 @@ on:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-kws.sh'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-speech-denoiser.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
... ... @@ -31,6 +32,7 @@ on:
paths:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-kws.sh'
- '.github/scripts/test-offline-speech-denoiser.sh'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-online-paraformer.sh'
- '.github/scripts/test-offline-transducer.sh'
... ... @@ -203,6 +205,15 @@ jobs:
overwrite: true
file: sherpa-onnx-*.tar.bz2
- name: Test offline speech denoiser
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-denoiser
.github/scripts/test-offline-speech-denoiser.sh
- name: Test offline TTS
if: matrix.with_tts == 'ON'
shell: bash
... ... @@ -215,6 +226,11 @@ jobs:
du -h -d1 .
- uses: actions/upload-artifact@v4
with:
name: speech-denoiser-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: ./*speech*.wav
- uses: actions/upload-artifact@v4
if: matrix.with_tts == 'ON'
with:
name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
... ...
... ... @@ -7,6 +7,7 @@ on:
tags:
- 'v[0-9]+.[0-9]+.[0-9]+*'
paths:
- '.github/scripts/test-offline-speech-denoiser.sh'
- '.github/workflows/macos.yaml'
- '.github/scripts/test-kws.sh'
- '.github/scripts/test-online-transducer.sh'
... ... @@ -28,6 +29,7 @@ on:
branches:
- master
paths:
- '.github/scripts/test-offline-speech-denoiser.sh'
- '.github/workflows/macos.yaml'
- '.github/scripts/test-kws.sh'
- '.github/scripts/test-online-transducer.sh'
... ... @@ -160,6 +162,15 @@ jobs:
overwrite: true
file: sherpa-onnx-*osx-universal2*.tar.bz2
- name: Test offline speech denoiser
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-denoiser
.github/scripts/test-offline-speech-denoiser.sh
- name: Test offline TTS
if: matrix.with_tts == 'ON'
shell: bash
... ...
... ... @@ -12,9 +12,9 @@
|--------------------------------|---------------|--------------------------|
| ✔️ | ✔️ | ✔️ |
| Keyword spotting | Add punctuation |
|------------------|-----------------|
| ✔️ | ✔️ |
| Keyword spotting | Add punctuation | Speech enhancement |
|------------------|-----------------|--------------------|
| ✔️ | ✔️ | ✔️ |
### Supported platforms
... ... @@ -198,6 +198,7 @@ We also have spaces built using WebAssembly. They are listed below:
| Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]|
| Punctuation | [Address][punct-models] |
| Speaker segmentation | [Address][speaker-segmentation-models] |
| Speech enhancement | [Address][speech-enhancement-models] |
</details>
... ... @@ -442,3 +443,4 @@ sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/iss
[Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
[NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9
[NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/
[speech-enhancement-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
... ...
function(download_kaldi_native_fbank)
include(FetchContent)
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.20.0.tar.gz")
set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.20.0.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=c6195b3cf374eef824644061d3c04f6b2a9267ae554169cbaa9865c89c1fe4f9")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=37c1aa230b00fe062791d800d8fc50aa3de215918d3dce6440699e67275d859e")
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
... ... @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.20.0.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.20.0.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.20.0.tar.gz
/tmp/kaldi-native-fbank-1.20.0.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.20.0.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.21.1.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.1.tar.gz
${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.1.tar.gz
/tmp/kaldi-native-fbank-1.21.1.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.21.1.tar.gz
)
foreach(f IN LISTS possible_file_locations)
... ...
... ... @@ -186,6 +186,14 @@ if(SHERPA_ONNX_ENABLE_TTS)
)
endif()
list(APPEND sources
offline-speech-denoiser-gtcrn-model-config.cc
offline-speech-denoiser-gtcrn-model.cc
offline-speech-denoiser-impl.cc
offline-speech-denoiser-model-config.cc
offline-speech-denoiser.cc
)
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
fast-clustering-config.cc
... ... @@ -301,6 +309,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
add_executable(sherpa-onnx-offline-denoiser sherpa-onnx-offline-denoiser.cc)
if(SHERPA_ONNX_ENABLE_TTS)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
... ... @@ -318,6 +327,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-parallel
sherpa-onnx-offline-punctuation
sherpa-onnx-offline-denoiser
sherpa-onnx-online-punctuation
)
if(SHERPA_ONNX_ENABLE_TTS)
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/istft.h"
#include "kaldi-native-fbank/csrc/stft.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/csrc/resample.h"
namespace sherpa_onnx {
class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl {
public:
explicit OfflineSpeechDenoiserGtcrnImpl(
const OfflineSpeechDenoiserConfig &config)
: model_(config.model) {}
template <typename Manager>
OfflineSpeechDenoiserGtcrnImpl(Manager *mgr,
const OfflineSpeechDenoiserConfig &config)
: model_(mgr, config.model) {}
DenoisedAudio Run(const float *samples, int32_t n,
int32_t sample_rate) const override {
SHERPA_ONNX_LOGE("n: %d, sample_rate: %d", n, sample_rate);
const auto &meta = model_.GetMetaData();
std::vector<float> tmp;
auto p = samples;
if (sample_rate != meta.sample_rate) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sample_rate, meta.sample_rate);
float min_freq = std::min<int32_t>(sample_rate, meta.sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
auto resampler = std::make_unique<LinearResample>(
sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width);
resampler->Resample(samples, n, true, &tmp);
p = tmp.data();
n = tmp.size();
}
knf::StftConfig stft_config;
stft_config.n_fft = meta.n_fft;
stft_config.hop_length = meta.hop_length;
stft_config.win_length = meta.window_length;
stft_config.window_type = meta.window_type;
if (stft_config.window_type == "hann_sqrt") {
auto window = knf::GetWindow("hann", stft_config.win_length);
for (auto &w : window) {
w = std::sqrt(w);
}
stft_config.window = std::move(window);
}
knf::Stft stft(stft_config);
knf::StftResult stft_result = stft.Compute(p, n);
auto states = model_.GetInitStates();
OfflineSpeechDenoiserGtcrnModel::States next_states;
knf::StftResult enhanced_stft_result;
enhanced_stft_result.num_frames = stft_result.num_frames;
for (int32_t i = 0; i < stft_result.num_frames; ++i) {
auto p = Process(stft_result, i, std::move(states), &next_states);
states = std::move(next_states);
enhanced_stft_result.real.insert(enhanced_stft_result.real.end(),
p.first.begin(), p.first.end());
enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(),
p.second.begin(), p.second.end());
}
knf::IStft istft(stft_config);
DenoisedAudio denoised_audio;
denoised_audio.sample_rate = meta.sample_rate;
denoised_audio.samples = istft.Compute(enhanced_stft_result);
return denoised_audio;
}
int32_t GetSampleRate() const override {
return model_.GetMetaData().sample_rate;
}
private:
std::pair<std::vector<float>, std::vector<float>> Process(
const knf::StftResult &stft_result, int32_t frame_index,
OfflineSpeechDenoiserGtcrnModel::States states,
OfflineSpeechDenoiserGtcrnModel::States *next_states) const {
const auto &meta = model_.GetMetaData();
int32_t n_fft = meta.n_fft;
std::vector<float> x((n_fft / 2 + 1) * 2);
const float *p_real =
stft_result.real.data() + frame_index * (n_fft / 2 + 1);
const float *p_imag =
stft_result.imag.data() + frame_index * (n_fft / 2 + 1);
for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
x[2 * i] = p_real[i];
x[2 * i + 1] = p_imag[i];
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 4> x_shape{1, n_fft / 2 + 1, 1, 2};
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value output{nullptr};
std::tie(output, *next_states) =
model_.Run(std::move(x_tensor), std::move(states));
std::vector<float> real(n_fft / 2 + 1);
std::vector<float> imag(n_fft / 2 + 1);
const auto *p = output.GetTensorData<float>();
for (int32_t i = 0; i < n_fft / 2 + 1; ++i) {
real[i] = p[2 * i];
imag[i] = p[2 * i + 1];
}
return {std::move(real), std::move(imag)};
}
private:
OfflineSpeechDenoiserGtcrnModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineSpeechDenoiserGtcrnModelConfig::Register(ParseOptions *po) {
po->Register("speech-denoiser-gtcrn-model", &model,
"Path to the gtcrn model for speech denoising");
}
bool OfflineSpeechDenoiserGtcrnModelConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --speech-denoiser-gtcrn-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("gtcrn model file '%s' does not exist", model.c_str());
return false;
}
return true;
}
std::string OfflineSpeechDenoiserGtcrnModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSpeechDenoiserGtcrnModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSpeechDenoiserGtcrnModelConfig {
std::string model;
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
#include <cstdint>
#include <string>
#include <vector>
namespace sherpa_onnx {
// please refer to
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/kokoro/add-meta-data.py
struct OfflineSpeechDenoiserGtcrnModelMetaData {
int32_t sample_rate = 0;
int32_t version = 1;
int32_t n_fft = 0;
int32_t hop_length = 0;
int32_t window_length = 0;
std::string window_type;
std::vector<int64_t> conv_cache_shape;
std::vector<int64_t> tra_cache_shape;
std::vector<int64_t> inter_cache_shape;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineSpeechDenoiserGtcrnModel::Impl {
public:
explicit Impl(const OfflineSpeechDenoiserModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.gtcrn.model);
Init(buf.data(), buf.size());
}
}
template <typename Manager>
Impl(Manager *mgr, const OfflineSpeechDenoiserModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.gtcrn.model);
Init(buf.data(), buf.size());
}
}
const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const {
return meta_;
}
States GetInitStates() const {
Ort::Value conv_cache = Ort::Value::CreateTensor<float>(
allocator_, meta_.conv_cache_shape.data(),
meta_.conv_cache_shape.size());
Ort::Value tra_cache = Ort::Value::CreateTensor<float>(
allocator_, meta_.tra_cache_shape.data(), meta_.tra_cache_shape.size());
Ort::Value inter_cache = Ort::Value::CreateTensor<float>(
allocator_, meta_.inter_cache_shape.data(),
meta_.inter_cache_shape.size());
Fill<float>(&conv_cache, 0);
Fill<float>(&tra_cache, 0);
Fill<float>(&inter_cache, 0);
std::vector<Ort::Value> states;
states.reserve(3);
states.push_back(std::move(conv_cache));
states.push_back(std::move(tra_cache));
states.push_back(std::move(inter_cache));
return states;
}
std::pair<Ort::Value, States> Run(Ort::Value x, States states) const {
std::vector<Ort::Value> inputs;
inputs.reserve(1 + states.size());
inputs.push_back(std::move(x));
for (auto &s : states) {
inputs.push_back(std::move(s));
}
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(out.size() - 1);
for (int32_t k = 1; k < out.size(); ++k) {
next_states.push_back(std::move(out[k]));
}
return {std::move(out[0]), std::move(next_states)};
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
os << "---gtcrn model---\n";
PrintModelMetadata(os, meta_data);
os << "----------input names----------\n";
int32_t i = 0;
for (const auto &s : input_names_) {
os << i << " " << s << "\n";
++i;
}
os << "----------output names----------\n";
i = 0;
for (const auto &s : output_names_) {
os << i << " " << s << "\n";
++i;
}
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
if (model_type != "gtcrn") {
SHERPA_ONNX_LOGE("Expect model type 'gtcrn'. Given: '%s'",
model_type.c_str());
SHERPA_ONNX_EXIT(-1);
}
SHERPA_ONNX_READ_META_DATA(meta_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA(meta_.n_fft, "n_fft");
SHERPA_ONNX_READ_META_DATA(meta_.hop_length, "hop_length");
SHERPA_ONNX_READ_META_DATA(meta_.window_length, "window_length");
SHERPA_ONNX_READ_META_DATA_STR(meta_.window_type, "window_type");
SHERPA_ONNX_READ_META_DATA(meta_.version, "version");
SHERPA_ONNX_READ_META_DATA_VEC(meta_.conv_cache_shape, "conv_cache_shape");
SHERPA_ONNX_READ_META_DATA_VEC(meta_.tra_cache_shape, "tra_cache_shape");
SHERPA_ONNX_READ_META_DATA_VEC(meta_.inter_cache_shape,
"inter_cache_shape");
}
private:
OfflineSpeechDenoiserModelConfig config_;
OfflineSpeechDenoiserGtcrnModelMetaData meta_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
};
OfflineSpeechDenoiserGtcrnModel::~OfflineSpeechDenoiserGtcrnModel() = default;
OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel(
const OfflineSpeechDenoiserModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel(
Manager *mgr, const OfflineSpeechDenoiserModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
OfflineSpeechDenoiserGtcrnModel::States
OfflineSpeechDenoiserGtcrnModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::pair<Ort::Value, OfflineSpeechDenoiserGtcrnModel::States>
OfflineSpeechDenoiserGtcrnModel::Run(Ort::Value x, States states) const {
return impl_->Run(std::move(x), std::move(states));
}
const OfflineSpeechDenoiserGtcrnModelMetaData &
OfflineSpeechDenoiserGtcrnModel::GetMetaData() const {
return impl_->GetMetaData();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
namespace sherpa_onnx {
class OfflineSpeechDenoiserGtcrnModel {
public:
~OfflineSpeechDenoiserGtcrnModel();
explicit OfflineSpeechDenoiserGtcrnModel(
const OfflineSpeechDenoiserModelConfig &config);
template <typename Manager>
OfflineSpeechDenoiserGtcrnModel(
Manager *mgr, const OfflineSpeechDenoiserModelConfig &config);
using States = std::vector<Ort::Value>;
States GetInitStates() const;
std::pair<Ort::Value, States> Run(Ort::Value x, States states) const;
const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-impl.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-impl.h"
namespace sherpa_onnx {
std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create(
const OfflineSpeechDenoiserConfig &config) {
if (!config.model.gtcrn.model.empty()) {
return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(config);
}
SHERPA_ONNX_LOGE("Please provide a speech denoising model.");
return nullptr;
}
template <typename Manager>
std::unique_ptr<OfflineSpeechDenoiserImpl> OfflineSpeechDenoiserImpl::Create(
Manager *mgr, const OfflineSpeechDenoiserConfig &config) {
if (!config.model.gtcrn.model.empty()) {
return std::make_unique<OfflineSpeechDenoiserGtcrnImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please provide a speech denoising model.");
return nullptr;
}
#if __ANDROID_API__ >= 9
template std::unique_ptr<OfflineSpeechDenoiserImpl>
OfflineSpeechDenoiserImpl::Create(AAssetManager *mgr,
const OfflineSpeechDenoiserConfig &config);
#endif
#if __OHOS__
template std::unique_ptr<OfflineSpeechDenoiserImpl>
OfflineSpeechDenoiserImpl::Create(NativeResourceManager *mgr,
const OfflineSpeechDenoiserConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-speech-denoiser-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
#include <memory>
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
namespace sherpa_onnx {
class OfflineSpeechDenoiserImpl {
public:
virtual ~OfflineSpeechDenoiserImpl() = default;
static std::unique_ptr<OfflineSpeechDenoiserImpl> Create(
const OfflineSpeechDenoiserConfig &config);
template <typename Manager>
static std::unique_ptr<OfflineSpeechDenoiserImpl> Create(
Manager *mgr, const OfflineSpeechDenoiserConfig &config);
virtual DenoisedAudio Run(const float *samples, int32_t n,
int32_t sample_rate) const = 0;
virtual int32_t GetSampleRate() const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
#include <string>
namespace sherpa_onnx {
void OfflineSpeechDenoiserModelConfig::Register(ParseOptions *po) {
gtcrn.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool OfflineSpeechDenoiserModelConfig::Validate() const {
return gtcrn.Validate();
}
std::string OfflineSpeechDenoiserModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSpeechDenoiserModelConfig(";
os << "gtcrn=" << gtcrn.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speech-denoiser-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSpeechDenoiserModelConfig {
OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
OfflineSpeechDenoiserModelConfig() = default;
OfflineSpeechDenoiserModelConfig(OfflineSpeechDenoiserGtcrnModelConfig gtcrn,
int32_t num_threads, bool debug,
const std::string &provider)
: gtcrn(gtcrn),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-speech-denoiser.h
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/csrc/offline-speech-denoiser-impl.h"
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
namespace sherpa_onnx {
void OfflineSpeechDenoiserConfig::Register(ParseOptions *po) {
model.Register(po);
}
bool OfflineSpeechDenoiserConfig::Validate() const { return model.Validate(); }
std::string OfflineSpeechDenoiserConfig::ToString() const {
std::ostringstream os;
os << "OfflineSpeechDenoiserConfig(";
os << "model=" << model.ToString() << ")";
return os.str();
}
template <typename Manager>
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
Manager *mgr, const OfflineSpeechDenoiserConfig &config)
: impl_(OfflineSpeechDenoiserImpl::Create(mgr, config)) {}
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
const OfflineSpeechDenoiserConfig &config)
: impl_(OfflineSpeechDenoiserImpl::Create(config)) {}
OfflineSpeechDenoiser::~OfflineSpeechDenoiser() = default;
DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
int32_t sample_rate) const {
return impl_->Run(samples, n, sample_rate);
}
int32_t OfflineSpeechDenoiser::GetSampleRate() const {
return impl_->GetSampleRate();
}
#if __ANDROID_API__ >= 9
template OfflineSpeechDenoiser::OfflineSpeechDenoiser(
AAssetManager *mgr, const OfflineSpeechDenoiserConfig &config);
#endif
#if __OHOS__
template OfflineSpeechDenoiser::OfflineSpeechDenoiser(
NativeResourceManager *mgr, const OfflineSpeechDenoiserConfig &config);
#endif
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speech-denoiser.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct DenoisedAudio {
std::vector<float> samples;
int32_t sample_rate;
};
struct OfflineSpeechDenoiserConfig {
OfflineSpeechDenoiserModelConfig model;
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class OfflineSpeechDenoiserImpl;
class OfflineSpeechDenoiser {
public:
explicit OfflineSpeechDenoiser(const OfflineSpeechDenoiserConfig &config);
~OfflineSpeechDenoiser();
template <typename Manager>
OfflineSpeechDenoiser(Manager *mgr,
const OfflineSpeechDenoiserConfig &config);
/*
* @param samples 1-D array of audio samples. Each sample is in the
* range [-1, 1].
* @param n Number of samples
* @param sample_rate Sample rate of the input samples
*
*/
DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
/*
* Return the sample rate of the denoised audio
*/
int32_t GetSampleRate() const;
private:
std::unique_ptr<OfflineSpeechDenoiserImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_
... ...
// sherpa-onnx/csrc/offline-tts-kokoro-model-metadata.h
// sherpa-onnx/csrc/offline-tts-kokoro-model-meta-data.h
//
// Copyright (c) 2025 Xiaomi Corporation
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-denoiser.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Non-stremaing speech denoising with sherpa-onnx.
Please visit
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
to download models.
Usage:
(1) Use gtcrn models
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
./bin/sherpa-onnx-offline-denoiser \
--speech-denoiser-gtcrn-model=gtcrn_simple.onnx \
--input-wav input.wav \
--output-wav output_16k.wav
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineSpeechDenoiserConfig config;
std::string input_wave;
std::string output_wave;
config.Register(&po);
po.Register("input-wav", &input_wave, "Path to input wav.");
po.Register("output-wav", &output_wave, "Path to output wav");
po.Read(argc, argv);
if (po.NumArgs() != 0) {
fprintf(stderr, "Please don't give positional arguments\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (input_wave.empty()) {
fprintf(stderr, "Please provide --input-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
if (output_wave.empty()) {
fprintf(stderr, "Please provide --output-wav\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
sherpa_onnx::OfflineSpeechDenoiser denoiser(config);
int32_t sampling_rate = -1;
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(input_wave, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str());
return -1;
}
fprintf(stderr, "Started\n");
const auto begin = std::chrono::steady_clock::now();
auto result = denoiser.Run(samples.data(), samples.size(), sampling_rate);
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "Done\n");
is_ok = sherpa_onnx::WriteWave(output_wave, result.sample_rate,
result.samples.data(), result.samples.size());
if (is_ok) {
fprintf(stderr, "Saved to %s\n", output_wave.c_str());
} else {
fprintf(stderr, "Failed to save to %s\n", output_wave.c_str());
}
float duration = samples.size() / static_cast<float>(sampling_rate);
fprintf(stderr, "num threads: %d\n", config.model.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
}
... ...