Fangjun Kuang
Committed by GitHub

Add Python API for speech enhancement GTCRN models (#1978)

... ... @@ -8,6 +8,13 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "test offline speech enhancement (GTCRN)"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
python3 ./python-api-examples/offline-speech-enhancement-gtcrn.py
ls -lh *.wav
log "test offline zipformer (byte-level bpe, Chinese+English)"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-zh-en-2023-11-22.tar.bz2
tar xvf sherpa-onnx-zipformer-zh-en-2023-11-22.tar.bz2
... ...
#!/usr/bin/env python3
"""
This file shows how to use the speech enhancement API.
Please download files used this script from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
Example:
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav
"""
import time
from pathlib import Path
from typing import Tuple
import numpy as np
import sherpa_onnx
import soundfile as sf
def create_speech_denoiser():
model_filename = "./gtcrn_simple.onnx"
if not Path(model_filename).is_file():
raise ValueError(
"Please first download a model from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models"
)
config = sherpa_onnx.OfflineSpeechDenoiserConfig(
model=sherpa_onnx.OfflineSpeechDenoiserModelConfig(
gtcrn=sherpa_onnx.OfflineSpeechDenoiserGtcrnModelConfig(
model=model_filename
),
debug=False,
num_threads=1,
provider="cpu",
)
)
if not config.validate():
print(config)
raise ValueError("Errors in config. Please check previous error logs")
return sherpa_onnx.OfflineSpeechDenoiser(config)
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def main():
sd = create_speech_denoiser()
test_wave = "./speech_with_noise.wav"
if not Path(test_wave).is_file():
raise ValueError(
f"{test_wave} does not exist. You can download it from "
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models"
)
samples, sample_rate = load_audio(test_wave)
start = time.time()
denoised = sd(samples, sample_rate)
end = time.time()
elapsed_seconds = end - start
audio_duration = len(samples) / sample_rate
real_time_factor = elapsed_seconds / audio_duration
sf.write("./enhanced_16k.wav", denoised.samples, denoised.sample_rate)
print("Saved to ./enhanced_16k.wav")
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
print(f"Audio duration in seconds: {audio_duration:.3f}")
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
if __name__ == "__main__":
main()
... ...
... ... @@ -12,6 +12,7 @@ namespace sherpa_onnx {
struct OfflineSpeechDenoiserGtcrnModelConfig {
std::string model;
OfflineSpeechDenoiserGtcrnModelConfig() = default;
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -18,6 +18,9 @@ set(srcs
offline-punctuation.cc
offline-recognizer.cc
offline-sense-voice-model-config.cc
offline-speech-denoiser-gtcrn-model-config.cc
offline-speech-denoiser-model-config.cc
offline-speech-denoiser.cc
offline-stream.cc
offline-tdnn-model-config.cc
offline-transducer-model-config.cc
... ...
// sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.h"
#include <string>
#include "sherpa-onnx/csrc/offline-speech-denoiser-gtcrn-model-config.h"
namespace sherpa_onnx {
void PybindOfflineSpeechDenoiserGtcrnModelConfig(py::module *m) {
using PyClass = OfflineSpeechDenoiserGtcrnModelConfig;
py::class_<PyClass>(*m, "OfflineSpeechDenoiserGtcrnModelConfig")
.def(py::init<const std::string &>(), py::arg("model") = "")
.def_readwrite("model", &PyClass::model)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeechDenoiserGtcrnModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_
... ...
// sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.h"
#include <string>
#include "sherpa-onnx/csrc/offline-speech-denoiser-model-config.h"
#include "sherpa-onnx/python/csrc/offline-speech-denoiser-gtcrn-model-config.h"
namespace sherpa_onnx {
void PybindOfflineSpeechDenoiserModelConfig(py::module *m) {
PybindOfflineSpeechDenoiserGtcrnModelConfig(m);
using PyClass = OfflineSpeechDenoiserModelConfig;
py::class_<PyClass>(*m, "OfflineSpeechDenoiserModelConfig")
.def(py::init<>())
.def(py::init<const OfflineSpeechDenoiserGtcrnModelConfig &, int32_t,
bool, const std::string &>(),
py::arg("gtcrn") = OfflineSpeechDenoiserGtcrnModelConfig{},
py::arg("num_threads") = 1, py::arg("debug") = false,
py::arg("provider") = "cpu")
.def_readwrite("gtcrn", &PyClass::gtcrn)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeechDenoiserModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_
... ...
// sherpa-onnx/python/csrc/offline-speech-denoiser.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speech-denoiser.h"
#include <vector>
#include "sherpa-onnx/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/python/csrc/offline-speech-denoiser-model-config.h"
namespace sherpa_onnx {
void PybindOfflineSpeechDenoiserConfig(py::module *m) {
PybindOfflineSpeechDenoiserModelConfig(m);
using PyClass = OfflineSpeechDenoiserConfig;
py::class_<PyClass>(*m, "OfflineSpeechDenoiserConfig")
.def(py::init<>())
.def(py::init<const OfflineSpeechDenoiserModelConfig &>(),
py::arg("model") = OfflineSpeechDenoiserModelConfig{})
.def_readwrite("model", &PyClass::model)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
void PybindDenoisedAudio(py::module *m) {
using PyClass = DenoisedAudio;
py::class_<PyClass>(*m, "DenoisedAudio")
.def_property_readonly(
"sample_rate", [](const PyClass &self) { return self.sample_rate; })
.def_property_readonly("samples",
[](const PyClass &self) { return self.samples; });
}
void PybindOfflineSpeechDenoiser(py::module *m) {
PybindOfflineSpeechDenoiserConfig(m);
PybindDenoisedAudio(m);
using PyClass = OfflineSpeechDenoiser;
py::class_<PyClass>(*m, "OfflineSpeechDenoiser")
.def(py::init<const OfflineSpeechDenoiserConfig &>(), py::arg("config"),
py::call_guard<py::gil_scoped_release>())
.def(
"__call__",
[](const PyClass &self, const std::vector<float> &samples,
int32_t sample_rate) {
return self.Run(samples.data(), samples.size(), sample_rate);
},
py::call_guard<py::gil_scoped_release>())
.def(
"run",
[](const PyClass &self, const std::vector<float> &samples,
int32_t sample_rate) {
return self.Run(samples.data(), samples.size(), sample_rate);
},
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("sample_rate", &PyClass::GetSampleRate);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-speech-denoiser.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeechDenoiser(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_
... ...
... ... @@ -16,6 +16,7 @@
#include "sherpa-onnx/python/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-punctuation.h"
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
#include "sherpa-onnx/python/csrc/offline-speech-denoiser.h"
#include "sherpa-onnx/python/csrc/offline-stream.h"
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/online-lm-config.h"
... ... @@ -87,6 +88,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
#endif
PybindAlsa(&m);
PybindOfflineSpeechDenoiser(&m);
}
} // namespace sherpa_onnx
... ...
... ... @@ -5,6 +5,7 @@ from _sherpa_onnx import (
AudioTaggingConfig,
AudioTaggingModelConfig,
CircularBuffer,
DenoisedAudio,
Display,
FastClustering,
FastClusteringConfig,
... ... @@ -17,6 +18,10 @@ from _sherpa_onnx import (
OfflineSpeakerDiarizationSegment,
OfflineSpeakerSegmentationModelConfig,
OfflineSpeakerSegmentationPyannoteModelConfig,
OfflineSpeechDenoiser,
OfflineSpeechDenoiserConfig,
OfflineSpeechDenoiserGtcrnModelConfig,
OfflineSpeechDenoiserModelConfig,
OfflineStream,
OfflineTts,
OfflineTtsConfig,
... ...