Committed by
GitHub
Add Go API for speech enhancement GTCRN models (#1991)
正在显示
11 个修改的文件
包含
172 行增加
和
1 行删除
| @@ -70,6 +70,13 @@ jobs: | @@ -70,6 +70,13 @@ jobs: | ||
| 70 | run: | | 70 | run: | |
| 71 | gcc --version | 71 | gcc --version |
| 72 | 72 | ||
| 73 | + - name: Test speech enhancement (GTCRN) | ||
| 74 | + if: matrix.os != 'windows-latest' | ||
| 75 | + shell: bash | ||
| 76 | + run: | | ||
| 77 | + cd go-api-examples/speech-enhancement-gtcrn/ | ||
| 78 | + ./run.sh | ||
| 79 | + | ||
| 73 | - name: Test Keyword spotting | 80 | - name: Test Keyword spotting |
| 74 | if: matrix.os != 'windows-latest' | 81 | if: matrix.os != 'windows-latest' |
| 75 | shell: bash | 82 | shell: bash |
| @@ -132,6 +132,15 @@ jobs: | @@ -132,6 +132,15 @@ jobs: | ||
| 132 | name: ${{ matrix.os }}-libs | 132 | name: ${{ matrix.os }}-libs |
| 133 | path: to-upload/ | 133 | path: to-upload/ |
| 134 | 134 | ||
| 135 | + - name: Test speech enhancement (GTCRN) | ||
| 136 | + shell: bash | ||
| 137 | + run: | | ||
| 138 | + cd scripts/go/_internal/speech-enhancement-gtcrn/ | ||
| 139 | + | ||
| 140 | + ./run.sh | ||
| 141 | + | ||
| 142 | + ls -lh | ||
| 143 | + | ||
| 135 | - name: Test audio tagging | 144 | - name: Test audio tagging |
| 136 | shell: bash | 145 | shell: bash |
| 137 | run: | | 146 | run: | |
| 1 | +package main | ||
| 2 | + | ||
| 3 | +import ( | ||
| 4 | + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" | ||
| 5 | + "log" | ||
| 6 | +) | ||
| 7 | + | ||
| 8 | +func main() { | ||
| 9 | + log.SetFlags(log.LstdFlags | log.Lmicroseconds) | ||
| 10 | + | ||
| 11 | + config := sherpa.OfflineSpeechDenoiserConfig{} | ||
| 12 | + | ||
| 13 | + // Please download the models from | ||
| 14 | + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models | ||
| 15 | + | ||
| 16 | + config.Model.Gtcrn.Model = "./gtcrn_simple.onnx" | ||
| 17 | + config.Model.NumThreads = 1 | ||
| 18 | + config.Model.Debug = 1 | ||
| 19 | + | ||
| 20 | + sd := sherpa.NewOfflineSpeechDenoiser(&config) | ||
| 21 | + defer sherpa.DeleteOfflineSpeechDenoiser(sd) | ||
| 22 | + | ||
| 23 | + wave_filename := "./inp_16k.wav" | ||
| 24 | + | ||
| 25 | + wave := sherpa.ReadWave(wave_filename) | ||
| 26 | + if wave == nil { | ||
| 27 | + log.Printf("Failed to read %v\n", wave_filename) | ||
| 28 | + return | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + log.Println("Started") | ||
| 32 | + audio := sd.Run(wave.Samples, wave.SampleRate) | ||
| 33 | + log.Println("Done!") | ||
| 34 | + | ||
| 35 | + filename := "./enhanced-16k.wav" | ||
| 36 | + ok := audio.Save(filename) | ||
| 37 | + if !ok { | ||
| 38 | + log.Fatalf("Failed to write", filename) | ||
| 39 | + } else { | ||
| 40 | + log.Println("Saved to ", filename) | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | +} |
| 1 | +#!/usr/bin/env bash | ||
| 2 | +set -ex | ||
| 3 | + | ||
| 4 | +if [ ! -f ./gtcrn_simple.onnx ]; then | ||
| 5 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx | ||
| 6 | +fi | ||
| 7 | + | ||
| 8 | +if [ ! -f ./inp_16k.wav ]; then | ||
| 9 | + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav | ||
| 10 | +fi | ||
| 11 | + | ||
| 12 | +go mod tidy | ||
| 13 | +go build | ||
| 14 | + | ||
| 15 | +./speech-enhancement-gtcrn |
| 1 | +speech-enhancement-gtcrn |
| 1 | +../../../../go-api-examples/speech-enhancement-gtcrn/main.go |
| 1 | +../../../../go-api-examples/speech-enhancement-gtcrn/run.sh |
| @@ -959,7 +959,6 @@ func (tts *OfflineTts) Generate(text string, sid int, speed float32) *GeneratedA | @@ -959,7 +959,6 @@ func (tts *OfflineTts) Generate(text string, sid int, speed float32) *GeneratedA | ||
| 959 | // see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo | 959 | // see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo |
| 960 | // :n:n means 0:n:n, means low:high:capacity | 960 | // :n:n means 0:n:n, means low:high:capacity |
| 961 | samples := unsafe.Slice(audio.samples, n) | 961 | samples := unsafe.Slice(audio.samples, n) |
| 962 | - // copy(ans.Samples, samples) | ||
| 963 | for i := 0; i < n; i++ { | 962 | for i := 0; i < n; i++ { |
| 964 | ans.Samples[i] = float32(samples[i]) | 963 | ans.Samples[i] = float32(samples[i]) |
| 965 | } | 964 | } |
| @@ -1840,3 +1839,88 @@ func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent | @@ -1840,3 +1839,88 @@ func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent | ||
| 1840 | } | 1839 | } |
| 1841 | return result | 1840 | return result |
| 1842 | } | 1841 | } |
| 1842 | + | ||
| 1843 | +type OfflineSpeechDenoiserGtcrnModelConfig struct { | ||
| 1844 | + Model string | ||
| 1845 | +} | ||
| 1846 | + | ||
| 1847 | +type OfflineSpeechDenoiserModelConfig struct { | ||
| 1848 | + Gtcrn OfflineSpeechDenoiserGtcrnModelConfig | ||
| 1849 | + NumThreads int32 | ||
| 1850 | + Debug int32 | ||
| 1851 | + Provider string | ||
| 1852 | +} | ||
| 1853 | + | ||
| 1854 | +type OfflineSpeechDenoiserConfig struct { | ||
| 1855 | + Model OfflineSpeechDenoiserModelConfig | ||
| 1856 | +} | ||
| 1857 | + | ||
| 1858 | +type OfflineSpeechDenoiser struct { | ||
| 1859 | + impl *C.struct_SherpaOnnxOfflineSpeechDenoiser | ||
| 1860 | +} | ||
| 1861 | + | ||
| 1862 | +type DenoisedAudio struct { | ||
| 1863 | + // Normalized samples in the range [-1, 1] | ||
| 1864 | + Samples []float32 | ||
| 1865 | + | ||
| 1866 | + SampleRate int | ||
| 1867 | +} | ||
| 1868 | + | ||
| 1869 | +// Free the internal pointer inside the OfflineSpeechDenoiser to avoid memory leak. | ||
| 1870 | +func DeleteOfflineSpeechDenoiser(sd *OfflineSpeechDenoiser) { | ||
| 1871 | + C.SherpaOnnxDestroyOfflineSpeechDenoiser(sd.impl) | ||
| 1872 | + sd.impl = nil | ||
| 1873 | +} | ||
| 1874 | + | ||
| 1875 | +// The user is responsible to invoke [DeleteOfflineSpeechDenoiser]() to free | ||
| 1876 | +// the returned tts to avoid memory leak | ||
| 1877 | +func NewOfflineSpeechDenoiser(config *OfflineSpeechDenoiserConfig) *OfflineSpeechDenoiser { | ||
| 1878 | + c := C.struct_SherpaOnnxOfflineSpeechDenoiserConfig{} | ||
| 1879 | + c.model.gtcrn.model = C.CString(config.Model.Gtcrn.Model) | ||
| 1880 | + defer C.free(unsafe.Pointer(c.model.gtcrn.model)) | ||
| 1881 | + | ||
| 1882 | + c.model.num_threads = C.int(config.Model.NumThreads) | ||
| 1883 | + c.model.debug = C.int(config.Model.Debug) | ||
| 1884 | + | ||
| 1885 | + c.model.provider = C.CString(config.Model.Provider) | ||
| 1886 | + defer C.free(unsafe.Pointer(c.model.provider)) | ||
| 1887 | + | ||
| 1888 | + impl := C.SherpaOnnxCreateOfflineSpeechDenoiser(&c) | ||
| 1889 | + if impl == nil { | ||
| 1890 | + return nil | ||
| 1891 | + } | ||
| 1892 | + | ||
| 1893 | + sd := &OfflineSpeechDenoiser{} | ||
| 1894 | + sd.impl = impl | ||
| 1895 | + return sd | ||
| 1896 | +} | ||
| 1897 | + | ||
| 1898 | +func (sd *OfflineSpeechDenoiser) Run(samples []float32, sampleRate int) *DenoisedAudio { | ||
| 1899 | + audio := C.SherpaOnnxOfflineSpeechDenoiserRun(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)), C.int(sampleRate)) | ||
| 1900 | + defer C.SherpaOnnxDestroyDenoisedAudio(audio) | ||
| 1901 | + | ||
| 1902 | + ans := &DenoisedAudio{} | ||
| 1903 | + ans.SampleRate = int(audio.sample_rate) | ||
| 1904 | + n := int(audio.n) | ||
| 1905 | + ans.Samples = make([]float32, n) | ||
| 1906 | + | ||
| 1907 | + denoisedSamples := unsafe.Slice(audio.samples, n) | ||
| 1908 | + for i := 0; i < n; i++ { | ||
| 1909 | + ans.Samples[i] = float32(denoisedSamples[i]) | ||
| 1910 | + } | ||
| 1911 | + | ||
| 1912 | + return ans | ||
| 1913 | +} | ||
| 1914 | + | ||
| 1915 | +func (audio *DenoisedAudio) Save(filename string) bool { | ||
| 1916 | + s := C.CString(filename) | ||
| 1917 | + defer C.free(unsafe.Pointer(s)) | ||
| 1918 | + | ||
| 1919 | + ok := int(C.SherpaOnnxWriteWave((*C.float)(&audio.Samples[0]), C.int(len(audio.Samples)), C.int(audio.SampleRate), s)) | ||
| 1920 | + | ||
| 1921 | + return ok == 1 | ||
| 1922 | +} | ||
| 1923 | + | ||
| 1924 | +func (sd *OfflineSpeechDenoiser) SampleRate() int { | ||
| 1925 | + return int(C.SherpaOnnxOfflineSpeechDenoiserGetSampleRate(sd.impl)) | ||
| 1926 | +} |
-
请 注册 或 登录 后发表评论