Committed by
GitHub
Add CXX API for speech enhancement GTCRN models (#1986)
正在显示
9 个修改的文件
包含
192 行增加
和
3 行删除
| @@ -108,6 +108,8 @@ jobs: | @@ -108,6 +108,8 @@ jobs: | ||
| 108 | cp -v inp_16k.wav denoised-wavs | 108 | cp -v inp_16k.wav denoised-wavs |
| 109 | cp -v enhanced_16k.wav denoised-wavs | 109 | cp -v enhanced_16k.wav denoised-wavs |
| 110 | 110 | ||
| 111 | + rm $name | ||
| 112 | + | ||
| 111 | - uses: actions/upload-artifact@v4 | 113 | - uses: actions/upload-artifact@v4 |
| 112 | with: | 114 | with: |
| 113 | name: denoised-wavs-${{ matrix.os }} | 115 | name: denoised-wavs-${{ matrix.os }} |
| @@ -81,6 +81,44 @@ jobs: | @@ -81,6 +81,44 @@ jobs: | ||
| 81 | otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib | 81 | otool -L ./install/lib/libsherpa-onnx-cxx-api.dylib |
| 82 | fi | 82 | fi |
| 83 | 83 | ||
| 84 | + - name: Test Speech Enhancement (GTCRN) | ||
| 85 | + shell: bash | ||
| 86 | + run: | | ||
| 87 | + name=speech-enhancement-gtcrn-cxx-api | ||
| 88 | + g++ -std=c++17 -o $name ./cxx-api-examples/$name.cc \ | ||
| 89 | + -I ./build/install/include \ | ||
| 90 | + -L ./build/install/lib/ \ | ||
| 91 | + -l sherpa-onnx-cxx-api \ | ||
| 92 | + -l sherpa-onnx-c-api \ | ||
| 93 | + -l onnxruntime | ||
| 94 | + | ||
| 95 | + ls -lh $name | ||
| 96 | + | ||
| 97 | + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH | ||
| 98 | + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH | ||
| 99 | + | ||
| 100 | + if [[ ${{ matrix.os }} == ubuntu-latest || ${{ matrix.os }} == ubuntu-22.04-arm ]]; then | ||
| 101 | + ldd ./$name | ||
| 102 | + echo "----" | ||
| 103 | + readelf -d ./$name | ||
| 104 | + fi | ||
| 105 | + | ||
| 106 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | ||
| 107 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav | ||
| 108 | + | ||
| 109 | + ./$name | ||
| 110 | + | ||
| 111 | + mkdir denoised-wavs | ||
| 112 | + cp -v inp_16k.wav denoised-wavs | ||
| 113 | + cp -v enhanced_16k.wav denoised-wavs | ||
| 114 | + | ||
| 115 | + rm $name | ||
| 116 | + | ||
| 117 | + - uses: actions/upload-artifact@v4 | ||
| 118 | + with: | ||
| 119 | + name: denoised-wavs-cxx-${{ matrix.os }} | ||
| 120 | + path: ./denoised-wavs/*.wav | ||
| 121 | + | ||
| 84 | - name: Test FireRedAsr | 122 | - name: Test FireRedAsr |
| 85 | shell: bash | 123 | shell: bash |
| 86 | run: | | 124 | run: | |
| @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -3,6 +3,9 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 3 | add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) | 3 | add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) |
| 4 | target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) | 4 | target_link_libraries(streaming-zipformer-cxx-api sherpa-onnx-cxx-api) |
| 5 | 5 | ||
| 6 | +add_executable(speech-enhancement-gtcrn-cxx-api ./speech-enhancement-gtcrn-cxx-api.cc) | ||
| 7 | +target_link_libraries(speech-enhancement-gtcrn-cxx-api sherpa-onnx-cxx-api) | ||
| 8 | + | ||
| 6 | add_executable(kws-cxx-api ./kws-cxx-api.cc) | 9 | add_executable(kws-cxx-api ./kws-cxx-api.cc) |
| 7 | target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api) | 10 | target_link_libraries(kws-cxx-api sherpa-onnx-cxx-api) |
| 8 | 11 |
| 1 | +// cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2025 Xiaomi Corporation | ||
| 4 | +// | ||
| 5 | +// We assume you have pre-downloaded model | ||
| 6 | +// from | ||
| 7 | +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models | ||
| 8 | +// | ||
| 9 | +// | ||
| 10 | +// An example command to download | ||
| 11 | +// clang-format off | ||
| 12 | +/* | ||
| 13 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | ||
| 14 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav | ||
| 15 | +*/ | ||
| 16 | +// clang-format on | ||
| 17 | +#include <chrono> // NOLINT | ||
| 18 | +#include <iostream> | ||
| 19 | +#include <string> | ||
| 20 | + | ||
| 21 | +#include "sherpa-onnx/c-api/cxx-api.h" | ||
| 22 | + | ||
| 23 | +int32_t main() { | ||
| 24 | + using namespace sherpa_onnx::cxx; // NOLINT | ||
| 25 | + | ||
| 26 | + OfflineSpeechDenoiserConfig config; | ||
| 27 | + std::string wav_filename = "./inp_16k.wav"; | ||
| 28 | + std::string out_wave_filename = "./enhanced_16k.wav"; | ||
| 29 | + | ||
| 30 | + config.model.gtcrn.model = "./gtcrn_simple.onnx"; | ||
| 31 | + | ||
| 32 | + auto sd = OfflineSpeechDenoiser::Create(config); | ||
| 33 | + if (!sd.Get()) { | ||
| 34 | + std::cerr << "Please check your config\n"; | ||
| 35 | + return -1; | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + Wave wave = ReadWave(wav_filename); | ||
| 39 | + if (wave.samples.empty()) { | ||
| 40 | + std::cerr << "Failed to read: '" << wav_filename << "'\n"; | ||
| 41 | + return -1; | ||
| 42 | + } | ||
| 43 | + | ||
| 44 | + std::cout << "Started\n"; | ||
| 45 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 46 | + auto denoised = | ||
| 47 | + sd.Run(wave.samples.data(), wave.samples.size(), wave.sample_rate); | ||
| 48 | + const auto end = std::chrono::steady_clock::now(); | ||
| 49 | + std::cout << "Done\n"; | ||
| 50 | + | ||
| 51 | + WriteWave(out_wave_filename, {denoised.samples, denoised.sample_rate}); | ||
| 52 | + | ||
| 53 | + const float elapsed_seconds = | ||
| 54 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 55 | + .count() / | ||
| 56 | + 1000.; | ||
| 57 | + float duration = wave.samples.size() / static_cast<float>(wave.sample_rate); | ||
| 58 | + float rtf = elapsed_seconds / duration; | ||
| 59 | + | ||
| 60 | + std::cout << "Saved to " << out_wave_filename << "\n"; | ||
| 61 | + printf("Duration: %.3fs\n", duration); | ||
| 62 | + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); | ||
| 63 | + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, | ||
| 64 | + duration, rtf); | ||
| 65 | +} |
| @@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const { | @@ -513,4 +513,49 @@ void KeywordSpotter::Reset(const OnlineStream *s) const { | ||
| 513 | SherpaOnnxResetKeywordStream(p_, s->Get()); | 513 | SherpaOnnxResetKeywordStream(p_, s->Get()); |
| 514 | } | 514 | } |
| 515 | 515 | ||
| 516 | +// ============================================================ | ||
| 517 | +// For Offline Speech Enhancement | ||
| 518 | +// ============================================================ | ||
| 519 | + | ||
| 520 | +OfflineSpeechDenoiser OfflineSpeechDenoiser::Create( | ||
| 521 | + const OfflineSpeechDenoiserConfig &config) { | ||
| 522 | + struct SherpaOnnxOfflineSpeechDenoiserConfig c; | ||
| 523 | + memset(&c, 0, sizeof(c)); | ||
| 524 | + | ||
| 525 | + c.model.gtcrn.model = config.model.gtcrn.model.c_str(); | ||
| 526 | + | ||
| 527 | + c.model.num_threads = config.model.num_threads; | ||
| 528 | + c.model.provider = config.model.provider.c_str(); | ||
| 529 | + c.model.debug = config.model.debug; | ||
| 530 | + | ||
| 531 | + auto p = SherpaOnnxCreateOfflineSpeechDenoiser(&c); | ||
| 532 | + | ||
| 533 | + return OfflineSpeechDenoiser(p); | ||
| 534 | +} | ||
| 535 | + | ||
| 536 | +void OfflineSpeechDenoiser::Destroy( | ||
| 537 | + const SherpaOnnxOfflineSpeechDenoiser *p) const { | ||
| 538 | + SherpaOnnxDestroyOfflineSpeechDenoiser(p); | ||
| 539 | +} | ||
| 540 | + | ||
| 541 | +OfflineSpeechDenoiser::OfflineSpeechDenoiser( | ||
| 542 | + const SherpaOnnxOfflineSpeechDenoiser *p) | ||
| 543 | + : MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser>(p) {} | ||
| 544 | + | ||
| 545 | +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n, | ||
| 546 | + int32_t sample_rate) const { | ||
| 547 | + auto audio = SherpaOnnxOfflineSpeechDenoiserRun(p_, samples, n, sample_rate); | ||
| 548 | + | ||
| 549 | + DenoisedAudio ans; | ||
| 550 | + ans.samples = {audio->samples, audio->samples + audio->n}; | ||
| 551 | + ans.sample_rate = audio->sample_rate; | ||
| 552 | + SherpaOnnxDestroyDenoisedAudio(audio); | ||
| 553 | + | ||
| 554 | + return ans; | ||
| 555 | +} | ||
| 556 | + | ||
| 557 | +int32_t OfflineSpeechDenoiser::GetSampleRate() const { | ||
| 558 | + return SherpaOnnxOfflineSpeechDenoiserGetSampleRate(p_); | ||
| 559 | +} | ||
| 560 | + | ||
| 516 | } // namespace sherpa_onnx::cxx | 561 | } // namespace sherpa_onnx::cxx |
| @@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter | @@ -464,6 +464,42 @@ class SHERPA_ONNX_API KeywordSpotter | ||
| 464 | explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p); | 464 | explicit KeywordSpotter(const SherpaOnnxKeywordSpotter *p); |
| 465 | }; | 465 | }; |
| 466 | 466 | ||
| 467 | +struct OfflineSpeechDenoiserGtcrnModelConfig { | ||
| 468 | + std::string model; | ||
| 469 | +}; | ||
| 470 | + | ||
| 471 | +struct OfflineSpeechDenoiserModelConfig { | ||
| 472 | + OfflineSpeechDenoiserGtcrnModelConfig gtcrn; | ||
| 473 | + int32_t num_threads = 1; | ||
| 474 | + int32_t debug = false; | ||
| 475 | + std::string provider = "cpu"; | ||
| 476 | +}; | ||
| 477 | + | ||
| 478 | +struct OfflineSpeechDenoiserConfig { | ||
| 479 | + OfflineSpeechDenoiserModelConfig model; | ||
| 480 | +}; | ||
| 481 | + | ||
| 482 | +struct DenoisedAudio { | ||
| 483 | + std::vector<float> samples; // in the range [-1, 1] | ||
| 484 | + int32_t sample_rate; | ||
| 485 | +}; | ||
| 486 | + | ||
| 487 | +class SHERPA_ONNX_API OfflineSpeechDenoiser | ||
| 488 | + : public MoveOnly<OfflineSpeechDenoiser, SherpaOnnxOfflineSpeechDenoiser> { | ||
| 489 | + public: | ||
| 490 | + static OfflineSpeechDenoiser Create( | ||
| 491 | + const OfflineSpeechDenoiserConfig &config); | ||
| 492 | + | ||
| 493 | + void Destroy(const SherpaOnnxOfflineSpeechDenoiser *p) const; | ||
| 494 | + | ||
| 495 | + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const; | ||
| 496 | + | ||
| 497 | + int32_t GetSampleRate() const; | ||
| 498 | + | ||
| 499 | + private: | ||
| 500 | + explicit OfflineSpeechDenoiser(const SherpaOnnxOfflineSpeechDenoiser *p); | ||
| 501 | +}; | ||
| 502 | + | ||
| 467 | } // namespace sherpa_onnx::cxx | 503 | } // namespace sherpa_onnx::cxx |
| 468 | 504 | ||
| 469 | #endif // SHERPA_ONNX_C_API_CXX_API_H_ | 505 | #endif // SHERPA_ONNX_C_API_CXX_API_H_ |
-
请 注册 或 登录 后发表评论