正在显示
39 个修改的文件
包含
1652 行增加
和
108 行删除
.github/scripts/test-speaker-diarization.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -ex | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 17 | +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 18 | +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 19 | + | ||
| 20 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
| 21 | + | ||
| 22 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav | ||
| 23 | + | ||
| 24 | +log "specify number of clusters" | ||
| 25 | +$EXE \ | ||
| 26 | + --clustering.num-clusters=4 \ | ||
| 27 | + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ | ||
| 28 | + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ | ||
| 29 | + ./0-four-speakers-zh.wav | ||
| 30 | + | ||
| 31 | +log "specify threshold for clustering" | ||
| 32 | + | ||
| 33 | +$EXE \ | ||
| 34 | + --clustering.cluster-threshold=0.90 \ | ||
| 35 | + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ | ||
| 36 | + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ | ||
| 37 | + ./0-four-speakers-zh.wav | ||
| 38 | + | ||
| 39 | +rm -rf sherpa-onnx-pyannote-* | ||
| 40 | +rm -fv *.onnx | ||
| 41 | +rm -fv *.wav |
| @@ -29,7 +29,7 @@ jobs: | @@ -29,7 +29,7 @@ jobs: | ||
| 29 | - name: Install pyannote | 29 | - name: Install pyannote |
| 30 | shell: bash | 30 | shell: bash |
| 31 | run: | | 31 | run: | |
| 32 | - pip install pyannote.audio onnx onnxruntime | 32 | + pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3 |
| 33 | 33 | ||
| 34 | - name: Run | 34 | - name: Run |
| 35 | shell: bash | 35 | shell: bash |
| @@ -18,6 +18,7 @@ on: | @@ -18,6 +18,7 @@ on: | ||
| 18 | - '.github/scripts/test-audio-tagging.sh' | 18 | - '.github/scripts/test-audio-tagging.sh' |
| 19 | - '.github/scripts/test-offline-punctuation.sh' | 19 | - '.github/scripts/test-offline-punctuation.sh' |
| 20 | - '.github/scripts/test-online-punctuation.sh' | 20 | - '.github/scripts/test-online-punctuation.sh' |
| 21 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 21 | - 'CMakeLists.txt' | 22 | - 'CMakeLists.txt' |
| 22 | - 'cmake/**' | 23 | - 'cmake/**' |
| 23 | - 'sherpa-onnx/csrc/*' | 24 | - 'sherpa-onnx/csrc/*' |
| @@ -38,6 +39,7 @@ on: | @@ -38,6 +39,7 @@ on: | ||
| 38 | - '.github/scripts/test-audio-tagging.sh' | 39 | - '.github/scripts/test-audio-tagging.sh' |
| 39 | - '.github/scripts/test-offline-punctuation.sh' | 40 | - '.github/scripts/test-offline-punctuation.sh' |
| 40 | - '.github/scripts/test-online-punctuation.sh' | 41 | - '.github/scripts/test-online-punctuation.sh' |
| 42 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 41 | - 'CMakeLists.txt' | 43 | - 'CMakeLists.txt' |
| 42 | - 'cmake/**' | 44 | - 'cmake/**' |
| 43 | - 'sherpa-onnx/csrc/*' | 45 | - 'sherpa-onnx/csrc/*' |
| @@ -143,6 +145,15 @@ jobs: | @@ -143,6 +145,15 @@ jobs: | ||
| 143 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 145 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 144 | path: install/* | 146 | path: install/* |
| 145 | 147 | ||
| 148 | + - name: Test offline speaker diarization | ||
| 149 | + shell: bash | ||
| 150 | + run: | | ||
| 151 | + du -h -d1 . | ||
| 152 | + export PATH=$PWD/build/bin:$PATH | ||
| 153 | + export EXE=sherpa-onnx-offline-speaker-diarization | ||
| 154 | + | ||
| 155 | + .github/scripts/test-speaker-diarization.sh | ||
| 156 | + | ||
| 146 | - name: Test offline transducer | 157 | - name: Test offline transducer |
| 147 | shell: bash | 158 | shell: bash |
| 148 | run: | | 159 | run: | |
| @@ -18,6 +18,7 @@ on: | @@ -18,6 +18,7 @@ on: | ||
| 18 | - '.github/scripts/test-audio-tagging.sh' | 18 | - '.github/scripts/test-audio-tagging.sh' |
| 19 | - '.github/scripts/test-offline-punctuation.sh' | 19 | - '.github/scripts/test-offline-punctuation.sh' |
| 20 | - '.github/scripts/test-online-punctuation.sh' | 20 | - '.github/scripts/test-online-punctuation.sh' |
| 21 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 21 | - 'CMakeLists.txt' | 22 | - 'CMakeLists.txt' |
| 22 | - 'cmake/**' | 23 | - 'cmake/**' |
| 23 | - 'sherpa-onnx/csrc/*' | 24 | - 'sherpa-onnx/csrc/*' |
| @@ -37,6 +38,7 @@ on: | @@ -37,6 +38,7 @@ on: | ||
| 37 | - '.github/scripts/test-audio-tagging.sh' | 38 | - '.github/scripts/test-audio-tagging.sh' |
| 38 | - '.github/scripts/test-offline-punctuation.sh' | 39 | - '.github/scripts/test-offline-punctuation.sh' |
| 39 | - '.github/scripts/test-online-punctuation.sh' | 40 | - '.github/scripts/test-online-punctuation.sh' |
| 41 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 40 | - 'CMakeLists.txt' | 42 | - 'CMakeLists.txt' |
| 41 | - 'cmake/**' | 43 | - 'cmake/**' |
| 42 | - 'sherpa-onnx/csrc/*' | 44 | - 'sherpa-onnx/csrc/*' |
| @@ -115,6 +117,15 @@ jobs: | @@ -115,6 +117,15 @@ jobs: | ||
| 115 | otool -L build/bin/sherpa-onnx | 117 | otool -L build/bin/sherpa-onnx |
| 116 | otool -l build/bin/sherpa-onnx | 118 | otool -l build/bin/sherpa-onnx |
| 117 | 119 | ||
| 120 | + - name: Test offline speaker diarization | ||
| 121 | + shell: bash | ||
| 122 | + run: | | ||
| 123 | + du -h -d1 . | ||
| 124 | + export PATH=$PWD/build/bin:$PATH | ||
| 125 | + export EXE=sherpa-onnx-offline-speaker-diarization | ||
| 126 | + | ||
| 127 | + .github/scripts/test-speaker-diarization.sh | ||
| 128 | + | ||
| 118 | - name: Test offline transducer | 129 | - name: Test offline transducer |
| 119 | shell: bash | 130 | shell: bash |
| 120 | run: | | 131 | run: | |
| @@ -67,7 +67,7 @@ jobs: | @@ -67,7 +67,7 @@ jobs: | ||
| 67 | curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin | 67 | curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin |
| 68 | 68 | ||
| 69 | test_wavs=( | 69 | test_wavs=( |
| 70 | - 0-two-speakers-zh.wav | 70 | + 0-four-speakers-zh.wav |
| 71 | 1-two-speakers-en.wav | 71 | 1-two-speakers-en.wav |
| 72 | 2-two-speakers-en.wav | 72 | 2-two-speakers-en.wav |
| 73 | 3-two-speakers-en.wav | 73 | 3-two-speakers-en.wav |
| @@ -17,6 +17,7 @@ on: | @@ -17,6 +17,7 @@ on: | ||
| 17 | - '.github/scripts/test-audio-tagging.sh' | 17 | - '.github/scripts/test-audio-tagging.sh' |
| 18 | - '.github/scripts/test-offline-punctuation.sh' | 18 | - '.github/scripts/test-offline-punctuation.sh' |
| 19 | - '.github/scripts/test-online-punctuation.sh' | 19 | - '.github/scripts/test-online-punctuation.sh' |
| 20 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 20 | - 'CMakeLists.txt' | 21 | - 'CMakeLists.txt' |
| 21 | - 'cmake/**' | 22 | - 'cmake/**' |
| 22 | - 'sherpa-onnx/csrc/*' | 23 | - 'sherpa-onnx/csrc/*' |
| @@ -34,6 +35,7 @@ on: | @@ -34,6 +35,7 @@ on: | ||
| 34 | - '.github/scripts/test-audio-tagging.sh' | 35 | - '.github/scripts/test-audio-tagging.sh' |
| 35 | - '.github/scripts/test-offline-punctuation.sh' | 36 | - '.github/scripts/test-offline-punctuation.sh' |
| 36 | - '.github/scripts/test-online-punctuation.sh' | 37 | - '.github/scripts/test-online-punctuation.sh' |
| 38 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 37 | - 'CMakeLists.txt' | 39 | - 'CMakeLists.txt' |
| 38 | - 'cmake/**' | 40 | - 'cmake/**' |
| 39 | - 'sherpa-onnx/csrc/*' | 41 | - 'sherpa-onnx/csrc/*' |
| @@ -87,6 +89,15 @@ jobs: | @@ -87,6 +89,15 @@ jobs: | ||
| 87 | name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} | 89 | name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} |
| 88 | path: build/install/* | 90 | path: build/install/* |
| 89 | 91 | ||
| 92 | + - name: Test offline speaker diarization | ||
| 93 | + shell: bash | ||
| 94 | + run: | | ||
| 95 | + du -h -d1 . | ||
| 96 | + export PATH=$PWD/build/bin:$PATH | ||
| 97 | + export EXE=sherpa-onnx-offline-speaker-diarization.exe | ||
| 98 | + | ||
| 99 | + .github/scripts/test-speaker-diarization.sh | ||
| 100 | + | ||
| 90 | - name: Test online punctuation | 101 | - name: Test online punctuation |
| 91 | shell: bash | 102 | shell: bash |
| 92 | run: | | 103 | run: | |
| @@ -17,6 +17,7 @@ on: | @@ -17,6 +17,7 @@ on: | ||
| 17 | - '.github/scripts/test-audio-tagging.sh' | 17 | - '.github/scripts/test-audio-tagging.sh' |
| 18 | - '.github/scripts/test-offline-punctuation.sh' | 18 | - '.github/scripts/test-offline-punctuation.sh' |
| 19 | - '.github/scripts/test-online-punctuation.sh' | 19 | - '.github/scripts/test-online-punctuation.sh' |
| 20 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 20 | - 'CMakeLists.txt' | 21 | - 'CMakeLists.txt' |
| 21 | - 'cmake/**' | 22 | - 'cmake/**' |
| 22 | - 'sherpa-onnx/csrc/*' | 23 | - 'sherpa-onnx/csrc/*' |
| @@ -34,6 +35,7 @@ on: | @@ -34,6 +35,7 @@ on: | ||
| 34 | - '.github/scripts/test-audio-tagging.sh' | 35 | - '.github/scripts/test-audio-tagging.sh' |
| 35 | - '.github/scripts/test-offline-punctuation.sh' | 36 | - '.github/scripts/test-offline-punctuation.sh' |
| 36 | - '.github/scripts/test-online-punctuation.sh' | 37 | - '.github/scripts/test-online-punctuation.sh' |
| 38 | + - '.github/scripts/test-speaker-diarization.sh' | ||
| 37 | - 'CMakeLists.txt' | 39 | - 'CMakeLists.txt' |
| 38 | - 'cmake/**' | 40 | - 'cmake/**' |
| 39 | - 'sherpa-onnx/csrc/*' | 41 | - 'sherpa-onnx/csrc/*' |
| @@ -87,6 +89,15 @@ jobs: | @@ -87,6 +89,15 @@ jobs: | ||
| 87 | name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} | 89 | name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} |
| 88 | path: build/install/* | 90 | path: build/install/* |
| 89 | 91 | ||
| 92 | + - name: Test offline speaker diarization | ||
| 93 | + shell: bash | ||
| 94 | + run: | | ||
| 95 | + du -h -d1 . | ||
| 96 | + export PATH=$PWD/build/bin:$PATH | ||
| 97 | + export EXE=sherpa-onnx-offline-speaker-diarization.exe | ||
| 98 | + | ||
| 99 | + .github/scripts/test-speaker-diarization.sh | ||
| 100 | + | ||
| 90 | - name: Test online punctuation | 101 | - name: Test online punctuation |
| 91 | shell: bash | 102 | shell: bash |
| 92 | run: | | 103 | run: | |
| @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 36 | fprintf(stderr, "Memory error\n"); | 36 | fprintf(stderr, "Memory error\n"); |
| 37 | return -1; | 37 | return -1; |
| 38 | } | 38 | } |
| 39 | - size_t read_bytes = fread(*buffer_out, 1, size, file); | 39 | + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); |
| 40 | if (read_bytes != size) { | 40 | if (read_bytes != size) { |
| 41 | printf("Errors occured in reading the file %s\n", filename); | 41 | printf("Errors occured in reading the file %s\n", filename); |
| 42 | free((void *)*buffer_out); | 42 | free((void *)*buffer_out); |
| @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 36 | fprintf(stderr, "Memory error\n"); | 36 | fprintf(stderr, "Memory error\n"); |
| 37 | return -1; | 37 | return -1; |
| 38 | } | 38 | } |
| 39 | - size_t read_bytes = fread(*buffer_out, 1, size, file); | 39 | + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); |
| 40 | if (read_bytes != size) { | 40 | if (read_bytes != size) { |
| 41 | printf("Errors occured in reading the file %s\n", filename); | 41 | printf("Errors occured in reading the file %s\n", filename); |
| 42 | free((void *)*buffer_out); | 42 | free((void *)*buffer_out); |
| @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 36 | fprintf(stderr, "Memory error\n"); | 36 | fprintf(stderr, "Memory error\n"); |
| 37 | return -1; | 37 | return -1; |
| 38 | } | 38 | } |
| 39 | - size_t read_bytes = fread(*buffer_out, 1, size, file); | 39 | + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); |
| 40 | if (read_bytes != size) { | 40 | if (read_bytes != size) { |
| 41 | printf("Errors occured in reading the file %s\n", filename); | 41 | printf("Errors occured in reading the file %s\n", filename); |
| 42 | free((void *)*buffer_out); | 42 | free((void *)*buffer_out); |
| @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { | ||
| 36 | fprintf(stderr, "Memory error\n"); | 36 | fprintf(stderr, "Memory error\n"); |
| 37 | return -1; | 37 | return -1; |
| 38 | } | 38 | } |
| 39 | - size_t read_bytes = fread(*buffer_out, 1, size, file); | 39 | + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); |
| 40 | if (read_bytes != size) { | 40 | if (read_bytes != size) { |
| 41 | printf("Errors occured in reading the file %s\n", filename); | 41 | printf("Errors occured in reading the file %s\n", filename); |
| 42 | free((void *)*buffer_out); | 42 | free((void *)*buffer_out); |
| @@ -55,6 +55,7 @@ def get_binaries(): | @@ -55,6 +55,7 @@ def get_binaries(): | ||
| 55 | "sherpa-onnx-offline-audio-tagging", | 55 | "sherpa-onnx-offline-audio-tagging", |
| 56 | "sherpa-onnx-offline-language-identification", | 56 | "sherpa-onnx-offline-language-identification", |
| 57 | "sherpa-onnx-offline-punctuation", | 57 | "sherpa-onnx-offline-punctuation", |
| 58 | + "sherpa-onnx-offline-speaker-diarization", | ||
| 58 | "sherpa-onnx-offline-tts", | 59 | "sherpa-onnx-offline-tts", |
| 59 | "sherpa-onnx-offline-tts-play", | 60 | "sherpa-onnx-offline-tts-play", |
| 60 | "sherpa-onnx-offline-websocket-server", | 61 | "sherpa-onnx-offline-websocket-server", |
| @@ -3,12 +3,9 @@ | @@ -3,12 +3,9 @@ | ||
| 3 | Please download test wave files from | 3 | Please download test wave files from |
| 4 | https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models | 4 | https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models |
| 5 | 5 | ||
| 6 | -## 0-two-speakers-zh.wav | 6 | +## 0-four-speakers-zh.wav |
| 7 | 7 | ||
| 8 | -This file is from | ||
| 9 | -https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0 | ||
| 10 | - | ||
| 11 | -Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`. | 8 | +It is recorded by @csukuangfj |
| 12 | 9 | ||
| 13 | ## 1-two-speakers-en.wav | 10 | ## 1-two-speakers-en.wav |
| 14 | 11 | ||
| @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav` | @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav` | ||
| 40 | 37 | ||
| 41 | 38 | ||
| 42 | ```bash | 39 | ```bash |
| 43 | -sox ML16091-Audio.mp3 3-two-speakers-en.wav | 40 | +sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav |
| 44 | ``` | 41 | ``` |
| @@ -72,7 +72,7 @@ def main(): | @@ -72,7 +72,7 @@ def main(): | ||
| 72 | model.receptive_field.duration * 16000 | 72 | model.receptive_field.duration * 16000 |
| 73 | ) | 73 | ) |
| 74 | 74 | ||
| 75 | - opset_version = 18 | 75 | + opset_version = 13 |
| 76 | 76 | ||
| 77 | filename = "model.onnx" | 77 | filename = "model.onnx" |
| 78 | torch.onnx.export( | 78 | torch.onnx.export( |
| @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | ||
| 164 | list(APPEND sources | 164 | list(APPEND sources |
| 165 | fast-clustering-config.cc | 165 | fast-clustering-config.cc |
| 166 | fast-clustering.cc | 166 | fast-clustering.cc |
| 167 | + offline-speaker-diarization-impl.cc | ||
| 168 | + offline-speaker-diarization-result.cc | ||
| 169 | + offline-speaker-diarization.cc | ||
| 170 | + offline-speaker-segmentation-model-config.cc | ||
| 171 | + offline-speaker-segmentation-pyannote-model-config.cc | ||
| 172 | + offline-speaker-segmentation-pyannote-model.cc | ||
| 167 | ) | 173 | ) |
| 168 | endif() | 174 | endif() |
| 169 | 175 | ||
| @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 260 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) | 266 | add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) |
| 261 | endif() | 267 | endif() |
| 262 | 268 | ||
| 269 | + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | ||
| 270 | + add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc) | ||
| 271 | + endif() | ||
| 272 | + | ||
| 263 | set(main_exes | 273 | set(main_exes |
| 264 | sherpa-onnx | 274 | sherpa-onnx |
| 265 | sherpa-onnx-keyword-spotter | 275 | sherpa-onnx-keyword-spotter |
| @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY) | @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY) | ||
| 276 | ) | 286 | ) |
| 277 | endif() | 287 | endif() |
| 278 | 288 | ||
| 289 | + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) | ||
| 290 | + list(APPEND main_exes | ||
| 291 | + sherpa-onnx-offline-speaker-diarization | ||
| 292 | + ) | ||
| 293 | + endif() | ||
| 294 | + | ||
| 279 | foreach(exe IN LISTS main_exes) | 295 | foreach(exe IN LISTS main_exes) |
| 280 | target_link_libraries(${exe} sherpa-onnx-core) | 296 | target_link_libraries(${exe} sherpa-onnx-core) |
| 281 | endforeach() | 297 | endforeach() |
| @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const { | @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const { | ||
| 21 | } | 21 | } |
| 22 | 22 | ||
| 23 | void FastClusteringConfig::Register(ParseOptions *po) { | 23 | void FastClusteringConfig::Register(ParseOptions *po) { |
| 24 | - std::string prefix = "ctc"; | ||
| 25 | - ParseOptions p(prefix, po); | ||
| 26 | - | ||
| 27 | - p.Register("num-clusters", &num_clusters, | ||
| 28 | - "Number of cluster. If greater than 0, then --cluster-thresold is " | ||
| 29 | - "ignored. Please provide it if you know the actual number of " | ||
| 30 | - "clusters in advance."); | ||
| 31 | - | ||
| 32 | - p.Register("cluster-threshold", &threshold, | ||
| 33 | - "If --num-clusters is not specified, then it specifies the " | ||
| 34 | - "distance threshold for clustering. smaller value -> more " | ||
| 35 | - "clusters. larger value -> fewer clusters"); | 24 | + po->Register( |
| 25 | + "num-clusters", &num_clusters, | ||
| 26 | + "Number of cluster. If greater than 0, then cluster threshold is " | ||
| 27 | + "ignored. Please provide it if you know the actual number of " | ||
| 28 | + "clusters in advance."); | ||
| 29 | + | ||
| 30 | + po->Register("cluster-threshold", &threshold, | ||
| 31 | + "If num_clusters is not specified, then it specifies the " | ||
| 32 | + "distance threshold for clustering. smaller value -> more " | ||
| 33 | + "clusters. larger value -> fewer clusters"); | ||
| 36 | } | 34 | } |
| 37 | 35 | ||
| 38 | bool FastClusteringConfig::Validate() const { | 36 | bool FastClusteringConfig::Validate() const { |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_MACROS_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_MACROS_H_ |
| 6 | #define SHERPA_ONNX_CSRC_MACROS_H_ | 6 | #define SHERPA_ONNX_CSRC_MACROS_H_ |
| 7 | #include <stdio.h> | 7 | #include <stdio.h> |
| 8 | +#include <stdlib.h> | ||
| 8 | 9 | ||
| 9 | #if __ANDROID_API__ >= 8 | 10 | #if __ANDROID_API__ >= 8 |
| 10 | #include "android/log.h" | 11 | #include "android/log.h" |
| @@ -169,4 +170,6 @@ | @@ -169,4 +170,6 @@ | ||
| 169 | } \ | 170 | } \ |
| 170 | } while (0) | 171 | } while (0) |
| 171 | 172 | ||
| 173 | +#define SHERPA_ONNX_EXIT(code) exit(code) | ||
| 174 | + | ||
| 172 | #endif // SHERPA_ONNX_CSRC_MACROS_H_ | 175 | #endif // SHERPA_ONNX_CSRC_MACROS_H_ |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <utility> | 9 | #include <utility> |
| 10 | 10 | ||
| 11 | #include "sherpa-onnx/csrc/macros.h" | 11 | #include "sherpa-onnx/csrc/macros.h" |
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | #include "sherpa-onnx/csrc/session.h" | 13 | #include "sherpa-onnx/csrc/session.h" |
| 13 | #include "sherpa-onnx/csrc/text-utils.h" | 14 | #include "sherpa-onnx/csrc/text-utils.h" |
| 14 | 15 |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | +#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +std::unique_ptr<OfflineSpeakerDiarizationImpl> | ||
| 15 | +OfflineSpeakerDiarizationImpl::Create( | ||
| 16 | + const OfflineSpeakerDiarizationConfig &config) { | ||
| 17 | + if (!config.segmentation.pyannote.model.empty()) { | ||
| 18 | + return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config); | ||
| 19 | + } | ||
| 20 | + | ||
| 21 | + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); | ||
| 22 | + | ||
| 23 | + return nullptr; | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <functional> | ||
| 9 | +#include <memory> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineSpeakerDiarizationImpl { | ||
| 15 | + public: | ||
| 16 | + static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create( | ||
| 17 | + const OfflineSpeakerDiarizationConfig &config); | ||
| 18 | + | ||
| 19 | + virtual ~OfflineSpeakerDiarizationImpl() = default; | ||
| 20 | + | ||
| 21 | + virtual int32_t SampleRate() const = 0; | ||
| 22 | + | ||
| 23 | + virtual OfflineSpeakerDiarizationResult Process( | ||
| 24 | + const float *audio, int32_t n, | ||
| 25 | + OfflineSpeakerDiarizationProgressCallback callback = nullptr, | ||
| 26 | + void *callback_arg = nullptr) const = 0; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <unordered_map> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "Eigen/Dense" | ||
| 13 | +#include "sherpa-onnx/csrc/fast-clustering.h" | ||
| 14 | +#include "sherpa-onnx/csrc/math.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" | ||
| 17 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +namespace { // NOLINT | ||
| 22 | + | ||
| 23 | +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41 | ||
| 24 | +template <class T> | ||
| 25 | +inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT | ||
| 26 | + std::hash<T> hasher; | ||
| 27 | + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47 | ||
| 31 | +struct PairHash { | ||
| 32 | + template <class T1, class T2> | ||
| 33 | + std::size_t operator()(const std::pair<T1, T2> &pair) const { | ||
| 34 | + std::size_t result = 0; | ||
| 35 | + hash_combine(&result, pair.first); | ||
| 36 | + hash_combine(&result, pair.second); | ||
| 37 | + return result; | ||
| 38 | + } | ||
| 39 | +}; | ||
| 40 | +} // namespace | ||
| 41 | + | ||
| 42 | +using Matrix2D = | ||
| 43 | + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; | ||
| 44 | + | ||
| 45 | +using Matrix2DInt32 = | ||
| 46 | + Eigen::Matrix<int32_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; | ||
| 47 | + | ||
| 48 | +using FloatRowVector = Eigen::Matrix<float, 1, Eigen::Dynamic>; | ||
| 49 | +using Int32RowVector = Eigen::Matrix<int32_t, 1, Eigen::Dynamic>; | ||
| 50 | + | ||
| 51 | +using Int32Pair = std::pair<int32_t, int32_t>; | ||
| 52 | + | ||
| 53 | +class OfflineSpeakerDiarizationPyannoteImpl | ||
| 54 | + : public OfflineSpeakerDiarizationImpl { | ||
| 55 | + public: | ||
| 56 | + ~OfflineSpeakerDiarizationPyannoteImpl() override = default; | ||
| 57 | + | ||
| 58 | + explicit OfflineSpeakerDiarizationPyannoteImpl( | ||
| 59 | + const OfflineSpeakerDiarizationConfig &config) | ||
| 60 | + : config_(config), | ||
| 61 | + segmentation_model_(config_.segmentation), | ||
| 62 | + embedding_extractor_(config_.embedding), | ||
| 63 | + clustering_(config_.clustering) { | ||
| 64 | + Init(); | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + int32_t SampleRate() const override { | ||
| 68 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 69 | + | ||
| 70 | + return meta_data.sample_rate; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + OfflineSpeakerDiarizationResult Process( | ||
| 74 | + const float *audio, int32_t n, | ||
| 75 | + OfflineSpeakerDiarizationProgressCallback callback = nullptr, | ||
| 76 | + void *callback_arg = nullptr) const override { | ||
| 77 | + std::vector<Matrix2D> segmentations = RunSpeakerSegmentationModel(audio, n); | ||
| 78 | + // segmentations[i] is for chunk_i | ||
| 79 | + // Each matrix is of shape (num_frames, num_powerset_classes) | ||
| 80 | + if (segmentations.empty()) { | ||
| 81 | + return {}; | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + std::vector<Matrix2DInt32> labels; | ||
| 85 | + labels.reserve(segmentations.size()); | ||
| 86 | + | ||
| 87 | + for (const auto &m : segmentations) { | ||
| 88 | + labels.push_back(ToMultiLabel(m)); | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + segmentations.clear(); | ||
| 92 | + | ||
| 93 | + // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) | ||
| 94 | + | ||
| 95 | + // speaker count per frame | ||
| 96 | + Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels); | ||
| 97 | + | ||
| 98 | + if (speakers_per_frame.maxCoeff() == 0) { | ||
| 99 | + SHERPA_ONNX_LOGE("No speakers found in the audio samples"); | ||
| 100 | + return {}; | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); | ||
| 104 | + Matrix2D embeddings = | ||
| 105 | + ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, | ||
| 106 | + callback, callback_arg); | ||
| 107 | + | ||
| 108 | + std::vector<int32_t> cluster_labels = clustering_.Cluster( | ||
| 109 | + &embeddings(0, 0), embeddings.rows(), embeddings.cols()); | ||
| 110 | + | ||
| 111 | + int32_t max_cluster_index = | ||
| 112 | + *std::max_element(cluster_labels.begin(), cluster_labels.end()); | ||
| 113 | + | ||
| 114 | + auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster( | ||
| 115 | + chunk_speaker_samples_list_pair.first, cluster_labels); | ||
| 116 | + | ||
| 117 | + auto new_labels = | ||
| 118 | + ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster); | ||
| 119 | + | ||
| 120 | + Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n); | ||
| 121 | + | ||
| 122 | + Matrix2DInt32 final_labels = | ||
| 123 | + FinalizeLabels(speaker_count, speakers_per_frame); | ||
| 124 | + | ||
| 125 | + auto result = ComputeResult(final_labels); | ||
| 126 | + | ||
| 127 | + return result; | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + private: | ||
| 131 | + void Init() { InitPowersetMapping(); } | ||
| 132 | + | ||
| 133 | + // see also | ||
| 134 | + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68 | ||
| 135 | + void InitPowersetMapping() { | ||
| 136 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 137 | + int32_t num_classes = meta_data.num_classes; | ||
| 138 | + int32_t powerset_max_classes = meta_data.powerset_max_classes; | ||
| 139 | + int32_t num_speakers = meta_data.num_speakers; | ||
| 140 | + | ||
| 141 | + powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers); | ||
| 142 | + powerset_mapping_.setZero(); | ||
| 143 | + | ||
| 144 | + int32_t k = 1; | ||
| 145 | + for (int32_t i = 1; i <= powerset_max_classes; ++i) { | ||
| 146 | + if (i == 1) { | ||
| 147 | + for (int32_t j = 0; j != num_speakers; ++j, ++k) { | ||
| 148 | + powerset_mapping_(k, j) = 1; | ||
| 149 | + } | ||
| 150 | + } else if (i == 2) { | ||
| 151 | + for (int32_t j = 0; j != num_speakers; ++j) { | ||
| 152 | + for (int32_t m = j + 1; m < num_speakers; ++m, ++k) { | ||
| 153 | + powerset_mapping_(k, j) = 1; | ||
| 154 | + powerset_mapping_(k, m) = 1; | ||
| 155 | + } | ||
| 156 | + } | ||
| 157 | + } else { | ||
| 158 | + SHERPA_ONNX_LOGE( | ||
| 159 | + "powerset_max_classes = %d is currently not supported!", i); | ||
| 160 | + SHERPA_ONNX_EXIT(-1); | ||
| 161 | + } | ||
| 162 | + } | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + std::vector<Matrix2D> RunSpeakerSegmentationModel(const float *audio, | ||
| 166 | + int32_t n) const { | ||
| 167 | + std::vector<Matrix2D> ans; | ||
| 168 | + | ||
| 169 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 170 | + int32_t window_size = meta_data.window_size; | ||
| 171 | + int32_t window_shift = meta_data.window_shift; | ||
| 172 | + | ||
| 173 | + if (n <= 0) { | ||
| 174 | + SHERPA_ONNX_LOGE( | ||
| 175 | + "number of audio samples is %d (<= 0). Please provide a positive " | ||
| 176 | + "number", | ||
| 177 | + n); | ||
| 178 | + return {}; | ||
| 179 | + } | ||
| 180 | + | ||
| 181 | + if (n <= window_size) { | ||
| 182 | + std::vector<float> buf(window_size); | ||
| 183 | + // NOTE: buf is zero initialized by default | ||
| 184 | + | ||
| 185 | + std::copy(audio, audio + n, buf.data()); | ||
| 186 | + | ||
| 187 | + Matrix2D m = ProcessChunk(buf.data()); | ||
| 188 | + | ||
| 189 | + ans.push_back(std::move(m)); | ||
| 190 | + | ||
| 191 | + return ans; | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + int32_t num_chunks = (n - window_size) / window_shift + 1; | ||
| 195 | + bool has_last_chunk = (n - window_size) % window_shift > 0; | ||
| 196 | + | ||
| 197 | + ans.reserve(num_chunks + has_last_chunk); | ||
| 198 | + | ||
| 199 | + const float *p = audio; | ||
| 200 | + | ||
| 201 | + for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) { | ||
| 202 | + Matrix2D m = ProcessChunk(p); | ||
| 203 | + | ||
| 204 | + ans.push_back(std::move(m)); | ||
| 205 | + } | ||
| 206 | + | ||
| 207 | + if (has_last_chunk) { | ||
| 208 | + std::vector<float> buf(window_size); | ||
| 209 | + std::copy(p, audio + n, buf.data()); | ||
| 210 | + | ||
| 211 | + Matrix2D m = ProcessChunk(buf.data()); | ||
| 212 | + | ||
| 213 | + ans.push_back(std::move(m)); | ||
| 214 | + } | ||
| 215 | + | ||
| 216 | + return ans; | ||
| 217 | + } | ||
| 218 | + | ||
| 219 | + Matrix2D ProcessChunk(const float *p) const { | ||
| 220 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 221 | + int32_t window_size = meta_data.window_size; | ||
| 222 | + | ||
| 223 | + auto memory_info = | ||
| 224 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 225 | + | ||
| 226 | + std::array<int64_t, 3> shape = {1, 1, window_size}; | ||
| 227 | + | ||
| 228 | + Ort::Value x = | ||
| 229 | + Ort::Value::CreateTensor(memory_info, const_cast<float *>(p), | ||
| 230 | + window_size, shape.data(), shape.size()); | ||
| 231 | + | ||
| 232 | + Ort::Value out = segmentation_model_.Forward(std::move(x)); | ||
| 233 | + std::vector<int64_t> out_shape = out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 234 | + Matrix2D m(out_shape[1], out_shape[2]); | ||
| 235 | + std::copy(out.GetTensorData<float>(), out.GetTensorData<float>() + m.size(), | ||
| 236 | + &m(0, 0)); | ||
| 237 | + return m; | ||
| 238 | + } | ||
| 239 | + | ||
| 240 | + Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const { | ||
| 241 | + int32_t num_rows = m.rows(); | ||
| 242 | + Matrix2DInt32 ans(num_rows, powerset_mapping_.cols()); | ||
| 243 | + | ||
| 244 | + std::ptrdiff_t col_id; | ||
| 245 | + | ||
| 246 | + for (int32_t i = 0; i != num_rows; ++i) { | ||
| 247 | + m.row(i).maxCoeff(&col_id); | ||
| 248 | + ans.row(i) = powerset_mapping_.row(col_id); | ||
| 249 | + } | ||
| 250 | + | ||
| 251 | + return ans; | ||
| 252 | + } | ||
| 253 | + | ||
| 254 | + // See also | ||
| 255 | + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122 | ||
| 256 | + Int32RowVector ComputeSpeakersPerFrame( | ||
| 257 | + const std::vector<Matrix2DInt32> &labels) const { | ||
| 258 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 259 | + int32_t window_size = meta_data.window_size; | ||
| 260 | + int32_t window_shift = meta_data.window_shift; | ||
| 261 | + int32_t receptive_field_shift = meta_data.receptive_field_shift; | ||
| 262 | + | ||
| 263 | + int32_t num_chunks = labels.size(); | ||
| 264 | + | ||
| 265 | + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / | ||
| 266 | + receptive_field_shift + | ||
| 267 | + 1; | ||
| 268 | + | ||
| 269 | + FloatRowVector count(num_frames); | ||
| 270 | + FloatRowVector weight(num_frames); | ||
| 271 | + count.setZero(); | ||
| 272 | + weight.setZero(); | ||
| 273 | + | ||
| 274 | + for (int32_t i = 0; i != num_chunks; ++i) { | ||
| 275 | + int32_t start = | ||
| 276 | + static_cast<float>(i) * window_shift / receptive_field_shift + 0.5; | ||
| 277 | + | ||
| 278 | + auto seq = Eigen::seqN(start, labels[i].rows()); | ||
| 279 | + | ||
| 280 | + count(seq).array() += labels[i].rowwise().sum().array().cast<float>(); | ||
| 281 | + | ||
| 282 | + weight(seq).array() += 1; | ||
| 283 | + } | ||
| 284 | + | ||
| 285 | + return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast<int32_t>(); | ||
| 286 | + } | ||
| 287 | + | ||
| 288 | + // ans.first: a list of (chunk_id, speaker_id) | ||
| 289 | + // ans.second: a list of list of (start_sample_index, end_sample_index) | ||
| 290 | + // | ||
| 291 | + // ans.first[i] corresponds to ans.second[i] | ||
| 292 | + std::pair<std::vector<Int32Pair>, std::vector<std::vector<Int32Pair>>> | ||
| 293 | + GetChunkSpeakerSampleIndexes(const std::vector<Matrix2DInt32> &labels) const { | ||
| 294 | + auto new_labels = ExcludeOverlap(labels); | ||
| 295 | + | ||
| 296 | + std::vector<Int32Pair> chunk_speaker_list; | ||
| 297 | + std::vector<std::vector<Int32Pair>> samples_index_list; | ||
| 298 | + | ||
| 299 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 300 | + int32_t window_size = meta_data.window_size; | ||
| 301 | + int32_t window_shift = meta_data.window_shift; | ||
| 302 | + int32_t receptive_field_shift = meta_data.receptive_field_shift; | ||
| 303 | + int32_t num_speakers = meta_data.num_speakers; | ||
| 304 | + | ||
| 305 | + int32_t chunk_index = 0; | ||
| 306 | + for (const auto &label : new_labels) { | ||
| 307 | + Matrix2DInt32 tmp = label.transpose(); | ||
| 308 | + // tmp: (num_speakers, num_frames) | ||
| 309 | + int32_t num_frames = tmp.cols(); | ||
| 310 | + | ||
| 311 | + int32_t sample_offset = chunk_index * window_shift; | ||
| 312 | + | ||
| 313 | + for (int32_t speaker_index = 0; speaker_index != num_speakers; | ||
| 314 | + ++speaker_index) { | ||
| 315 | + auto d = tmp.row(speaker_index); | ||
| 316 | + if (d.sum() < 10) { | ||
| 317 | + // skip segments less than 10 frames | ||
| 318 | + continue; | ||
| 319 | + } | ||
| 320 | + | ||
| 321 | + Int32Pair this_chunk_speaker = {chunk_index, speaker_index}; | ||
| 322 | + std::vector<Int32Pair> this_speaker_samples; | ||
| 323 | + | ||
| 324 | + bool is_active = false; | ||
| 325 | + int32_t start_index; | ||
| 326 | + | ||
| 327 | + for (int32_t k = 0; k != num_frames; ++k) { | ||
| 328 | + if (d[k] != 0) { | ||
| 329 | + if (!is_active) { | ||
| 330 | + is_active = true; | ||
| 331 | + start_index = k; | ||
| 332 | + } | ||
| 333 | + } else if (is_active) { | ||
| 334 | + is_active = false; | ||
| 335 | + | ||
| 336 | + int32_t start_samples = | ||
| 337 | + static_cast<float>(start_index) / num_frames * window_size + | ||
| 338 | + sample_offset; | ||
| 339 | + int32_t end_samples = | ||
| 340 | + static_cast<float>(k) / num_frames * window_size + | ||
| 341 | + sample_offset; | ||
| 342 | + | ||
| 343 | + this_speaker_samples.emplace_back(start_samples, end_samples); | ||
| 344 | + } | ||
| 345 | + } | ||
| 346 | + | ||
| 347 | + if (is_active) { | ||
| 348 | + int32_t start_samples = | ||
| 349 | + static_cast<float>(start_index) / num_frames * window_size + | ||
| 350 | + sample_offset; | ||
| 351 | + int32_t end_samples = | ||
| 352 | + static_cast<float>(num_frames - 1) / num_frames * window_size + | ||
| 353 | + sample_offset; | ||
| 354 | + this_speaker_samples.emplace_back(start_samples, end_samples); | ||
| 355 | + } | ||
| 356 | + | ||
| 357 | + chunk_speaker_list.push_back(std::move(this_chunk_speaker)); | ||
| 358 | + samples_index_list.push_back(std::move(this_speaker_samples)); | ||
| 359 | + } // for (int32_t speaker_index = 0; | ||
| 360 | + chunk_index += 1; | ||
| 361 | + } // for (const auto &label : new_labels) | ||
| 362 | + | ||
| 363 | + return {chunk_speaker_list, samples_index_list}; | ||
| 364 | + } | ||
| 365 | + | ||
| 366 | + // If there are multiple speakers at a frame, then this frame is excluded. | ||
| 367 | + std::vector<Matrix2DInt32> ExcludeOverlap( | ||
| 368 | + const std::vector<Matrix2DInt32> &labels) const { | ||
| 369 | + int32_t num_chunks = labels.size(); | ||
| 370 | + std::vector<Matrix2DInt32> ans; | ||
| 371 | + ans.reserve(num_chunks); | ||
| 372 | + | ||
| 373 | + for (const auto &label : labels) { | ||
| 374 | + Matrix2DInt32 new_label(label.rows(), label.cols()); | ||
| 375 | + new_label.setZero(); | ||
| 376 | + Int32RowVector v = label.rowwise().sum(); | ||
| 377 | + | ||
| 378 | + for (int32_t i = 0; i != v.cols(); ++i) { | ||
| 379 | + if (v[i] < 2) { | ||
| 380 | + new_label.row(i) = label.row(i); | ||
| 381 | + } | ||
| 382 | + } | ||
| 383 | + | ||
| 384 | + ans.push_back(std::move(new_label)); | ||
| 385 | + } | ||
| 386 | + | ||
| 387 | + return ans; | ||
| 388 | + } | ||
| 389 | + | ||
| 390 | + /** | ||
| 391 | + * @param sample_indexes[i] contains the sample segment start and end indexes | ||
| 392 | + * for the i-th (chunk, speaker) pair | ||
| 393 | + * @return Return a matrix of shape (sample_indexes.size(), embedding_dim) | ||
| 394 | + * where ans.row[i] contains the embedding for the | ||
| 395 | + * i-th (chunk, speaker) pair | ||
| 396 | + */ | ||
| 397 | + Matrix2D ComputeEmbeddings( | ||
| 398 | + const float *audio, int32_t n, | ||
| 399 | + const std::vector<std::vector<Int32Pair>> &sample_indexes, | ||
| 400 | + OfflineSpeakerDiarizationProgressCallback callback, | ||
| 401 | + void *callback_arg) const { | ||
| 402 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 403 | + int32_t sample_rate = meta_data.sample_rate; | ||
| 404 | + Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); | ||
| 405 | + | ||
| 406 | + int32_t k = 0; | ||
| 407 | + for (const auto &v : sample_indexes) { | ||
| 408 | + auto stream = embedding_extractor_.CreateStream(); | ||
| 409 | + for (const auto &p : v) { | ||
| 410 | + int32_t end = (p.second <= n) ? p.second : n; | ||
| 411 | + int32_t num_samples = end - p.first; | ||
| 412 | + | ||
| 413 | + if (num_samples > 0) { | ||
| 414 | + stream->AcceptWaveform(sample_rate, audio + p.first, num_samples); | ||
| 415 | + } | ||
| 416 | + } | ||
| 417 | + | ||
| 418 | + stream->InputFinished(); | ||
| 419 | + if (!embedding_extractor_.IsReady(stream.get())) { | ||
| 420 | + SHERPA_ONNX_LOGE( | ||
| 421 | + "This segment is too short, which should not happen since we have " | ||
| 422 | + "already filtered short segments"); | ||
| 423 | + SHERPA_ONNX_EXIT(-1); | ||
| 424 | + } | ||
| 425 | + | ||
| 426 | + std::vector<float> embedding = embedding_extractor_.Compute(stream.get()); | ||
| 427 | + | ||
| 428 | + std::copy(embedding.begin(), embedding.end(), &ans(k, 0)); | ||
| 429 | + | ||
| 430 | + k += 1; | ||
| 431 | + | ||
| 432 | + if (callback) { | ||
| 433 | + callback(k, ans.rows(), callback_arg); | ||
| 434 | + } | ||
| 435 | + } | ||
| 436 | + | ||
| 437 | + return ans; | ||
| 438 | + } | ||
| 439 | + | ||
| 440 | + std::unordered_map<Int32Pair, int32_t, PairHash> ConvertChunkSpeakerToCluster( | ||
| 441 | + const std::vector<Int32Pair> &chunk_speaker_pair, | ||
| 442 | + const std::vector<int32_t> &cluster_labels) const { | ||
| 443 | + std::unordered_map<Int32Pair, int32_t, PairHash> ans; | ||
| 444 | + | ||
| 445 | + int32_t k = 0; | ||
| 446 | + for (const auto &p : chunk_speaker_pair) { | ||
| 447 | + ans[p] = cluster_labels[k]; | ||
| 448 | + k += 1; | ||
| 449 | + } | ||
| 450 | + | ||
| 451 | + return ans; | ||
| 452 | + } | ||
| 453 | + | ||
| 454 | + std::vector<Matrix2DInt32> ReLabel( | ||
| 455 | + const std::vector<Matrix2DInt32> &labels, int32_t max_cluster_index, | ||
| 456 | + std::unordered_map<Int32Pair, int32_t, PairHash> chunk_speaker_to_cluster) | ||
| 457 | + const { | ||
| 458 | + std::vector<Matrix2DInt32> new_labels; | ||
| 459 | + new_labels.reserve(labels.size()); | ||
| 460 | + | ||
| 461 | + int32_t chunk_index = 0; | ||
| 462 | + for (const auto &label : labels) { | ||
| 463 | + Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1); | ||
| 464 | + new_label.setZero(); | ||
| 465 | + | ||
| 466 | + Matrix2DInt32 t = label.transpose(); | ||
| 467 | + // t: (num_speakers, num_frames) | ||
| 468 | + | ||
| 469 | + for (int32_t speaker_index = 0; speaker_index != t.rows(); | ||
| 470 | + ++speaker_index) { | ||
| 471 | + if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) { | ||
| 472 | + continue; | ||
| 473 | + } | ||
| 474 | + | ||
| 475 | + int32_t new_speaker_index = | ||
| 476 | + chunk_speaker_to_cluster.at({chunk_index, speaker_index}); | ||
| 477 | + | ||
| 478 | + for (int32_t k = 0; k != t.cols(); ++k) { | ||
| 479 | + if (t(speaker_index, k) == 1) { | ||
| 480 | + new_label(k, new_speaker_index) = 1; | ||
| 481 | + } | ||
| 482 | + } | ||
| 483 | + } | ||
| 484 | + | ||
| 485 | + new_labels.push_back(std::move(new_label)); | ||
| 486 | + | ||
| 487 | + chunk_index += 1; | ||
| 488 | + } | ||
| 489 | + | ||
| 490 | + return new_labels; | ||
| 491 | + } | ||
| 492 | + | ||
| 493 | + Matrix2DInt32 ComputeSpeakerCount(const std::vector<Matrix2DInt32> &labels, | ||
| 494 | + int32_t num_samples) const { | ||
| 495 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 496 | + int32_t window_size = meta_data.window_size; | ||
| 497 | + int32_t window_shift = meta_data.window_shift; | ||
| 498 | + int32_t receptive_field_shift = meta_data.receptive_field_shift; | ||
| 499 | + | ||
| 500 | + int32_t num_chunks = labels.size(); | ||
| 501 | + | ||
| 502 | + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / | ||
| 503 | + receptive_field_shift + | ||
| 504 | + 1; | ||
| 505 | + | ||
| 506 | + Matrix2DInt32 count(num_frames, labels[0].cols()); | ||
| 507 | + count.setZero(); | ||
| 508 | + | ||
| 509 | + for (int32_t i = 0; i != num_chunks; ++i) { | ||
| 510 | + int32_t start = | ||
| 511 | + static_cast<float>(i) * window_shift / receptive_field_shift + 0.5; | ||
| 512 | + | ||
| 513 | + auto seq = Eigen::seqN(start, labels[i].rows()); | ||
| 514 | + | ||
| 515 | + count(seq, Eigen::all).array() += labels[i].array(); | ||
| 516 | + } | ||
| 517 | + | ||
| 518 | + bool has_last_chunk = (num_samples - window_size) % window_shift > 0; | ||
| 519 | + | ||
| 520 | + if (has_last_chunk) { | ||
| 521 | + return count; | ||
| 522 | + } | ||
| 523 | + | ||
| 524 | + int32_t last_frame = num_samples / receptive_field_shift; | ||
| 525 | + return count(Eigen::seq(0, last_frame), Eigen::all); | ||
| 526 | + } | ||
| 527 | + | ||
| 528 | + Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count, | ||
| 529 | + const Int32RowVector &speakers_per_frame) const { | ||
| 530 | + int32_t num_rows = count.rows(); | ||
| 531 | + int32_t num_cols = count.cols(); | ||
| 532 | + | ||
| 533 | + Matrix2DInt32 ans(num_rows, num_cols); | ||
| 534 | + ans.setZero(); | ||
| 535 | + | ||
| 536 | + for (int32_t i = 0; i != num_rows; ++i) { | ||
| 537 | + int32_t k = speakers_per_frame[i]; | ||
| 538 | + if (k == 0) { | ||
| 539 | + continue; | ||
| 540 | + } | ||
| 541 | + auto top_k = TopkIndex(&count(i, 0), num_cols, k); | ||
| 542 | + | ||
| 543 | + for (int32_t m : top_k) { | ||
| 544 | + ans(i, m) = 1; | ||
| 545 | + } | ||
| 546 | + } | ||
| 547 | + | ||
| 548 | + return ans; | ||
| 549 | + } | ||
| 550 | + | ||
| 551 | + OfflineSpeakerDiarizationResult ComputeResult( | ||
| 552 | + const Matrix2DInt32 &final_labels) const { | ||
| 553 | + Matrix2DInt32 final_labels_t = final_labels.transpose(); | ||
| 554 | + int32_t num_speakers = final_labels_t.rows(); | ||
| 555 | + int32_t num_frames = final_labels_t.cols(); | ||
| 556 | + | ||
| 557 | + const auto &meta_data = segmentation_model_.GetModelMetaData(); | ||
| 558 | + int32_t window_size = meta_data.window_size; | ||
| 559 | + int32_t window_shift = meta_data.window_shift; | ||
| 560 | + int32_t receptive_field_shift = meta_data.receptive_field_shift; | ||
| 561 | + int32_t receptive_field_size = meta_data.receptive_field_size; | ||
| 562 | + int32_t sample_rate = meta_data.sample_rate; | ||
| 563 | + | ||
| 564 | + float scale = static_cast<float>(receptive_field_shift) / sample_rate; | ||
| 565 | + float scale_offset = 0.5 * receptive_field_size / sample_rate; | ||
| 566 | + | ||
| 567 | + OfflineSpeakerDiarizationResult ans; | ||
| 568 | + | ||
| 569 | + for (int32_t speaker_index = 0; speaker_index != num_speakers; | ||
| 570 | + ++speaker_index) { | ||
| 571 | + std::vector<OfflineSpeakerDiarizationSegment> this_speaker; | ||
| 572 | + | ||
| 573 | + bool is_active = final_labels_t(speaker_index, 0) > 0; | ||
| 574 | + int32_t start_index = is_active ? 0 : -1; | ||
| 575 | + | ||
| 576 | + for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) { | ||
| 577 | + if (is_active) { | ||
| 578 | + if (final_labels_t(speaker_index, frame_index) == 0) { | ||
| 579 | + float start_time = start_index * scale + scale_offset; | ||
| 580 | + float end_time = frame_index * scale + scale_offset; | ||
| 581 | + | ||
| 582 | + OfflineSpeakerDiarizationSegment segment(start_time, end_time, | ||
| 583 | + speaker_index); | ||
| 584 | + this_speaker.push_back(segment); | ||
| 585 | + | ||
| 586 | + is_active = false; | ||
| 587 | + } | ||
| 588 | + } else if (final_labels_t(speaker_index, frame_index) == 1) { | ||
| 589 | + is_active = true; | ||
| 590 | + start_index = frame_index; | ||
| 591 | + } | ||
| 592 | + } | ||
| 593 | + | ||
| 594 | + if (is_active) { | ||
| 595 | + float start_time = start_index * scale + scale_offset; | ||
| 596 | + float end_time = (num_frames - 1) * scale + scale_offset; | ||
| 597 | + | ||
| 598 | + OfflineSpeakerDiarizationSegment segment(start_time, end_time, | ||
| 599 | + speaker_index); | ||
| 600 | + this_speaker.push_back(segment); | ||
| 601 | + } | ||
| 602 | + | ||
| 603 | + // merge segments if the gap between them is less than min_duration_off | ||
| 604 | + MergeSegments(&this_speaker); | ||
| 605 | + | ||
| 606 | + for (const auto &seg : this_speaker) { | ||
| 607 | + if (seg.Duration() > config_.min_duration_on) { | ||
| 608 | + ans.Add(seg); | ||
| 609 | + } | ||
| 610 | + } | ||
| 611 | + } // for (int32_t speaker_index = 0; speaker_index != num_speakers; | ||
| 612 | + | ||
| 613 | + return ans; | ||
| 614 | + } | ||
| 615 | + | ||
| 616 | + void MergeSegments( | ||
| 617 | + std::vector<OfflineSpeakerDiarizationSegment> *segments) const { | ||
| 618 | + float min_duration_off = config_.min_duration_off; | ||
| 619 | + bool changed = true; | ||
| 620 | + while (changed) { | ||
| 621 | + changed = false; | ||
| 622 | + for (int32_t i = 0; i < static_cast<int32_t>(segments->size()) - 1; ++i) { | ||
| 623 | + auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off); | ||
| 624 | + if (s) { | ||
| 625 | + (*segments)[i] = s.value(); | ||
| 626 | + segments->erase(segments->begin() + i + 1); | ||
| 627 | + | ||
| 628 | + changed = true; | ||
| 629 | + break; | ||
| 630 | + } | ||
| 631 | + } | ||
| 632 | + } | ||
| 633 | + } | ||
| 634 | + | ||
| 635 | + private: | ||
| 636 | + OfflineSpeakerDiarizationConfig config_; | ||
| 637 | + OfflineSpeakerSegmentationPyannoteModel segmentation_model_; | ||
| 638 | + SpeakerEmbeddingExtractor embedding_extractor_; | ||
| 639 | + FastClustering clustering_; | ||
| 640 | + Matrix2DInt32 powerset_mapping_; | ||
| 641 | +}; | ||
| 642 | + | ||
| 643 | +} // namespace sherpa_onnx | ||
| 644 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization-result.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <sstream> | ||
| 9 | +#include <string> | ||
| 10 | +#include <unordered_set> | ||
| 11 | +#include <utility> | ||
| 12 | + | ||
| 13 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( | ||
| 18 | + float start, float end, int32_t speaker, const std::string &text /*= {}*/) { | ||
| 19 | + if (start > end) { | ||
| 20 | + SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end); | ||
| 21 | + SHERPA_ONNX_EXIT(-1); | ||
| 22 | + } | ||
| 23 | + | ||
| 24 | + start_ = start; | ||
| 25 | + end_ = end; | ||
| 26 | + speaker_ = speaker; | ||
| 27 | + text_ = text; | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +std::optional<OfflineSpeakerDiarizationSegment> | ||
| 31 | +OfflineSpeakerDiarizationSegment::Merge( | ||
| 32 | + const OfflineSpeakerDiarizationSegment &other, float gap) const { | ||
| 33 | + if (other.speaker_ != speaker_) { | ||
| 34 | + SHERPA_ONNX_LOGE( | ||
| 35 | + "The two segments should have the same speaker. this->speaker: %d, " | ||
| 36 | + "other.speaker: %d", | ||
| 37 | + speaker_, other.speaker_); | ||
| 38 | + return std::nullopt; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + if (end_ < other.start_ && end_ + gap >= other.start_) { | ||
| 42 | + return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_); | ||
| 43 | + } else if (other.end_ < start_ && other.end_ + gap >= start_) { | ||
| 44 | + return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_); | ||
| 45 | + } else { | ||
| 46 | + return std::nullopt; | ||
| 47 | + } | ||
| 48 | +} | ||
| 49 | + | ||
| 50 | +std::string OfflineSpeakerDiarizationSegment::ToString() const { | ||
| 51 | + char s[128]; | ||
| 52 | + snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, speaker_); | ||
| 53 | + | ||
| 54 | + std::ostringstream os; | ||
| 55 | + os << s; | ||
| 56 | + | ||
| 57 | + if (!text_.empty()) { | ||
| 58 | + os << " " << text_; | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + return os.str(); | ||
| 62 | +} | ||
| 63 | + | ||
| 64 | +void OfflineSpeakerDiarizationResult::Add( | ||
| 65 | + const OfflineSpeakerDiarizationSegment &segment) { | ||
| 66 | + segments_.push_back(segment); | ||
| 67 | +} | ||
| 68 | + | ||
| 69 | +int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const { | ||
| 70 | + std::unordered_set<int32_t> count; | ||
| 71 | + for (const auto &s : segments_) { | ||
| 72 | + count.insert(s.Speaker()); | ||
| 73 | + } | ||
| 74 | + | ||
| 75 | + return count.size(); | ||
| 76 | +} | ||
| 77 | + | ||
| 78 | +int32_t OfflineSpeakerDiarizationResult::NumSegments() const { | ||
| 79 | + return segments_.size(); | ||
| 80 | +} | ||
| 81 | + | ||
| 82 | +// Return a list of segments sorted by segment.start time | ||
| 83 | +std::vector<OfflineSpeakerDiarizationSegment> | ||
| 84 | +OfflineSpeakerDiarizationResult::SortByStartTime() const { | ||
| 85 | + auto ans = segments_; | ||
| 86 | + std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) { | ||
| 87 | + return (a.Start() < b.Start()) || | ||
| 88 | + ((a.Start() == b.Start()) && (a.Speaker() < b.Speaker())); | ||
| 89 | + }); | ||
| 90 | + | ||
| 91 | + return ans; | ||
| 92 | +} | ||
| 93 | + | ||
| 94 | +std::vector<std::vector<OfflineSpeakerDiarizationSegment>> | ||
| 95 | +OfflineSpeakerDiarizationResult::SortBySpeaker() const { | ||
| 96 | + auto tmp = segments_; | ||
| 97 | + std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) { | ||
| 98 | + return (a.Speaker() < b.Speaker()) || | ||
| 99 | + ((a.Speaker() == b.Speaker()) && (a.Start() < b.Start())); | ||
| 100 | + }); | ||
| 101 | + | ||
| 102 | + std::vector<std::vector<OfflineSpeakerDiarizationSegment>> ans(NumSpeakers()); | ||
| 103 | + for (auto &s : tmp) { | ||
| 104 | + ans[s.Speaker()].push_back(std::move(s)); | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + return ans; | ||
| 108 | +} | ||
| 109 | + | ||
| 110 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization-result.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ | ||
| 7 | + | ||
| 8 | +#include <cstdint> | ||
| 9 | +#include <optional> | ||
| 10 | +#include <string> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineSpeakerDiarizationSegment { | ||
| 16 | + public: | ||
| 17 | + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker, | ||
| 18 | + const std::string &text = {}); | ||
| 19 | + | ||
| 20 | + // If the gap between the two segments is less than the given gap, then we | ||
| 21 | + // merge them and return a new segment. Otherwise, it returns null. | ||
| 22 | + std::optional<OfflineSpeakerDiarizationSegment> Merge( | ||
| 23 | + const OfflineSpeakerDiarizationSegment &other, float gap) const; | ||
| 24 | + | ||
| 25 | + float Start() const { return start_; } | ||
| 26 | + float End() const { return end_; } | ||
| 27 | + int32_t Speaker() const { return speaker_; } | ||
| 28 | + const std::string &Text() const { return text_; } | ||
| 29 | + float Duration() const { return end_ - start_; } | ||
| 30 | + | ||
| 31 | + std::string ToString() const; | ||
| 32 | + | ||
| 33 | + private: | ||
| 34 | + float start_; // in seconds | ||
| 35 | + float end_; // in seconds | ||
| 36 | + int32_t speaker_; // ID of the speaker, starting from 0 | ||
| 37 | + std::string text_; // If not empty, it contains the speech recognition result | ||
| 38 | + // of this segment | ||
| 39 | +}; | ||
| 40 | + | ||
| 41 | +class OfflineSpeakerDiarizationResult { | ||
| 42 | + public: | ||
| 43 | + // Add a new segment | ||
| 44 | + void Add(const OfflineSpeakerDiarizationSegment &segment); | ||
| 45 | + | ||
| 46 | + // Number of distinct speakers contained in this object at this point | ||
| 47 | + int32_t NumSpeakers() const; | ||
| 48 | + | ||
| 49 | + int32_t NumSegments() const; | ||
| 50 | + | ||
| 51 | + // Return a list of segments sorted by segment.start time | ||
| 52 | + std::vector<OfflineSpeakerDiarizationSegment> SortByStartTime() const; | ||
| 53 | + | ||
| 54 | + // ans.size() == NumSpeakers(). | ||
| 55 | + // ans[i] is for speaker_i and is sorted by start time | ||
| 56 | + std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker() | ||
| 57 | + const; | ||
| 58 | + | ||
| 59 | + public: | ||
| 60 | + std::vector<OfflineSpeakerDiarizationSegment> segments_; | ||
| 61 | +}; | ||
| 62 | + | ||
| 63 | +} // namespace sherpa_onnx | ||
| 64 | + | ||
| 65 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { | ||
| 14 | + ParseOptions po_segmentation("segmentation", po); | ||
| 15 | + segmentation.Register(&po_segmentation); | ||
| 16 | + | ||
| 17 | + ParseOptions po_embedding("embedding", po); | ||
| 18 | + embedding.Register(&po_embedding); | ||
| 19 | + | ||
| 20 | + ParseOptions po_clustering("clustering", po); | ||
| 21 | + clustering.Register(&po_clustering); | ||
| 22 | + | ||
| 23 | + po->Register("min-duration-on", &min_duration_on, | ||
| 24 | + "if a segment is less than this value, then it is discarded. " | ||
| 25 | + "Set it to 0 so that no segment is discarded"); | ||
| 26 | + | ||
| 27 | + po->Register("min-duration-off", &min_duration_off, | ||
| 28 | + "if the gap between to segments of the same speaker is less " | ||
| 29 | + "than this value, then these two segments are merged into a " | ||
| 30 | + "single segment. We do it recursively."); | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +bool OfflineSpeakerDiarizationConfig::Validate() const { | ||
| 34 | + if (!segmentation.Validate()) { | ||
| 35 | + return false; | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + if (!embedding.Validate()) { | ||
| 39 | + return false; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + if (!clustering.Validate()) { | ||
| 43 | + return false; | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + return true; | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +std::string OfflineSpeakerDiarizationConfig::ToString() const { | ||
| 50 | + std::ostringstream os; | ||
| 51 | + | ||
| 52 | + os << "OfflineSpeakerDiarizationConfig("; | ||
| 53 | + os << "segmentation=" << segmentation.ToString() << ", "; | ||
| 54 | + os << "embedding=" << embedding.ToString() << ", "; | ||
| 55 | + os << "clustering=" << clustering.ToString() << ", "; | ||
| 56 | + os << "min_duration_on=" << min_duration_on << ", "; | ||
| 57 | + os << "min_duration_off=" << min_duration_off << ")"; | ||
| 58 | + | ||
| 59 | + return os.str(); | ||
| 60 | +} | ||
| 61 | + | ||
| 62 | +OfflineSpeakerDiarization::OfflineSpeakerDiarization( | ||
| 63 | + const OfflineSpeakerDiarizationConfig &config) | ||
| 64 | + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} | ||
| 65 | + | ||
| 66 | +OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; | ||
| 67 | + | ||
| 68 | +int32_t OfflineSpeakerDiarization::SampleRate() const { | ||
| 69 | + return impl_->SampleRate(); | ||
| 70 | +} | ||
| 71 | + | ||
| 72 | +OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( | ||
| 73 | + const float *audio, int32_t n, | ||
| 74 | + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, | ||
| 75 | + void *callback_arg /*= nullptr*/) const { | ||
| 76 | + return impl_->Process(audio, n, callback, callback_arg); | ||
| 77 | +} | ||
| 78 | + | ||
| 79 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-diarization.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ | ||
| 7 | + | ||
| 8 | +#include <functional> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/fast-clustering-config.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" | ||
| 14 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" | ||
| 15 | +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 16 | + | ||
| 17 | +namespace sherpa_onnx { | ||
| 18 | + | ||
| 19 | +struct OfflineSpeakerDiarizationConfig { | ||
| 20 | + OfflineSpeakerSegmentationModelConfig segmentation; | ||
| 21 | + SpeakerEmbeddingExtractorConfig embedding; | ||
| 22 | + FastClusteringConfig clustering; | ||
| 23 | + | ||
| 24 | + // if a segment is less than this value, then it is discarded | ||
| 25 | + float min_duration_on = 0.3; // in seconds | ||
| 26 | + | ||
| 27 | + // if the gap between to segments of the same speaker is less than this value, | ||
| 28 | + // then these two segments are merged into a single segment. | ||
| 29 | + // We do this recursively. | ||
| 30 | + float min_duration_off = 0.5; // in seconds | ||
| 31 | + | ||
| 32 | + OfflineSpeakerDiarizationConfig() = default; | ||
| 33 | + | ||
| 34 | + OfflineSpeakerDiarizationConfig( | ||
| 35 | + const OfflineSpeakerSegmentationModelConfig &segmentation, | ||
| 36 | + const SpeakerEmbeddingExtractorConfig &embedding, | ||
| 37 | + const FastClusteringConfig &clustering) | ||
| 38 | + : segmentation(segmentation), | ||
| 39 | + embedding(embedding), | ||
| 40 | + clustering(clustering) {} | ||
| 41 | + | ||
| 42 | + void Register(ParseOptions *po); | ||
| 43 | + bool Validate() const; | ||
| 44 | + std::string ToString() const; | ||
| 45 | +}; | ||
| 46 | + | ||
| 47 | +class OfflineSpeakerDiarizationImpl; | ||
| 48 | + | ||
| 49 | +using OfflineSpeakerDiarizationProgressCallback = std::function<int32_t( | ||
| 50 | + int32_t processed_chunks, int32_t num_chunks, void *arg)>; | ||
| 51 | + | ||
| 52 | +class OfflineSpeakerDiarization { | ||
| 53 | + public: | ||
| 54 | + explicit OfflineSpeakerDiarization( | ||
| 55 | + const OfflineSpeakerDiarizationConfig &config); | ||
| 56 | + | ||
| 57 | + ~OfflineSpeakerDiarization(); | ||
| 58 | + | ||
| 59 | + // Expected sample rate of the input audio samples | ||
| 60 | + int32_t SampleRate() const; | ||
| 61 | + | ||
| 62 | + OfflineSpeakerDiarizationResult Process( | ||
| 63 | + const float *audio, int32_t n, | ||
| 64 | + OfflineSpeakerDiarizationProgressCallback callback = nullptr, | ||
| 65 | + void *callback_arg = nullptr) const; | ||
| 66 | + | ||
| 67 | + private: | ||
| 68 | + std::unique_ptr<OfflineSpeakerDiarizationImpl> impl_; | ||
| 69 | +}; | ||
| 70 | + | ||
| 71 | +} // namespace sherpa_onnx | ||
| 72 | + | ||
| 73 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" | ||
| 5 | + | ||
| 6 | +#include <sstream> | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) { | ||
| 14 | + pyannote.Register(po); | ||
| 15 | + | ||
| 16 | + po->Register("num-threads", &num_threads, | ||
| 17 | + "Number of threads to run the neural network"); | ||
| 18 | + | ||
| 19 | + po->Register("debug", &debug, | ||
| 20 | + "true to print model information while loading it."); | ||
| 21 | + | ||
| 22 | + po->Register("provider", &provider, | ||
| 23 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +bool OfflineSpeakerSegmentationModelConfig::Validate() const { | ||
| 27 | + if (num_threads < 1) { | ||
| 28 | + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
| 29 | + return false; | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + if (!pyannote.model.empty()) { | ||
| 33 | + return pyannote.Validate(); | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + if (pyannote.model.empty()) { | ||
| 37 | + SHERPA_ONNX_LOGE( | ||
| 38 | + "You have to provide at least one speaker segmentation model"); | ||
| 39 | + return false; | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + return true; | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +std::string OfflineSpeakerSegmentationModelConfig::ToString() const { | ||
| 46 | + std::ostringstream os; | ||
| 47 | + | ||
| 48 | + os << "OfflineSpeakerSegmentationModelConfig("; | ||
| 49 | + os << "pyannote=" << pyannote.ToString() << ", "; | ||
| 50 | + os << "num_threads=" << num_threads << ", "; | ||
| 51 | + os << "debug=" << (debug ? "True" : "False") << ", "; | ||
| 52 | + os << "provider=\"" << provider << "\")"; | ||
| 53 | + | ||
| 54 | + return os.str(); | ||
| 55 | +} | ||
| 56 | + | ||
| 57 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineSpeakerSegmentationModelConfig { | ||
| 15 | + OfflineSpeakerSegmentationPyannoteModelConfig pyannote; | ||
| 16 | + | ||
| 17 | + int32_t num_threads = 1; | ||
| 18 | + bool debug = false; | ||
| 19 | + std::string provider = "cpu"; | ||
| 20 | + | ||
| 21 | + OfflineSpeakerSegmentationModelConfig() = default; | ||
| 22 | + | ||
| 23 | + explicit OfflineSpeakerSegmentationModelConfig( | ||
| 24 | + const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote, | ||
| 25 | + int32_t num_threads, bool debug, const std::string &provider) | ||
| 26 | + : pyannote(pyannote), | ||
| 27 | + num_threads(num_threads), | ||
| 28 | + debug(debug), | ||
| 29 | + provider(provider) {} | ||
| 30 | + | ||
| 31 | + void Register(ParseOptions *po); | ||
| 32 | + | ||
| 33 | + bool Validate() const; | ||
| 34 | + | ||
| 35 | + std::string ToString() const; | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx | ||
| 39 | + | ||
| 40 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" | ||
| 5 | + | ||
| 6 | +#include <sstream> | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) { | ||
| 15 | + po->Register("pyannote-model", &model, | ||
| 16 | + "Path to model.onnx of the Pyannote segmentation model."); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const { | ||
| 20 | + if (!FileExists(model)) { | ||
| 21 | + SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist", | ||
| 22 | + model.c_str()); | ||
| 23 | + return false; | ||
| 24 | + } | ||
| 25 | + | ||
| 26 | + return true; | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const { | ||
| 30 | + std::ostringstream os; | ||
| 31 | + | ||
| 32 | + os << "OfflineSpeakerSegmentationPyannoteModelConfig("; | ||
| 33 | + os << "model=\"" << model << "\")"; | ||
| 34 | + | ||
| 35 | + return os.str(); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineSpeakerSegmentationPyannoteModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OfflineSpeakerSegmentationPyannoteModelConfig() = default; | ||
| 17 | + | ||
| 18 | + explicit OfflineSpeakerSegmentationPyannoteModelConfig( | ||
| 19 | + const std::string &model) | ||
| 20 | + : model(model) {} | ||
| 21 | + | ||
| 22 | + void Register(ParseOptions *po); | ||
| 23 | + bool Validate() const; | ||
| 24 | + | ||
| 25 | + std::string ToString() const; | ||
| 26 | +}; | ||
| 27 | + | ||
| 28 | +} // namespace sherpa_onnx | ||
| 29 | + | ||
| 30 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ | ||
| 7 | + | ||
| 8 | +#include <cstdint> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +// If you are not sure what each field means, please | ||
| 14 | +// have a look of the Python file in the model directory that | ||
| 15 | +// you have downloaded. | ||
| 16 | +struct OfflineSpeakerSegmentationPyannoteModelMetaData { | ||
| 17 | + int32_t sample_rate = 0; | ||
| 18 | + int32_t window_size = 0; // in samples | ||
| 19 | + int32_t window_shift = 0; // in samples | ||
| 20 | + int32_t receptive_field_size = 0; // in samples | ||
| 21 | + int32_t receptive_field_shift = 0; // in samples | ||
| 22 | + int32_t num_speakers = 0; | ||
| 23 | + int32_t powerset_max_classes = 0; | ||
| 24 | + int32_t num_classes = 0; | ||
| 25 | +}; | ||
| 26 | + | ||
| 27 | +} // namespace sherpa_onnx | ||
| 28 | + | ||
| 29 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | +#include "sherpa-onnx/csrc/session.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineSpeakerSegmentationPyannoteModel::Impl { | ||
| 17 | + public: | ||
| 18 | + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) | ||
| 19 | + : config_(config), | ||
| 20 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 21 | + sess_opts_(GetSessionOptions(config)), | ||
| 22 | + allocator_{} { | ||
| 23 | + auto buf = ReadFile(config_.pyannote.model); | ||
| 24 | + Init(buf.data(), buf.size()); | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() | ||
| 28 | + const { | ||
| 29 | + return meta_data_; | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + Ort::Value Forward(Ort::Value x) { | ||
| 33 | + auto out = sess_->Run({}, input_names_ptr_.data(), &x, 1, | ||
| 34 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 35 | + | ||
| 36 | + return std::move(out[0]); | ||
| 37 | + } | ||
| 38 | + | ||
| 39 | + private: | ||
| 40 | + void Init(void *model_data, size_t model_data_length) { | ||
| 41 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 42 | + sess_opts_); | ||
| 43 | + | ||
| 44 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 45 | + | ||
| 46 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 47 | + | ||
| 48 | + // get meta data | ||
| 49 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 50 | + if (config_.debug) { | ||
| 51 | + std::ostringstream os; | ||
| 52 | + PrintModelMetadata(os, meta_data); | ||
| 53 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 57 | + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); | ||
| 58 | + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size"); | ||
| 59 | + | ||
| 60 | + meta_data_.window_shift = | ||
| 61 | + static_cast<int32_t>(0.1 * meta_data_.window_size); | ||
| 62 | + | ||
| 63 | + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size, | ||
| 64 | + "receptive_field_size"); | ||
| 65 | + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift, | ||
| 66 | + "receptive_field_shift"); | ||
| 67 | + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers"); | ||
| 68 | + SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes, | ||
| 69 | + "powerset_max_classes"); | ||
| 70 | + SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes"); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + private: | ||
| 74 | + OfflineSpeakerSegmentationModelConfig config_; | ||
| 75 | + Ort::Env env_; | ||
| 76 | + Ort::SessionOptions sess_opts_; | ||
| 77 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 78 | + | ||
| 79 | + std::unique_ptr<Ort::Session> sess_; | ||
| 80 | + | ||
| 81 | + std::vector<std::string> input_names_; | ||
| 82 | + std::vector<const char *> input_names_ptr_; | ||
| 83 | + | ||
| 84 | + std::vector<std::string> output_names_; | ||
| 85 | + std::vector<const char *> output_names_ptr_; | ||
| 86 | + | ||
| 87 | + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; | ||
| 88 | +}; | ||
| 89 | + | ||
| 90 | +OfflineSpeakerSegmentationPyannoteModel:: | ||
| 91 | + OfflineSpeakerSegmentationPyannoteModel( | ||
| 92 | + const OfflineSpeakerSegmentationModelConfig &config) | ||
| 93 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 94 | + | ||
| 95 | +OfflineSpeakerSegmentationPyannoteModel:: | ||
| 96 | + ~OfflineSpeakerSegmentationPyannoteModel() = default; | ||
| 97 | + | ||
| 98 | +const OfflineSpeakerSegmentationPyannoteModelMetaData & | ||
| 99 | +OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const { | ||
| 100 | + return impl_->GetModelMetaData(); | ||
| 101 | +} | ||
| 102 | + | ||
| 103 | +Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward( | ||
| 104 | + Ort::Value x) const { | ||
| 105 | + return impl_->Forward(std::move(x)); | ||
| 106 | +} | ||
| 107 | + | ||
| 108 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | + | ||
| 9 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 10 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineSpeakerSegmentationPyannoteModel { | ||
| 16 | + public: | ||
| 17 | + explicit OfflineSpeakerSegmentationPyannoteModel( | ||
| 18 | + const OfflineSpeakerSegmentationModelConfig &config); | ||
| 19 | + | ||
| 20 | + ~OfflineSpeakerSegmentationPyannoteModel(); | ||
| 21 | + | ||
| 22 | + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() | ||
| 23 | + const; | ||
| 24 | + | ||
| 25 | + /** | ||
| 26 | + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) | ||
| 27 | + * @return Return a float tensor of | ||
| 28 | + * shape (batch_size, num_frames, num_speakers). Note that | ||
| 29 | + * num_speakers here uses powerset encoding. | ||
| 30 | + */ | ||
| 31 | + Ort::Value Forward(Ort::Value x) const; | ||
| 32 | + | ||
| 33 | + private: | ||
| 34 | + class Impl; | ||
| 35 | + std::unique_ptr<Impl> impl_; | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx | ||
| 39 | + | ||
| 40 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ |
| @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) { | @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) { | ||
| 61 | 61 | ||
| 62 | bool TensorrtConfig::Validate() const { | 62 | bool TensorrtConfig::Validate() const { |
| 63 | if (trt_max_workspace_size < 0) { | 63 | if (trt_max_workspace_size < 0) { |
| 64 | - SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.", | ||
| 65 | - trt_max_workspace_size); | 64 | + std::ostringstream os; |
| 65 | + os << "trt_max_workspace_size: " << trt_max_workspace_size | ||
| 66 | + << " is not valid."; | ||
| 67 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 66 | return false; | 68 | return false; |
| 67 | } | 69 | } |
| 68 | if (trt_max_partition_iterations < 0) { | 70 | if (trt_max_partition_iterations < 0) { |
| @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { | @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { | ||
| 35 | api.ReleaseStatus(status); | 35 | api.ReleaseStatus(status); |
| 36 | } | 36 | } |
| 37 | 37 | ||
| 38 | -static Ort::SessionOptions GetSessionOptionsImpl( | 38 | +Ort::SessionOptions GetSessionOptionsImpl( |
| 39 | int32_t num_threads, const std::string &provider_str, | 39 | int32_t num_threads, const std::string &provider_str, |
| 40 | - const ProviderConfig *provider_config = nullptr) { | 40 | + const ProviderConfig *provider_config /*= nullptr*/) { |
| 41 | Provider p = StringToProvider(provider_str); | 41 | Provider p = StringToProvider(provider_str); |
| 42 | 42 | ||
| 43 | Ort::SessionOptions sess_opts; | 43 | Ort::SessionOptions sess_opts; |
| @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | ||
| 259 | &config.provider_config); | 259 | &config.provider_config); |
| 260 | } | 260 | } |
| 261 | 261 | ||
| 262 | -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | ||
| 263 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 264 | -} | ||
| 265 | - | ||
| 266 | Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { | 262 | Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { |
| 267 | return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); | 263 | return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); |
| 268 | } | 264 | } |
| @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { | @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { | ||
| 271 | return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); | 267 | return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); |
| 272 | } | 268 | } |
| 273 | 269 | ||
| 274 | -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { | ||
| 275 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 276 | -} | ||
| 277 | - | ||
| 278 | -#if SHERPA_ONNX_ENABLE_TTS | ||
| 279 | -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { | ||
| 280 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 281 | -} | ||
| 282 | -#endif | ||
| 283 | - | ||
| 284 | -Ort::SessionOptions GetSessionOptions( | ||
| 285 | - const SpeakerEmbeddingExtractorConfig &config) { | ||
| 286 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 287 | -} | ||
| 288 | - | ||
| 289 | -Ort::SessionOptions GetSessionOptions( | ||
| 290 | - const SpokenLanguageIdentificationConfig &config) { | ||
| 291 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 292 | -} | ||
| 293 | - | ||
| 294 | -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { | ||
| 295 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 296 | -} | ||
| 297 | - | ||
| 298 | -Ort::SessionOptions GetSessionOptions( | ||
| 299 | - const OfflinePunctuationModelConfig &config) { | ||
| 300 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 301 | -} | ||
| 302 | - | ||
| 303 | -Ort::SessionOptions GetSessionOptions( | ||
| 304 | - const OnlinePunctuationModelConfig &config) { | ||
| 305 | - return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 306 | -} | ||
| 307 | - | ||
| 308 | } // namespace sherpa_onnx | 270 | } // namespace sherpa_onnx |
| @@ -8,53 +8,28 @@ | @@ -8,53 +8,28 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | 9 | ||
| 10 | #include "onnxruntime_cxx_api.h" // NOLINT | 10 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 11 | -#include "sherpa-onnx/csrc/audio-tagging-model-config.h" | ||
| 12 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 11 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 13 | -#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 14 | -#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" | ||
| 15 | -#include "sherpa-onnx/csrc/online-punctuation-model-config.h" | ||
| 16 | #include "sherpa-onnx/csrc/online-lm-config.h" | 12 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 17 | #include "sherpa-onnx/csrc/online-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 18 | -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" | ||
| 19 | -#include "sherpa-onnx/csrc/spoken-language-identification.h" | ||
| 20 | -#include "sherpa-onnx/csrc/vad-model-config.h" | ||
| 21 | - | ||
| 22 | -#if SHERPA_ONNX_ENABLE_TTS | ||
| 23 | -#include "sherpa-onnx/csrc/offline-tts-model-config.h" | ||
| 24 | -#endif | ||
| 25 | 14 | ||
| 26 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 27 | 16 | ||
| 28 | -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); | ||
| 29 | - | ||
| 30 | -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, | ||
| 31 | - const std::string &model_type); | ||
| 32 | - | ||
| 33 | -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); | 17 | +Ort::SessionOptions GetSessionOptionsImpl( |
| 18 | + int32_t num_threads, const std::string &provider_str, | ||
| 19 | + const ProviderConfig *provider_config = nullptr); | ||
| 34 | 20 | ||
| 35 | Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); | 21 | Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); |
| 36 | - | ||
| 37 | Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); | 22 | Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); |
| 38 | 23 | ||
| 39 | -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); | ||
| 40 | - | ||
| 41 | -#if SHERPA_ONNX_ENABLE_TTS | ||
| 42 | -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); | ||
| 43 | -#endif | ||
| 44 | - | ||
| 45 | -Ort::SessionOptions GetSessionOptions( | ||
| 46 | - const SpeakerEmbeddingExtractorConfig &config); | ||
| 47 | - | ||
| 48 | -Ort::SessionOptions GetSessionOptions( | ||
| 49 | - const SpokenLanguageIdentificationConfig &config); | ||
| 50 | - | ||
| 51 | -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); | 24 | +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); |
| 52 | 25 | ||
| 53 | -Ort::SessionOptions GetSessionOptions( | ||
| 54 | - const OfflinePunctuationModelConfig &config); | 26 | +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, |
| 27 | + const std::string &model_type); | ||
| 55 | 28 | ||
| 56 | -Ort::SessionOptions GetSessionOptions( | ||
| 57 | - const OnlinePunctuationModelConfig &config); | 29 | +template <typename T> |
| 30 | +Ort::SessionOptions GetSessionOptions(const T &config) { | ||
| 31 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 32 | +} | ||
| 58 | 33 | ||
| 59 | } // namespace sherpa_onnx | 34 | } // namespace sherpa_onnx |
| 60 | 35 |
| 1 | +// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" | ||
| 6 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 7 | +#include "sherpa-onnx/csrc/wave-reader.h" | ||
| 8 | + | ||
| 9 | +static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks, | ||
| 10 | + void *arg) { | ||
| 11 | + float progress = 100.0 * processed_chunks / num_chunks; | ||
| 12 | + fprintf(stderr, "progress %.2f%%\n", progress); | ||
| 13 | + | ||
| 14 | + // the return value is currently ignored | ||
| 15 | + return 0; | ||
| 16 | +} | ||
| 17 | + | ||
| 18 | +int main(int32_t argc, char *argv[]) { | ||
| 19 | + const char *kUsageMessage = R"usage( | ||
| 20 | +Offline/Non-streaming speaker diarization with sherpa-onnx | ||
| 21 | +Usage example: | ||
| 22 | + | ||
| 23 | +Step 1: Download a speaker segmentation model | ||
| 24 | + | ||
| 25 | +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models | ||
| 26 | +for a list of available models. The following is an example | ||
| 27 | + | ||
| 28 | + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 29 | + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 30 | + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 | ||
| 31 | + | ||
| 32 | +Step 2: Download a speaker embedding extractor model | ||
| 33 | + | ||
| 34 | +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models | ||
| 35 | +for a list of available models. The following is an example | ||
| 36 | + | ||
| 37 | + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx | ||
| 38 | + | ||
| 39 | +Step 3. Download test wave files | ||
| 40 | + | ||
| 41 | +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models | ||
| 42 | +for a list of available test wave files. The following is an example | ||
| 43 | + | ||
| 44 | + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav | ||
| 45 | + | ||
| 46 | +Step 4. Build sherpa-onnx | ||
| 47 | + | ||
| 48 | +Step 5. Run it | ||
| 49 | + | ||
| 50 | + ./bin/sherpa-onnx-offline-speaker-diarization \ | ||
| 51 | + --clustering.num-clusters=4 \ | ||
| 52 | + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ | ||
| 53 | + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ | ||
| 54 | + ./0-four-speakers-zh.wav | ||
| 55 | + | ||
| 56 | +Since we know that there are four speakers in the test wave file, we use | ||
| 57 | +--clustering.num-clusters=4 in the above example. | ||
| 58 | + | ||
| 59 | +If we don't know number of speakers in the given wave file, we can use | ||
| 60 | +the argument --clustering.cluster-threshold. The following is an example: | ||
| 61 | + | ||
| 62 | + ./bin/sherpa-onnx-offline-speaker-diarization \ | ||
| 63 | + --clustering.cluster-threshold=0.90 \ | ||
| 64 | + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ | ||
| 65 | + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ | ||
| 66 | + ./0-four-speakers-zh.wav | ||
| 67 | + | ||
| 68 | +A larger threshold leads to few clusters, i.e., few speakers; | ||
| 69 | +a smaller threshold leads to more clusters, i.e., more speakers | ||
| 70 | + )usage"; | ||
| 71 | + sherpa_onnx::OfflineSpeakerDiarizationConfig config; | ||
| 72 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 73 | + config.Register(&po); | ||
| 74 | + po.Read(argc, argv); | ||
| 75 | + | ||
| 76 | + std::cout << config.ToString() << "\n"; | ||
| 77 | + | ||
| 78 | + if (!config.Validate()) { | ||
| 79 | + po.PrintUsage(); | ||
| 80 | + std::cerr << "Errors in config!\n"; | ||
| 81 | + return -1; | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + if (po.NumArgs() != 1) { | ||
| 85 | + std::cerr << "Error: Please provide exactly 1 wave file.\n\n"; | ||
| 86 | + po.PrintUsage(); | ||
| 87 | + return -1; | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + sherpa_onnx::OfflineSpeakerDiarization sd(config); | ||
| 91 | + | ||
| 92 | + std::cout << "Started\n"; | ||
| 93 | + const auto begin = std::chrono::steady_clock::now(); | ||
| 94 | + const std::string wav_filename = po.GetArg(1); | ||
| 95 | + int32_t sample_rate = -1; | ||
| 96 | + bool is_ok = false; | ||
| 97 | + const std::vector<float> samples = | ||
| 98 | + sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok); | ||
| 99 | + if (!is_ok) { | ||
| 100 | + std::cerr << "Failed to read " << wav_filename.c_str() << "\n"; | ||
| 101 | + return -1; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + if (sample_rate != sd.SampleRate()) { | ||
| 105 | + std::cerr << "Expect sample rate " << sd.SampleRate() | ||
| 106 | + << ". Given: " << sample_rate << "\n"; | ||
| 107 | + return -1; | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + float duration = samples.size() / static_cast<float>(sample_rate); | ||
| 111 | + | ||
| 112 | + auto result = | ||
| 113 | + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr) | ||
| 114 | + .SortByStartTime(); | ||
| 115 | + | ||
| 116 | + for (const auto &r : result) { | ||
| 117 | + std::cout << r.ToString() << "\n"; | ||
| 118 | + } | ||
| 119 | + | ||
| 120 | + const auto end = std::chrono::steady_clock::now(); | ||
| 121 | + float elapsed_seconds = | ||
| 122 | + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
| 123 | + .count() / | ||
| 124 | + 1000.; | ||
| 125 | + | ||
| 126 | + fprintf(stderr, "Duration : %.3f s\n", duration); | ||
| 127 | + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||
| 128 | + float rtf = elapsed_seconds / duration; | ||
| 129 | + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||
| 130 | + elapsed_seconds, duration, rtf); | ||
| 131 | + | ||
| 132 | + return 0; | ||
| 133 | +} |
| @@ -9,14 +9,15 @@ | @@ -9,14 +9,15 @@ | ||
| 9 | #include "sherpa-onnx/csrc/parse-options.h" | 9 | #include "sherpa-onnx/csrc/parse-options.h" |
| 10 | #include "sherpa-onnx/csrc/wave-writer.h" | 10 | #include "sherpa-onnx/csrc/wave-writer.h" |
| 11 | 11 | ||
| 12 | -int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) { | 12 | +static int32_t AudioCallback(const float * /*samples*/, int32_t n, |
| 13 | + float progress) { | ||
| 13 | printf("sample=%d, progress=%f\n", n, progress); | 14 | printf("sample=%d, progress=%f\n", n, progress); |
| 14 | return 1; | 15 | return 1; |
| 15 | } | 16 | } |
| 16 | 17 | ||
| 17 | int main(int32_t argc, char *argv[]) { | 18 | int main(int32_t argc, char *argv[]) { |
| 18 | const char *kUsageMessage = R"usage( | 19 | const char *kUsageMessage = R"usage( |
| 19 | -Offline text-to-speech with sherpa-onnx | 20 | +Offline/Non-streaming text-to-speech with sherpa-onnx |
| 20 | 21 | ||
| 21 | Usage example: | 22 | Usage example: |
| 22 | 23 | ||
| @@ -79,7 +80,7 @@ or details. | @@ -79,7 +80,7 @@ or details. | ||
| 79 | sherpa_onnx::OfflineTts tts(config); | 80 | sherpa_onnx::OfflineTts tts(config); |
| 80 | 81 | ||
| 81 | const auto begin = std::chrono::steady_clock::now(); | 82 | const auto begin = std::chrono::steady_clock::now(); |
| 82 | - auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback); | 83 | + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback); |
| 83 | const auto end = std::chrono::steady_clock::now(); | 84 | const auto end = std::chrono::steady_clock::now(); |
| 84 | 85 | ||
| 85 | if (audio.samples.empty()) { | 86 | if (audio.samples.empty()) { |
| @@ -19,7 +19,7 @@ The input text can contain English words. | @@ -19,7 +19,7 @@ The input text can contain English words. | ||
| 19 | Usage: | 19 | Usage: |
| 20 | 20 | ||
| 21 | Please download the model from: | 21 | Please download the model from: |
| 22 | -https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 | 22 | +https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 |
| 23 | 23 | ||
| 24 | ./bin/Release/sherpa-onnx-online-punctuation \ | 24 | ./bin/Release/sherpa-onnx-online-punctuation \ |
| 25 | --cnn-bilstm=/path/to/model.onnx \ | 25 | --cnn-bilstm=/path/to/model.onnx \ |
| @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { | @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { | ||
| 26 | 26 | ||
| 27 | bool SpeakerEmbeddingExtractorConfig::Validate() const { | 27 | bool SpeakerEmbeddingExtractorConfig::Validate() const { |
| 28 | if (model.empty()) { | 28 | if (model.empty()) { |
| 29 | - SHERPA_ONNX_LOGE("Please provide --model"); | 29 | + SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model"); |
| 30 | return false; | 30 | return false; |
| 31 | } | 31 | } |
| 32 | 32 | ||
| 33 | if (!FileExists(model)) { | 33 | if (!FileExists(model)) { |
| 34 | - SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist", | 34 | + SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist", |
| 35 | model.c_str()); | 35 | model.c_str()); |
| 36 | return false; | 36 | return false; |
| 37 | } | 37 | } |
-
请 注册 或 登录 后发表评论