Fangjun Kuang
Committed by GitHub

Add CXX API for speech enhancement GTCRN models (#1986)

@@ -108,6 +108,8 @@ jobs: @@ -108,6 +108,8 @@ jobs:
108 cp -v inp_16k.wav denoised-wavs 108 cp -v inp_16k.wav denoised-wavs
109 cp -v enhanced_16k.wav denoised-wavs 109 cp -v enhanced_16k.wav denoised-wavs
110 110
  111 + rm $name
  112 +
111 - uses: actions/upload-artifact@v4 113 - uses: actions/upload-artifact@v4
112 with: 114 with:
113 name: denoised-wavs-${{ matrix.os }} 115 name: denoised-wavs-${{ matrix.os }}
@@ -81,6 +81,44 @@ jobs: @@ -81,6 +81,44 @@ jobs:
81 otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib 81 otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib
82 fi 82 fi
83 83
  84 + - name: Test Speech Enhancement (GTCRN)
  85 + shell: bash
  86 + run: |
  87 + name=speech-enhancement-gtcrn-cxx-api
  88 + g++ -std=c++17 -o $name ./cxx-api-examples/$name.cc \
  89 + -I ./build/install/include \
  90 + -L ./build/install/lib/ \
  91 + -l sherpa-onnx-cxx-api \
  92 + -l sherpa-onnx-c-api \
  93 + -l onnxruntime
  94 +
  95 + ls -lh $name
  96 +
  97 + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH
  98 + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH
  99 +
  100 + if [[ ${{ matrix.os }} == ubuntu-latest || ${{ matrix.os }} == ubuntu-22.04-arm ]]; then
  101 + ldd ./$name
  102 + echo "----"
  103 + readelf -d ./$name
  104 + fi
  105 +
  106 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  107 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
  108 +
  109 + ./$name
  110 +
  111 + mkdir denoised-wavs
  112 + cp -v inp_16k.wav denoised-wavs
  113 + cp -v enhanced_16k.wav denoised-wavs
  114 +
  115 + rm $name
  116 +
  117 + - uses: actions/upload-artifact@v4
  118 + with:
  119 + name: denoised-wavs-cxx-${{ matrix.os }}
  120 + path: ./denoised-wavs/*.wav
  121 +
84 - name: Test FireRedAsr 122 - name: Test FireRedAsr
85 shell: bash 123 shell: bash
86 run: | 124 run: |
@@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR})
3 add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) 3 add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc)
4 target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) 4 target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api)
5 5
  6 +add_executable(speech-enhancement-gtcrn-cxx-api ./speech-enhancement-gtcrn-cxx-api.cc)
  7 +target_link_libraries(speech-enhancement-gtcrn-cxx-api sherpa-onnx-cxx-api)
  8 +
6 add_executable(kws-cxx-api ./kws-cxx-api.cc) 9 add_executable(kws-cxx-api ./kws-cxx-api.cc)
7 target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api) 10 target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api)
8 11
1 -// cxx-api-examples/kokoro-tts-zh-en-cxx-api.c 1 +// cxx-api-examples/kokoro-tts-zh-en-cxx-api.cc
2 // 2 //
3 // Copyright (c) 2025 Xiaomi Corporation 3 // Copyright (c) 2025 Xiaomi Corporation
4 4
1 -// cxx-api-examples/matcha-tts-en-cxx-api.c 1 +// cxx-api-examples/matcha-tts-en-cxx-api.cc
2 // 2 //
3 // Copyright (c) 2025 Xiaomi Corporation 3 // Copyright (c) 2025 Xiaomi Corporation
4 4
1 -// cxx-api-examples/matcha-tts-zh-cxx-api.c 1 +// cxx-api-examples/matcha-tts-zh-cxx-api.cc
2 // 2 //
3 // Copyright (c) 2025 Xiaomi Corporation 3 // Copyright (c) 2025 Xiaomi Corporation
4 4
  1 +// cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +//
  5 +// We assume you have pre-downloaded model
  6 +// from
  7 +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
  8 +//
  9 +//
  10 +// An example command to download
  11 +// clang-format off
  12 +/*
  13 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  14 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
  15 +*/
  16 +// clang-format on
  17 +#include <chrono> // NOLINT
  18 +#include <iostream>
  19 +#include <string>
  20 +
  21 +#include "sherpa-onnx/c-api/cxx-api.h"
  22 +
  23 +int32_t main() {
  24 + using namespace sherpa_onnx::cxx; // NOLINT
  25 +
  26 + OfflineSpeechDenoiserConfig config;
  27 + std::string wav_filename = "./inp_16k.wav";
  28 + std::string out_wave_filename = "./enhanced_16k.wav";
  29 +
  30 + config.model.gtcrn.model = "./gtcrn_simple.onnx";
  31 +
  32 + auto sd = OfflineSpeechDenoiser::Create(config);
  33 + if (!sd.Get()) {
  34 + std::cerr << "Please check your config\n";
  35 + return -1;
  36 + }
  37 +
  38 + Wave wave = ReadWave(wav_filename);
  39 + if (wave.samples.empty()) {
  40 + std::cerr << "Failed to read: '" << wav_filename << "'\n";
  41 + return -1;
  42 + }
  43 +
  44 + std::cout << "Started\n";
  45 + const auto begin = std::chrono::steady_clock::now();
  46 + auto denoised =
  47 + sd.Run(wave.samples.data(), wave.samples.size(), wave.sample_rate);
  48 + const auto end = std::chrono::steady_clock::now();
  49 + std::cout << "Done\n";
  50 +
  51 + WriteWave(out_wave_filename, {denoised.samples, denoised.sample_rate});
  52 +
  53 + const float elapsed_seconds =
  54 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  55 + .count() /
  56 + 1000.;
  57 + float duration = wave.samples.size() / static_cast<float>(wave.sample_rate);
  58 + float rtf = elapsed_seconds / duration;
  59 +
  60 + std::cout << "Saved to " << out_wave_filename << "\n";
  61 + printf("Duration: %.3fs\n", duration);
  62 + printf("Elapsed seconds: %.3fs\n", elapsed_seconds);
  63 + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds,
  64 + duration, rtf);
  65 +}
@@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const { @@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const {
513 SherpaOnnxResetKeywordStream(p_, s->Get()); 513 SherpaOnnxResetKeywordStream(p_, s->Get());
514 } 514 }
515 515
  516 +// ============================================================
  517 +// For Offline Speech Enhancement
  518 +// ============================================================
  519 +
  520 +OfflineSpeechDenoiser OfflineSpeechDenoiser::Create(
  521 + const OfflineSpeechDenoiserConfig &config) {
  522 + struct SherpaOnnxOfflineSpeechDenoiserConfig c;
  523 + memset(&c, 0, sizeof(c));
  524 +
  525 + c.model.gtcrn.model = config.model.gtcrn.model.c_str();
  526 +
  527 + c.model.num_threads = config.model.num_threads;
  528 + c.model.provider = config.model.provider.c_str();
  529 + c.model.debug = config.model.debug;
  530 +
  531 + auto p = SherpaOnnxCreateOfflineSpeechDenoiser(&c);
  532 +
  533 + return OfflineSpeechDenoiser(p);
  534 +}
  535 +
  536 +void OfflineSpeechDenoiser::Destroy(
  537 + const SherpaOnnxOfflineSpeechDenoiser *p) const {
  538 + SherpaOnnxDestroyOfflineSpeechDenoiser(p);
  539 +}
  540 +
  541 +OfflineSpeechDenoiser::OfflineSpeechDenoiser(
  542 + const SherpaOnnxOfflineSpeechDenoiser *p)
  543 + : MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser>(p) {}
  544 +
  545 +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n,
  546 + int32_t sample_rate) const {
  547 + auto audio = SherpaOnnxOfflineSpeechDenoiserRun(p_, samples, n, sample_rate);
  548 +
  549 + DenoisedAudio ans;
  550 + ans.samples = {audio->samples, audio->samples + audio->n};
  551 + ans.sample_rate = audio->sample_rate;
  552 + SherpaOnnxDestroyDenoisedAudio(audio);
  553 +
  554 + return ans;
  555 +}
  556 +
  557 +int32_t OfflineSpeechDenoiser::GetSampleRate() const {
  558 + return SherpaOnnxOfflineSpeechDenoiserGetSampleRate(p_);
  559 +}
  560 +
516 } // namespace sherpa_onnx::cxx 561 } // namespace sherpa_onnx::cxx
@@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter @@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter
464 explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p); 464 explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p);
465 }; 465 };
466 466
  467 +struct OfflineSpeechDenoiserGtcrnModelConfig {
  468 + std::string model;
  469 +};
  470 +
  471 +struct OfflineSpeechDenoiserModelConfig {
  472 + OfflineSpeechDenoiserGtcrnModelConfig gtcrn;
  473 + int32_t num_threads = 1;
  474 + int32_t debug = false;
  475 + std::string provider = "cpu";
  476 +};
  477 +
  478 +struct OfflineSpeechDenoiserConfig {
  479 + OfflineSpeechDenoiserModelConfig model;
  480 +};
  481 +
  482 +struct DenoisedAudio {
  483 + std::vector<float> samples; // in the range [-1, 1]
  484 + int32_t sample_rate;
  485 +};
  486 +
  487 +class SHERPA_ONNX_API OfflineSpeechDenoiser
  488 + : public MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser> {
  489 + public:
  490 + static OfflineSpeechDenoiser Create(
  491 + const OfflineSpeechDenoiserConfig &config);
  492 +
  493 + void Destroy(const SherpaOnnxOfflineSpeechDenoiser *p) const;
  494 +
  495 + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const;
  496 +
  497 + int32_t GetSampleRate() const;
  498 +
  499 + private:
  500 + explicit OfflineSpeechDenoiser(const SherpaOnnxOfflineSpeechDenoiser *p);
  501 +};
  502 +
467 } // namespace sherpa_onnx::cxx 503 } // namespace sherpa_onnx::cxx
468 504
469 #endif // SHERPA_ONNX_C_API_CXX_API_H_ 505 #endif // SHERPA_ONNX_C_API_CXX_API_H_