Fangjun Kuang
Committed by GitHub

Add CXX API for speech enhancement GTCRN models (#1986)

... ... @@ -108,6 +108,8 @@ jobs:
cp -v inp_16k.wav denoised-wavs
cp -v enhanced_16k.wav denoised-wavs
rm $name
- uses: actions/upload-artifact@v4
with:
name: denoised-wavs-${{ matrix.os }}
... ...
... ... @@ -81,6 +81,44 @@ jobs:
otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
fi
- name: Test Speech Enhancement (GTCRN)
shell: bash
run: |
name=speech-enhancement-gtcrn-cxx-api
g++ -std=c++17 -o $name ./cxx-api-examples/$name.cc \
-I ./build/install/include \
-L ./build/install/lib/ \
-l sherpa-onnx-cxx-api \
-l sherpa-onnx-c-api \
-l onnxruntime
ls -lh $name
export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
if [[ ${{ matrix.os }} == ubuntu-latest || ${{ matrix.os }} == ubuntu-22.04-arm ]]; then
ldd ./$name
echo "----"
readelf -d ./$name
fi
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
./$name
mkdir denoised-wavs
cp -v inp_16k.wav denoised-wavs
cp -v enhanced_16k.wav denoised-wavs
rm $name
- uses: actions/upload-artifact@v4
with:
name: denoised-wavs-cxx-${{ matrix.os }}
path: ./denoised-wavs/*.wav
- name: Test FireRedAsr
shell: bash
run: |
... ...
... ... @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
add_executable(speech-enhancement-gtcrn-cxx-api ./speech-enhancement-gtcrn-cxx-api.cc)
target_link_libraries(speech-enhancement-gtcrn-cxx-api sherpa-onnx-cxx-api)
add_executable(kws-cxx-api ./kws-cxx-api.cc)
target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
... ...
// cxx-api-examples/kokoro-tts-zh-en-cxx-api.c
// cxx-api-examples/kokoro-tts-zh-en-cxx-api.cc
//
// Copyright (c) 2025 Xiaomi Corporation
... ...
// cxx-api-examples/matcha-tts-en-cxx-api.c
// cxx-api-examples/matcha-tts-en-cxx-api.cc
//
// Copyright (c) 2025 Xiaomi Corporation
... ...
// cxx-api-examples/matcha-tts-zh-cxx-api.c
// cxx-api-examples/matcha-tts-zh-cxx-api.cc
//
// Copyright (c) 2025 Xiaomi Corporation
... ...
// cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc
//
// Copyright (c) 2025 Xiaomi Corporation
//
// We assume you have pre-downloaded model
// from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
//
//
// An example command to download
// clang-format off
/*
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
*/
// clang-format on
#include <chrono> // NOLINT
#include <iostream>
#include <string>
#include "sherpa-onnx/c-api/cxx-api.h"
int32_t main() {
using namespace sherpa_onnx::cxx; // NOLINT
OfflineSpeechDenoiserConfig config;
std::string wav_filename = "./inp_16k.wav";
std::string out_wave_filename = "./enhanced_16k.wav";
config.model.gtcrn.model = "./gtcrn_simple.onnx";
auto sd = OfflineSpeechDenoiser::Create(config);
if (!sd.Get()) {
std::cerr << "Please check your config\n";
return -1;
}
Wave wave = ReadWave(wav_filename);
if (wave.samples.empty()) {
std::cerr << "Failed to read: '" << wav_filename << "'\n";
return -1;
}
std::cout << "Started\n";
const auto begin = std::chrono::steady_clock::now();
auto denoised =
sd.Run(wave.samples.data(), wave.samples.size(), wave.sample_rate);
const auto end = std::chrono::steady_clock::now();
std::cout << "Done\n";
WriteWave(out_wave_filename, {denoised.samples, denoised.sample_rate});
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
float duration = wave.samples.size() / static_cast<float>(wave.sample_rate);
float rtf = elapsed_seconds / duration;
std::cout << "Saved to " << out_wave_filename << "\n";
printf("Duration: %.3fs\n", duration);
printf("Elapsed seconds: %.3fs\n", elapsed_seconds);
printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds,
duration, rtf);
}
... ...
... ... @@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const {
SherpaOnnxResetKeywordStream(p_, s->Get());
}
// ============================================================
// For Offline Speech Enhancement
// ============================================================
OfflineSpeechDenoiser OfflineSpeechDenoiser::Create(
const OfflineSpeechDenoiserConfig &config) {
struct SherpaOnnxOfflineSpeechDenoiserConfig c;
memset(&c, 0, sizeof(c));
c.model.gtcrn.model = config.model.gtcrn.model.c_str();
c.model.num_threads = config.model.num_threads;
c.model.provider = config.model.provider.c_str();
c.model.debug = config.model.debug;
auto p = SherpaOnnxCreateOfflineSpeechDenoiser(&c);
return OfflineSpeechDenoiser(p);
}
void OfflineSpeechDenoiser::Destroy(
const SherpaOnnxOfflineSpeechDenoiser *p) const {
SherpaOnnxDestroyOfflineSpeechDenoiser(p);
}
OfflineSpeechDenoiser::OfflineSpeechDenoiser(
const SherpaOnnxOfflineSpeechDenoiser *p)
: MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser>(p) {}
DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
int32_t sample_rate) const {
auto audio = SherpaOnnxOfflineSpeechDenoiserRun(p_, samples, n, sample_rate);
DenoisedAudio ans;
ans.samples = {audio->samples, audio->samples + audio->n};
ans.sample_rate = audio->sample_rate;
SherpaOnnxDestroyDenoisedAudio(audio);
return ans;
}
int32_t OfflineSpeechDenoiser::GetSampleRate() const {
return SherpaOnnxOfflineSpeechDenoiserGetSampleRate(p_);
}
} // namespace sherpa_onnx::cxx
... ...
... ... @@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter
explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
};
struct OfflineSpeechDenoiserGtcrnModelConfig {
std::string model;
};
struct OfflineSpeechDenoiserModelConfig {
OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
int32_t num_threads = 1;
int32_t debug = false;
std::string provider = "cpu";
};
struct OfflineSpeechDenoiserConfig {
OfflineSpeechDenoiserModelConfig model;
};
struct DenoisedAudio {
std::vector<float> samples; // in the range [-1, 1]
int32_t sample_rate;
};
class SHERPA_ONNX_API OfflineSpeechDenoiser
: public MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser> {
public:
static OfflineSpeechDenoiser Create(
const OfflineSpeechDenoiserConfig &config);
void Destroy(const SherpaOnnxOfflineSpeechDenoiser *p) const;
DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
int32_t GetSampleRate() const;
private:
explicit OfflineSpeechDenoiser(const SherpaOnnxOfflineSpeechDenoiser *p);
};
} // namespace sherpa_onnx::cxx
#endif // SHERPA_ONNX_C_API_CXX_API_H_
... ...