Fangjun Kuang
Committed by GitHub

Add Python API for speech enhancement GTCRN models (#1978)

@@ -8,6 +8,13 @@ log() { @@ -8,6 +8,13 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test offline speech enhancement (GTCRN)"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  14 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
  15 +python3 ./python-api-examples/offline-speech-enhancement-gtcrn.py
  16 +ls -lh *.wav
  17 +
11 log "test offline zipformer (byte-level bpe, Chinese+English)" 18 log "test offline zipformer (byte-level bpe, Chinese+English)"
12 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-zh-en-2023-11-22.tar.bz2 19 curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-zh-en-2023-11-22.tar.bz2
13 tar xvf sherpa-onnx-zipformer-zh-en-2023-11-22.tar.bz2 20 tar xvf sherpa-onnx-zipformer-zh-en-2023-11-22.tar.bz2
  1 +#!/usr/bin/env python3
  2 +
  3 +"""
  4 +This file shows how to use the speech enhancement API.
  5 +
  6 +Please download files used this script from
  7 +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
  8 +
  9 +Example:
  10 +
  11 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  12 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
  13 +"""
  14 +
  15 +import time
  16 +from pathlib import Path
  17 +from typing import Tuple
  18 +
  19 +import numpy as np
  20 +import sherpa_onnx
  21 +import soundfile as sf
  22 +
  23 +
  24 +def create_speech_denoiser():
  25 + model_filename = "./gtcrn_simple.onnx"
  26 + if not Path(model_filename).is_file():
  27 + raise ValueError(
  28 + "Please first download a model from "
  29 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models"
  30 + )
  31 +
  32 + config = sherpa_onnx.OfflineSpeechDenoiserConfig(
  33 + model=sherpa_onnx.OfflineSpeechDenoiserModelConfig(
  34 + gtcrn=sherpa_onnx.OfflineSpeechDenoiserGtcrnModelConfig(
  35 + model=model_filename
  36 + ),
  37 + debug=False,
  38 + num_threads=1,
  39 + provider="cpu",
  40 + )
  41 + )
  42 + if not config.validate():
  43 + print(config)
  44 + raise ValueError("Errors in config. Please check previous error logs")
  45 + return sherpa_onnx.OfflineSpeechDenoiser(config)
  46 +
  47 +
  48 +def load_audio(filename: str) -> Tuple[np.ndarray, int]:
  49 + data, sample_rate = sf.read(
  50 + filename,
  51 + always_2d=True,
  52 + dtype="float32",
  53 + )
  54 + data = data[:, 0] # use only the first channel
  55 + samples = np.ascontiguousarray(data)
  56 + return samples, sample_rate
  57 +
  58 +
  59 +def main():
  60 + sd = create_speech_denoiser()
  61 + test_wave = "./speech_with_noise.wav"
  62 + if not Path(test_wave).is_file():
  63 + raise ValueError(
  64 + f"{test_wave} does not exist. You can download it from "
  65 + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models"
  66 + )
  67 +
  68 + samples, sample_rate = load_audio(test_wave)
  69 +
  70 + start = time.time()
  71 + denoised = sd(samples, sample_rate)
  72 + end = time.time()
  73 +
  74 + elapsed_seconds = end - start
  75 + audio_duration = len(samples) / sample_rate
  76 + real_time_factor = elapsed_seconds / audio_duration
  77 +
  78 + sf.write("./enhanced_16k.wav", denoised.samples, denoised.sample_rate)
  79 + print("Saved to ./enhanced_16k.wav")
  80 + print(f"Elapsed seconds: {elapsed_seconds:.3f}")
  81 + print(f"Audio duration in seconds: {audio_duration:.3f}")
  82 + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
  83 +
  84 +
  85 +if __name__ == "__main__":
  86 + main()
@@ -12,6 +12,7 @@ namespace sherpa_onnx { @@ -12,6 +12,7 @@ namespace sherpa_onnx {
12 12
13 struct OfflineSpeechDenoiserGtcrnModelConfig { 13 struct OfflineSpeechDenoiserGtcrnModelConfig {
14 std::string model; 14 std::string model;
  15 + OfflineSpeechDenoiserGtcrnModelConfig() = default;
15 16
16 void Register(ParseOptions *po); 17 void Register(ParseOptions *po);
17 bool Validate() const; 18 bool Validate() const;
@@ -18,6 +18,9 @@ set(srcs @@ -18,6 +18,9 @@ set(srcs
18 offline-punctuation.cc 18 offline-punctuation.cc
19 offline-recognizer.cc 19 offline-recognizer.cc
20 offline-sense-voice-model-config.cc 20 offline-sense-voice-model-config.cc
  21 + offline-speech-denoiser-gtcrn-model-config.cc
  22 + offline-speech-denoiser-model-config.cc
  23 + offline-speech-denoiser.cc
21 offline-stream.cc 24 offline-stream.cc
22 offline-tdnn-model-config.cc 25 offline-tdnn-model-config.cc
23 offline-transducer-model-config.cc 26 offline-transducer-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineSpeechDenoiserGtcrnModelConfig(py::module *m) {
  14 + using PyClass = OfflineSpeechDenoiserGtcrnModelConfig;
  15 + py::class_<PyClass>(*m, "OfflineSpeechDenoiserGtcrnModelConfig")
  16 + .def(py::init<const std::string &>(), py::arg("model") = "")
  17 + .def_readwrite("model", &PyClass::model)
  18 + .def("validate", &PyClass::Validate)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSpeechDenoiserGtcrnModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
  10 +#include "sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineSpeechDenoiserModelConfig(py::module *m) {
  15 + PybindOfflineSpeechDenoiserGtcrnModelConfig(m);
  16 +
  17 + using PyClass = OfflineSpeechDenoiserModelConfig;
  18 + py::class_<PyClass>(*m, "OfflineSpeechDenoiserModelConfig")
  19 + .def(py::init<>())
  20 + .def(py::init<const OfflineSpeechDenoiserGtcrnModelConfig &, int32_t,
  21 + bool, const std::string &>(),
  22 + py::arg("gtcrn") = OfflineSpeechDenoiserGtcrnModelConfig{},
  23 + py::arg("num_threads") = 1, py::arg("debug") = false,
  24 + py::arg("provider") = "cpu")
  25 + .def_readwrite("gtcrn", &PyClass::gtcrn)
  26 + .def_readwrite("num_threads", &PyClass::num_threads)
  27 + .def_readwrite("debug", &PyClass::debug)
  28 + .def_readwrite("provider", &PyClass::provider)
  29 + .def("validate", &PyClass::Validate)
  30 + .def("__str__", &PyClass::ToString);
  31 +}
  32 +
  33 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSpeechDenoiserModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/python/csrc/offline-speech-denoiser.cc
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-speech-denoiser.h"
  6 +
  7 +#include <vector>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
  10 +#include "sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOfflineSpeechDenoiserConfig(py::module *m) {
  15 + PybindOfflineSpeechDenoiserModelConfig(m);
  16 +
  17 + using PyClass = OfflineSpeechDenoiserConfig;
  18 +
  19 + py::class_<PyClass>(*m, "OfflineSpeechDenoiserConfig")
  20 + .def(py::init<>())
  21 + .def(py::init<const OfflineSpeechDenoiserModelConfig &>(),
  22 + py::arg("model") = OfflineSpeechDenoiserModelConfig{})
  23 + .def_readwrite("model", &PyClass::model)
  24 + .def("validate", &PyClass::Validate)
  25 + .def("__str__", &PyClass::ToString);
  26 +}
  27 +
  28 +void PybindDenoisedAudio(py::module *m) {
  29 + using PyClass = DenoisedAudio;
  30 + py::class_<PyClass>(*m, "DenoisedAudio")
  31 + .def_property_readonly(
  32 + "sample_rate", [](const PyClass &self) { return self.sample_rate; })
  33 + .def_property_readonly("samples",
  34 + [](const PyClass &self) { return self.samples; });
  35 +}
  36 +
  37 +void PybindOfflineSpeechDenoiser(py::module *m) {
  38 + PybindOfflineSpeechDenoiserConfig(m);
  39 + PybindDenoisedAudio(m);
  40 + using PyClass = OfflineSpeechDenoiser;
  41 + py::class_<PyClass>(*m, "OfflineSpeechDenoiser")
  42 + .def(py::init<const OfflineSpeechDenoiserConfig &>(), py::arg("config"),
  43 + py::call_guard<py::gil_scoped_release>())
  44 + .def(
  45 + "__call__",
  46 + [](const PyClass &self, const std::vector<float> &samples,
  47 + int32_t sample_rate) {
  48 + return self.Run(samples.data(), samples.size(), sample_rate);
  49 + },
  50 + py::call_guard<py::gil_scoped_release>())
  51 + .def(
  52 + "run",
  53 + [](const PyClass &self, const std::vector<float> &samples,
  54 + int32_t sample_rate) {
  55 + return self.Run(samples.data(), samples.size(), sample_rate);
  56 + },
  57 + py::call_guard<py::gil_scoped_release>())
  58 + .def_property_readonly("sample_rate", &PyClass::GetSampleRate);
  59 +}
  60 +
  61 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-speech-denoiser.h
  2 +//
  3 +// Copyright (c) 2025 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineSpeechDenoiser(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 #include "sherpa-onnx/python/csrc/offline-model-config.h" 16 #include "sherpa-onnx/python/csrc/offline-model-config.h"
17 #include "sherpa-onnx/python/csrc/offline-punctuation.h" 17 #include "sherpa-onnx/python/csrc/offline-punctuation.h"
18 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 18 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
  19 +#include "sherpa-onnx/python/csrc/offline-speech-denoiser.h"
19 #include "sherpa-onnx/python/csrc/offline-stream.h" 20 #include "sherpa-onnx/python/csrc/offline-stream.h"
20 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" 21 #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
21 #include "sherpa-onnx/python/csrc/online-lm-config.h" 22 #include "sherpa-onnx/python/csrc/online-lm-config.h"
@@ -87,6 +88,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -87,6 +88,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
87 #endif 88 #endif
88 89
89 PybindAlsa(&m); 90 PybindAlsa(&m);
  91 + PybindOfflineSpeechDenoiser(&m);
90 } 92 }
91 93
92 } // namespace sherpa_onnx 94 } // namespace sherpa_onnx
@@ -5,6 +5,7 @@ from _sherpa_onnx import ( @@ -5,6 +5,7 @@ from _sherpa_onnx import (
5 AudioTaggingConfig, 5 AudioTaggingConfig,
6 AudioTaggingModelConfig, 6 AudioTaggingModelConfig,
7 CircularBuffer, 7 CircularBuffer,
  8 + DenoisedAudio,
8 Display, 9 Display,
9 FastClustering, 10 FastClustering,
10 FastClusteringConfig, 11 FastClusteringConfig,
@@ -17,6 +18,10 @@ from _sherpa_onnx import ( @@ -17,6 +18,10 @@ from _sherpa_onnx import (
17 OfflineSpeakerDiarizationSegment, 18 OfflineSpeakerDiarizationSegment,
18 OfflineSpeakerSegmentationModelConfig, 19 OfflineSpeakerSegmentationModelConfig,
19 OfflineSpeakerSegmentationPyannoteModelConfig, 20 OfflineSpeakerSegmentationPyannoteModelConfig,
  21 + OfflineSpeechDenoiser,
  22 + OfflineSpeechDenoiserConfig,
  23 + OfflineSpeechDenoiserGtcrnModelConfig,
  24 + OfflineSpeechDenoiserModelConfig,
20 OfflineStream, 25 OfflineStream,
21 OfflineTts, 26 OfflineTts,
22 OfflineTtsConfig, 27 OfflineTtsConfig,