Fangjun Kuang
Committed by GitHub

Add Go API for speech enhancement GTCRN models (#1991)

... ... @@ -70,6 +70,13 @@ jobs:
run: |
gcc --version
- name: Test speech enhancement (GTCRN)
if: matrix.os != 'windows-latest'
shell: bash
run: |
cd go-api-examples/speech-enhancement-gtcrn/
./run.sh
- name: Test Keyword spotting
if: matrix.os != 'windows-latest'
shell: bash
... ...
... ... @@ -132,6 +132,15 @@ jobs:
name: ${{ matrix.os }}-libs
path: to-upload/
- name: Test speech enhancement (GTCRN)
shell: bash
run: |
cd scripts/go/_internal/speech-enhancement-gtcrn/
./run.sh
ls -lh
- name: Test audio tagging
shell: bash
run: |
... ...
module speech-enhancement-gtcrn
go 1.17
... ...
package main
import (
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
"log"
)
func main() {
log.SetFlags(log.LstdFlags | log.Lmicroseconds)
config := sherpa.OfflineSpeechDenoiserConfig{}
// Please download the models from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
config.Model.Gtcrn.Model = "./gtcrn_simple.onnx"
config.Model.NumThreads = 1
config.Model.Debug = 1
sd := sherpa.NewOfflineSpeechDenoiser(&config)
defer sherpa.DeleteOfflineSpeechDenoiser(sd)
wave_filename := "./inp_16k.wav"
wave := sherpa.ReadWave(wave_filename)
if wave == nil {
log.Printf("Failed to read %v\n", wave_filename)
return
}
log.Println("Started")
audio := sd.Run(wave.Samples, wave.SampleRate)
log.Println("Done!")
filename := "./enhanced-16k.wav"
ok := audio.Save(filename)
if !ok {
log.Fatalf("Failed to write", filename)
} else {
log.Println("Saved to ", filename)
}
}
... ...
#!/usr/bin/env bash
set -ex
if [ ! -f ./gtcrn_simple.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
fi
if [ ! -f ./inp_16k.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
fi
go mod tidy
go build
./speech-enhancement-gtcrn
... ...
/// Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
using System;
using System.Runtime.InteropServices;
namespace SherpaOnnx
... ...
module speech-enhancement-gtcrn
go 1.17
replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
... ...
../../../../go-api-examples/speech-enhancement-gtcrn/main.go
\ No newline at end of file
... ...
../../../../go-api-examples/speech-enhancement-gtcrn/run.sh
\ No newline at end of file
... ...
... ... @@ -959,7 +959,6 @@ func (tts *OfflineTts) Generate(text string, sid int, speed float32) *GeneratedA
// see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo
// :n:n means 0:n:n, means low:high:capacity
samples := unsafe.Slice(audio.samples, n)
// copy(ans.Samples, samples)
for i := 0; i < n; i++ {
ans.Samples[i] = float32(samples[i])
}
... ... @@ -1840,3 +1839,88 @@ func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent
}
return result
}
type OfflineSpeechDenoiserGtcrnModelConfig struct {
Model string
}
type OfflineSpeechDenoiserModelConfig struct {
Gtcrn OfflineSpeechDenoiserGtcrnModelConfig
NumThreads int32
Debug int32
Provider string
}
type OfflineSpeechDenoiserConfig struct {
Model OfflineSpeechDenoiserModelConfig
}
type OfflineSpeechDenoiser struct {
impl *C.struct_SherpaOnnxOfflineSpeechDenoiser
}
type DenoisedAudio struct {
// Normalized samples in the range [-1, 1]
Samples []float32
SampleRate int
}
// Free the internal pointer inside the OfflineSpeechDenoiser to avoid memory leak.
func DeleteOfflineSpeechDenoiser(sd *OfflineSpeechDenoiser) {
C.SherpaOnnxDestroyOfflineSpeechDenoiser(sd.impl)
sd.impl = nil
}
// The user is responsible to invoke [DeleteOfflineSpeechDenoiser]() to free
// the returned tts to avoid memory leak
func NewOfflineSpeechDenoiser(config *OfflineSpeechDenoiserConfig) *OfflineSpeechDenoiser {
c := C.struct_SherpaOnnxOfflineSpeechDenoiserConfig{}
c.model.gtcrn.model = C.CString(config.Model.Gtcrn.Model)
defer C.free(unsafe.Pointer(c.model.gtcrn.model))
c.model.num_threads = C.int(config.Model.NumThreads)
c.model.debug = C.int(config.Model.Debug)
c.model.provider = C.CString(config.Model.Provider)
defer C.free(unsafe.Pointer(c.model.provider))
impl := C.SherpaOnnxCreateOfflineSpeechDenoiser(&c)
if impl == nil {
return nil
}
sd := &OfflineSpeechDenoiser{}
sd.impl = impl
return sd
}
func (sd *OfflineSpeechDenoiser) Run(samples []float32, sampleRate int) *DenoisedAudio {
audio := C.SherpaOnnxOfflineSpeechDenoiserRun(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)), C.int(sampleRate))
defer C.SherpaOnnxDestroyDenoisedAudio(audio)
ans := &DenoisedAudio{}
ans.SampleRate = int(audio.sample_rate)
n := int(audio.n)
ans.Samples = make([]float32, n)
denoisedSamples := unsafe.Slice(audio.samples, n)
for i := 0; i < n; i++ {
ans.Samples[i] = float32(denoisedSamples[i])
}
return ans
}
func (audio *DenoisedAudio) Save(filename string) bool {
s := C.CString(filename)
defer C.free(unsafe.Pointer(s))
ok := int(C.SherpaOnnxWriteWave((*C.float)(&audio.Samples[0]), C.int(len(audio.Samples)), C.int(audio.SampleRate), s))
return ok == 1
}
func (sd *OfflineSpeechDenoiser) SampleRate() int {
return int(C.SherpaOnnxOfflineSpeechDenoiserGetSampleRate(sd.impl))
}
... ...