Fangjun Kuang
Committed by GitHub

Add Go API for speech enhancement GTCRN models (#1991)

@@ -70,6 +70,13 @@ jobs: @@ -70,6 +70,13 @@ jobs:
70 run: | 70 run: |
71 gcc --version 71 gcc --version
72 72
  73 + - name: Test speech enhancement (GTCRN)
  74 + if: matrix.os != 'windows-latest'
  75 + shell: bash
  76 + run: |
  77 + cd go-api-examples/speech-enhancement-gtcrn/
  78 + ./run.sh
  79 +
73 - name: Test Keyword spotting 80 - name: Test Keyword spotting
74 if: matrix.os != 'windows-latest' 81 if: matrix.os != 'windows-latest'
75 shell: bash 82 shell: bash
@@ -132,6 +132,15 @@ jobs: @@ -132,6 +132,15 @@ jobs:
132 name: ${{ matrix.os }}-libs 132 name: ${{ matrix.os }}-libs
133 path: to-upload/ 133 path: to-upload/
134 134
  135 + - name: Test speech enhancement (GTCRN)
  136 + shell: bash
  137 + run: |
  138 + cd scripts/go/_internal/speech-enhancement-gtcrn/
  139 +
  140 + ./run.sh
  141 +
  142 + ls -lh
  143 +
135 - name: Test audio tagging 144 - name: Test audio tagging
136 shell: bash 145 shell: bash
137 run: | 146 run: |
  1 +module speech-enhancement-gtcrn
  2 +
  3 +go 1.17
  4 +
  1 +package main
  2 +
  3 +import (
  4 + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
  5 + "log"
  6 +)
  7 +
  8 +func main() {
  9 + log.SetFlags(log.LstdFlags | log.Lmicroseconds)
  10 +
  11 + config := sherpa.OfflineSpeechDenoiserConfig{}
  12 +
  13 + // Please download the models from
  14 + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models
  15 +
  16 + config.Model.Gtcrn.Model = "./gtcrn_simple.onnx"
  17 + config.Model.NumThreads = 1
  18 + config.Model.Debug = 1
  19 +
  20 + sd := sherpa.NewOfflineSpeechDenoiser(&config)
  21 + defer sherpa.DeleteOfflineSpeechDenoiser(sd)
  22 +
  23 + wave_filename := "./inp_16k.wav"
  24 +
  25 + wave := sherpa.ReadWave(wave_filename)
  26 + if wave == nil {
  27 + log.Printf("Failed to read %v\n", wave_filename)
  28 + return
  29 + }
  30 +
  31 + log.Println("Started")
  32 + audio := sd.Run(wave.Samples, wave.SampleRate)
  33 + log.Println("Done!")
  34 +
  35 + filename := "./enhanced-16k.wav"
  36 + ok := audio.Save(filename)
  37 + if !ok {
  38 + log.Fatalf("Failed to write", filename)
  39 + } else {
  40 + log.Println("Saved to ", filename)
  41 + }
  42 +
  43 +}
  1 +#!/usr/bin/env bash
  2 +set -ex
  3 +
  4 +if [ ! -f ./gtcrn_simple.onnx ]; then
  5 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx
  6 +fi
  7 +
  8 +if [ ! -f ./inp_16k.wav ]; then
  9 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav
  10 +fi
  11 +
  12 +go mod tidy
  13 +go build
  14 +
  15 +./speech-enhancement-gtcrn
1 /// Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) 1 /// Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
2 2
  3 +using System;
3 using System.Runtime.InteropServices; 4 using System.Runtime.InteropServices;
4 5
5 namespace SherpaOnnx 6 namespace SherpaOnnx
  1 +module speech-enhancement-gtcrn
  2 +
  3 +go 1.17
  4 +
  5 +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
  1 +../../../../go-api-examples/speech-enhancement-gtcrn/main.go
  1 +../../../../go-api-examples/speech-enhancement-gtcrn/run.sh
@@ -959,7 +959,6 @@ func (tts *OfflineTts) Generate(text string, sid int, speed float32) *GeneratedA @@ -959,7 +959,6 @@ func (tts *OfflineTts) Generate(text string, sid int, speed float32) *GeneratedA
959 // see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo 959 // see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo
960 // :n:n means 0:n:n, means low:high:capacity 960 // :n:n means 0:n:n, means low:high:capacity
961 samples := unsafe.Slice(audio.samples, n) 961 samples := unsafe.Slice(audio.samples, n)
962 - // copy(ans.Samples, samples)  
963 for i := 0; i < n; i++ { 962 for i := 0; i < n; i++ {
964 ans.Samples[i] = float32(samples[i]) 963 ans.Samples[i] = float32(samples[i])
965 } 964 }
@@ -1840,3 +1839,88 @@ func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent @@ -1840,3 +1839,88 @@ func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent
1840 } 1839 }
1841 return result 1840 return result
1842 } 1841 }
  1842 +
  1843 +type OfflineSpeechDenoiserGtcrnModelConfig struct {
  1844 + Model string
  1845 +}
  1846 +
  1847 +type OfflineSpeechDenoiserModelConfig struct {
  1848 + Gtcrn OfflineSpeechDenoiserGtcrnModelConfig
  1849 + NumThreads int32
  1850 + Debug int32
  1851 + Provider string
  1852 +}
  1853 +
  1854 +type OfflineSpeechDenoiserConfig struct {
  1855 + Model OfflineSpeechDenoiserModelConfig
  1856 +}
  1857 +
  1858 +type OfflineSpeechDenoiser struct {
  1859 + impl *C.struct_SherpaOnnxOfflineSpeechDenoiser
  1860 +}
  1861 +
  1862 +type DenoisedAudio struct {
  1863 + // Normalized samples in the range [-1, 1]
  1864 + Samples []float32
  1865 +
  1866 + SampleRate int
  1867 +}
  1868 +
  1869 +// Free the internal pointer inside the OfflineSpeechDenoiser to avoid memory leak.
  1870 +func DeleteOfflineSpeechDenoiser(sd *OfflineSpeechDenoiser) {
  1871 + C.SherpaOnnxDestroyOfflineSpeechDenoiser(sd.impl)
  1872 + sd.impl = nil
  1873 +}
  1874 +
  1875 +// The user is responsible to invoke [DeleteOfflineSpeechDenoiser]() to free
  1876 +// the returned tts to avoid memory leak
  1877 +func NewOfflineSpeechDenoiser(config *OfflineSpeechDenoiserConfig) *OfflineSpeechDenoiser {
  1878 + c := C.struct_SherpaOnnxOfflineSpeechDenoiserConfig{}
  1879 + c.model.gtcrn.model = C.CString(config.Model.Gtcrn.Model)
  1880 + defer C.free(unsafe.Pointer(c.model.gtcrn.model))
  1881 +
  1882 + c.model.num_threads = C.int(config.Model.NumThreads)
  1883 + c.model.debug = C.int(config.Model.Debug)
  1884 +
  1885 + c.model.provider = C.CString(config.Model.Provider)
  1886 + defer C.free(unsafe.Pointer(c.model.provider))
  1887 +
  1888 + impl := C.SherpaOnnxCreateOfflineSpeechDenoiser(&c)
  1889 + if impl == nil {
  1890 + return nil
  1891 + }
  1892 +
  1893 + sd := &OfflineSpeechDenoiser{}
  1894 + sd.impl = impl
  1895 + return sd
  1896 +}
  1897 +
  1898 +func (sd *OfflineSpeechDenoiser) Run(samples []float32, sampleRate int) *DenoisedAudio {
  1899 + audio := C.SherpaOnnxOfflineSpeechDenoiserRun(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)), C.int(sampleRate))
  1900 + defer C.SherpaOnnxDestroyDenoisedAudio(audio)
  1901 +
  1902 + ans := &DenoisedAudio{}
  1903 + ans.SampleRate = int(audio.sample_rate)
  1904 + n := int(audio.n)
  1905 + ans.Samples = make([]float32, n)
  1906 +
  1907 + denoisedSamples := unsafe.Slice(audio.samples, n)
  1908 + for i := 0; i < n; i++ {
  1909 + ans.Samples[i] = float32(denoisedSamples[i])
  1910 + }
  1911 +
  1912 + return ans
  1913 +}
  1914 +
  1915 +func (audio *DenoisedAudio) Save(filename string) bool {
  1916 + s := C.CString(filename)
  1917 + defer C.free(unsafe.Pointer(s))
  1918 +
  1919 + ok := int(C.SherpaOnnxWriteWave((*C.float)(&audio.Samples[0]), C.int(len(audio.Samples)), C.int(audio.SampleRate), s))
  1920 +
  1921 + return ok == 1
  1922 +}
  1923 +
  1924 +func (sd *OfflineSpeechDenoiser) SampleRate() int {
  1925 + return int(C.SherpaOnnxOfflineSpeechDenoiserGetSampleRate(sd.impl))
  1926 +}