Fangjun Kuang
Committed by GitHub

C++ API for speaker diarization (#1396)

正在显示 39 个修改的文件 包含 1647 行增加103 行删除
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  17 +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  18 +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  19 +
  20 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  21 +
  22 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  23 +
  24 +log "specify number of clusters"
  25 +$EXE \
  26 + --clustering.num-clusters=4 \
  27 + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
  28 + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
  29 + ./0-four-speakers-zh.wav
  30 +
  31 +log "specify threshold for clustering"
  32 +
  33 +$EXE \
  34 + --clustering.cluster-threshold=0.90 \
  35 + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
  36 + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
  37 + ./0-four-speakers-zh.wav
  38 +
  39 +rm -rf sherpa-onnx-pyannote-*
  40 +rm -fv *.onnx
  41 +rm -fv *.wav
@@ -29,7 +29,7 @@ jobs: @@ -29,7 +29,7 @@ jobs:
29 - name: Install pyannote 29 - name: Install pyannote
30 shell: bash 30 shell: bash
31 run: | 31 run: |
32 - pip install pyannote.audio onnx onnxruntime 32 + pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3
33 33
34 - name: Run 34 - name: Run
35 shell: bash 35 shell: bash
@@ -18,6 +18,7 @@ on: @@ -18,6 +18,7 @@ on:
18 - '.github/scripts/test-audio-tagging.sh' 18 - '.github/scripts/test-audio-tagging.sh'
19 - '.github/scripts/test-offline-punctuation.sh' 19 - '.github/scripts/test-offline-punctuation.sh'
20 - '.github/scripts/test-online-punctuation.sh' 20 - '.github/scripts/test-online-punctuation.sh'
  21 + - '.github/scripts/test-speaker-diarization.sh'
21 - 'CMakeLists.txt' 22 - 'CMakeLists.txt'
22 - 'cmake/**' 23 - 'cmake/**'
23 - 'sherpa-onnx/csrc/*' 24 - 'sherpa-onnx/csrc/*'
@@ -38,6 +39,7 @@ on: @@ -38,6 +39,7 @@ on:
38 - '.github/scripts/test-audio-tagging.sh' 39 - '.github/scripts/test-audio-tagging.sh'
39 - '.github/scripts/test-offline-punctuation.sh' 40 - '.github/scripts/test-offline-punctuation.sh'
40 - '.github/scripts/test-online-punctuation.sh' 41 - '.github/scripts/test-online-punctuation.sh'
  42 + - '.github/scripts/test-speaker-diarization.sh'
41 - 'CMakeLists.txt' 43 - 'CMakeLists.txt'
42 - 'cmake/**' 44 - 'cmake/**'
43 - 'sherpa-onnx/csrc/*' 45 - 'sherpa-onnx/csrc/*'
@@ -143,6 +145,15 @@ jobs: @@ -143,6 +145,15 @@ jobs:
143 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 145 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
144 path: install/* 146 path: install/*
145 147
  148 + - name: Test offline speaker diarization
  149 + shell: bash
  150 + run: |
  151 + du -h -d1 .
  152 + export PATH=$PWD/build/bin:$PATH
  153 + export EXE=sherpa-onnx-offline-speaker-diarization
  154 +
  155 + .github/scripts/test-speaker-diarization.sh
  156 +
146 - name: Test offline transducer 157 - name: Test offline transducer
147 shell: bash 158 shell: bash
148 run: | 159 run: |
@@ -18,6 +18,7 @@ on: @@ -18,6 +18,7 @@ on:
18 - '.github/scripts/test-audio-tagging.sh' 18 - '.github/scripts/test-audio-tagging.sh'
19 - '.github/scripts/test-offline-punctuation.sh' 19 - '.github/scripts/test-offline-punctuation.sh'
20 - '.github/scripts/test-online-punctuation.sh' 20 - '.github/scripts/test-online-punctuation.sh'
  21 + - '.github/scripts/test-speaker-diarization.sh'
21 - 'CMakeLists.txt' 22 - 'CMakeLists.txt'
22 - 'cmake/**' 23 - 'cmake/**'
23 - 'sherpa-onnx/csrc/*' 24 - 'sherpa-onnx/csrc/*'
@@ -37,6 +38,7 @@ on: @@ -37,6 +38,7 @@ on:
37 - '.github/scripts/test-audio-tagging.sh' 38 - '.github/scripts/test-audio-tagging.sh'
38 - '.github/scripts/test-offline-punctuation.sh' 39 - '.github/scripts/test-offline-punctuation.sh'
39 - '.github/scripts/test-online-punctuation.sh' 40 - '.github/scripts/test-online-punctuation.sh'
  41 + - '.github/scripts/test-speaker-diarization.sh'
40 - 'CMakeLists.txt' 42 - 'CMakeLists.txt'
41 - 'cmake/**' 43 - 'cmake/**'
42 - 'sherpa-onnx/csrc/*' 44 - 'sherpa-onnx/csrc/*'
@@ -115,6 +117,15 @@ jobs: @@ -115,6 +117,15 @@ jobs:
115 otool -L build/bin/sherpa-onnx 117 otool -L build/bin/sherpa-onnx
116 otool -l build/bin/sherpa-onnx 118 otool -l build/bin/sherpa-onnx
117 119
  120 + - name: Test offline speaker diarization
  121 + shell: bash
  122 + run: |
  123 + du -h -d1 .
  124 + export PATH=$PWD/build/bin:$PATH
  125 + export EXE=sherpa-onnx-offline-speaker-diarization
  126 +
  127 + .github/scripts/test-speaker-diarization.sh
  128 +
118 - name: Test offline transducer 129 - name: Test offline transducer
119 shell: bash 130 shell: bash
120 run: | 131 run: |
@@ -67,7 +67,7 @@ jobs: @@ -67,7 +67,7 @@ jobs:
67 curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin 67 curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
68 68
69 test_wavs=( 69 test_wavs=(
70 - 0-two-speakers-zh.wav 70 + 0-four-speakers-zh.wav
71 1-two-speakers-en.wav 71 1-two-speakers-en.wav
72 2-two-speakers-en.wav 72 2-two-speakers-en.wav
73 3-two-speakers-en.wav 73 3-two-speakers-en.wav
@@ -17,6 +17,7 @@ on: @@ -17,6 +17,7 @@ on:
17 - '.github/scripts/test-audio-tagging.sh' 17 - '.github/scripts/test-audio-tagging.sh'
18 - '.github/scripts/test-offline-punctuation.sh' 18 - '.github/scripts/test-offline-punctuation.sh'
19 - '.github/scripts/test-online-punctuation.sh' 19 - '.github/scripts/test-online-punctuation.sh'
  20 + - '.github/scripts/test-speaker-diarization.sh'
20 - 'CMakeLists.txt' 21 - 'CMakeLists.txt'
21 - 'cmake/**' 22 - 'cmake/**'
22 - 'sherpa-onnx/csrc/*' 23 - 'sherpa-onnx/csrc/*'
@@ -34,6 +35,7 @@ on: @@ -34,6 +35,7 @@ on:
34 - '.github/scripts/test-audio-tagging.sh' 35 - '.github/scripts/test-audio-tagging.sh'
35 - '.github/scripts/test-offline-punctuation.sh' 36 - '.github/scripts/test-offline-punctuation.sh'
36 - '.github/scripts/test-online-punctuation.sh' 37 - '.github/scripts/test-online-punctuation.sh'
  38 + - '.github/scripts/test-speaker-diarization.sh'
37 - 'CMakeLists.txt' 39 - 'CMakeLists.txt'
38 - 'cmake/**' 40 - 'cmake/**'
39 - 'sherpa-onnx/csrc/*' 41 - 'sherpa-onnx/csrc/*'
@@ -87,6 +89,15 @@ jobs: @@ -87,6 +89,15 @@ jobs:
87 name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }} 89 name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
88 path: build/install/* 90 path: build/install/*
89 91
  92 + - name: Test offline speaker diarization
  93 + shell: bash
  94 + run: |
  95 + du -h -d1 .
  96 + export PATH=$PWD/build/bin:$PATH
  97 + export EXE=sherpa-onnx-offline-speaker-diarization.exe
  98 +
  99 + .github/scripts/test-speaker-diarization.sh
  100 +
90 - name: Test online punctuation 101 - name: Test online punctuation
91 shell: bash 102 shell: bash
92 run: | 103 run: |
@@ -17,6 +17,7 @@ on: @@ -17,6 +17,7 @@ on:
17 - '.github/scripts/test-audio-tagging.sh' 17 - '.github/scripts/test-audio-tagging.sh'
18 - '.github/scripts/test-offline-punctuation.sh' 18 - '.github/scripts/test-offline-punctuation.sh'
19 - '.github/scripts/test-online-punctuation.sh' 19 - '.github/scripts/test-online-punctuation.sh'
  20 + - '.github/scripts/test-speaker-diarization.sh'
20 - 'CMakeLists.txt' 21 - 'CMakeLists.txt'
21 - 'cmake/**' 22 - 'cmake/**'
22 - 'sherpa-onnx/csrc/*' 23 - 'sherpa-onnx/csrc/*'
@@ -34,6 +35,7 @@ on: @@ -34,6 +35,7 @@ on:
34 - '.github/scripts/test-audio-tagging.sh' 35 - '.github/scripts/test-audio-tagging.sh'
35 - '.github/scripts/test-offline-punctuation.sh' 36 - '.github/scripts/test-offline-punctuation.sh'
36 - '.github/scripts/test-online-punctuation.sh' 37 - '.github/scripts/test-online-punctuation.sh'
  38 + - '.github/scripts/test-speaker-diarization.sh'
37 - 'CMakeLists.txt' 39 - 'CMakeLists.txt'
38 - 'cmake/**' 40 - 'cmake/**'
39 - 'sherpa-onnx/csrc/*' 41 - 'sherpa-onnx/csrc/*'
@@ -87,6 +89,15 @@ jobs: @@ -87,6 +89,15 @@ jobs:
87 name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }} 89 name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
88 path: build/install/* 90 path: build/install/*
89 91
  92 + - name: Test offline speaker diarization
  93 + shell: bash
  94 + run: |
  95 + du -h -d1 .
  96 + export PATH=$PWD/build/bin:$PATH
  97 + export EXE=sherpa-onnx-offline-speaker-diarization.exe
  98 +
  99 + .github/scripts/test-speaker-diarization.sh
  100 +
90 - name: Test online punctuation 101 - name: Test online punctuation
91 shell: bash 102 shell: bash
92 run: | 103 run: |
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
36 fprintf(stderr, "Memory error\n"); 36 fprintf(stderr, "Memory error\n");
37 return -1; 37 return -1;
38 } 38 }
39 - size_t read_bytes = fread(*buffer_out, 1, size, file); 39 + size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
40 if (read_bytes != size) { 40 if (read_bytes != size) {
41 printf("Errors occured in reading the file %s\n", filename); 41 printf("Errors occured in reading the file %s\n", filename);
42 free((void *)*buffer_out); 42 free((void *)*buffer_out);
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
36 fprintf(stderr, "Memory error\n"); 36 fprintf(stderr, "Memory error\n");
37 return -1; 37 return -1;
38 } 38 }
39 - size_t read_bytes = fread(*buffer_out, 1, size, file); 39 + size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
40 if (read_bytes != size) { 40 if (read_bytes != size) {
41 printf("Errors occured in reading the file %s\n", filename); 41 printf("Errors occured in reading the file %s\n", filename);
42 free((void *)*buffer_out); 42 free((void *)*buffer_out);
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
36 fprintf(stderr, "Memory error\n"); 36 fprintf(stderr, "Memory error\n");
37 return -1; 37 return -1;
38 } 38 }
39 - size_t read_bytes = fread(*buffer_out, 1, size, file); 39 + size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
40 if (read_bytes != size) { 40 if (read_bytes != size) {
41 printf("Errors occured in reading the file %s\n", filename); 41 printf("Errors occured in reading the file %s\n", filename);
42 free((void *)*buffer_out); 42 free((void *)*buffer_out);
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) { @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
36 fprintf(stderr, "Memory error\n"); 36 fprintf(stderr, "Memory error\n");
37 return -1; 37 return -1;
38 } 38 }
39 - size_t read_bytes = fread(*buffer_out, 1, size, file); 39 + size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
40 if (read_bytes != size) { 40 if (read_bytes != size) {
41 printf("Errors occured in reading the file %s\n", filename); 41 printf("Errors occured in reading the file %s\n", filename);
42 free((void *)*buffer_out); 42 free((void *)*buffer_out);
@@ -55,6 +55,7 @@ def get_binaries(): @@ -55,6 +55,7 @@ def get_binaries():
55 "sherpa-onnx-offline-audio-tagging", 55 "sherpa-onnx-offline-audio-tagging",
56 "sherpa-onnx-offline-language-identification", 56 "sherpa-onnx-offline-language-identification",
57 "sherpa-onnx-offline-punctuation", 57 "sherpa-onnx-offline-punctuation",
  58 + "sherpa-onnx-offline-speaker-diarization",
58 "sherpa-onnx-offline-tts", 59 "sherpa-onnx-offline-tts",
59 "sherpa-onnx-offline-tts-play", 60 "sherpa-onnx-offline-tts-play",
60 "sherpa-onnx-offline-websocket-server", 61 "sherpa-onnx-offline-websocket-server",
@@ -3,12 +3,9 @@ @@ -3,12 +3,9 @@
3 Please download test wave files from 3 Please download test wave files from
4 https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models 4 https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
5 5
6 -## 0-two-speakers-zh.wav 6 +## 0-four-speakers-zh.wav
7 7
8 -This file is from  
9 -https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0  
10 -  
11 -Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`. 8 +It is recorded by @csukuangfj
12 9
13 ## 1-two-speakers-en.wav 10 ## 1-two-speakers-en.wav
14 11
@@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav` @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav`
40 37
41 38
42 ```bash 39 ```bash
43 -sox ML16091-Audio.mp3 3-two-speakers-en.wav 40 +sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav
44 ``` 41 ```
@@ -72,7 +72,7 @@ def main(): @@ -72,7 +72,7 @@ def main():
72 model.receptive_field.duration * 16000 72 model.receptive_field.duration * 16000
73 ) 73 )
74 74
75 - opset_version = 18 75 + opset_version = 13
76 76
77 filename = "model.onnx" 77 filename = "model.onnx"
78 torch.onnx.export( 78 torch.onnx.export(
@@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
164 list(APPEND sources 164 list(APPEND sources
165 fast-clustering-config.cc 165 fast-clustering-config.cc
166 fast-clustering.cc 166 fast-clustering.cc
  167 + offline-speaker-diarization-impl.cc
  168 + offline-speaker-diarization-result.cc
  169 + offline-speaker-diarization.cc
  170 + offline-speaker-segmentation-model-config.cc
  171 + offline-speaker-segmentation-pyannote-model-config.cc
  172 + offline-speaker-segmentation-pyannote-model.cc
167 ) 173 )
168 endif() 174 endif()
169 175
@@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
260 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 266 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
261 endif() 267 endif()
262 268
  269 + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
  270 + add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc)
  271 + endif()
  272 +
263 set(main_exes 273 set(main_exes
264 sherpa-onnx 274 sherpa-onnx
265 sherpa-onnx-keyword-spotter 275 sherpa-onnx-keyword-spotter
@@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY)
276 ) 286 )
277 endif() 287 endif()
278 288
  289 + if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
  290 + list(APPEND main_exes
  291 + sherpa-onnx-offline-speaker-diarization
  292 + )
  293 + endif()
  294 +
279 foreach(exe IN LISTS main_exes) 295 foreach(exe IN LISTS main_exes)
280 target_link_libraries(${exe} sherpa-onnx-core) 296 target_link_libraries(${exe} sherpa-onnx-core)
281 endforeach() 297 endforeach()
@@ -21,16 +21,14 @@ std::string FastClusteringConfig::ToString() const { @@ -21,16 +21,14 @@ std::string FastClusteringConfig::ToString() const {
21 } 21 }
22 22
23 void FastClusteringConfig::Register(ParseOptions *po) { 23 void FastClusteringConfig::Register(ParseOptions *po) {
24 - std::string prefix = "ctc";  
25 - ParseOptions p(prefix, po);  
26 -  
27 - p.Register("num-clusters", &num_clusters,  
28 - "Number of cluster. If greater than 0, then --cluster-thresold is " 24 + po->Register(
  25 + "num-clusters", &num_clusters,
  26 + "Number of cluster. If greater than 0, then cluster threshold is "
29 "ignored. Please provide it if you know the actual number of " 27 "ignored. Please provide it if you know the actual number of "
30 "clusters in advance."); 28 "clusters in advance.");
31 29
32 - p.Register("cluster-threshold", &threshold,  
33 - "If --num-clusters is not specified, then it specifies the " 30 + po->Register("cluster-threshold", &threshold,
  31 + "If num_clusters is not specified, then it specifies the "
34 "distance threshold for clustering. smaller value -> more " 32 "distance threshold for clustering. smaller value -> more "
35 "clusters. larger value -> fewer clusters"); 33 "clusters. larger value -> fewer clusters");
36 } 34 }
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 #ifndef SHERPA_ONNX_CSRC_MACROS_H_ 5 #ifndef SHERPA_ONNX_CSRC_MACROS_H_
6 #define SHERPA_ONNX_CSRC_MACROS_H_ 6 #define SHERPA_ONNX_CSRC_MACROS_H_
7 #include <stdio.h> 7 #include <stdio.h>
  8 +#include <stdlib.h>
8 9
9 #if __ANDROID_API__ >= 8 10 #if __ANDROID_API__ >= 8
10 #include "android/log.h" 11 #include "android/log.h"
@@ -169,4 +170,6 @@ @@ -169,4 +170,6 @@
169 } \ 170 } \
170 } while (0) 171 } while (0)
171 172
  173 +#define SHERPA_ONNX_EXIT(code) exit(code)
  174 +
172 #endif // SHERPA_ONNX_CSRC_MACROS_H_ 175 #endif // SHERPA_ONNX_CSRC_MACROS_H_
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <utility> 9 #include <utility>
10 10
11 #include "sherpa-onnx/csrc/macros.h" 11 #include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/onnx-utils.h"
12 #include "sherpa-onnx/csrc/session.h" 13 #include "sherpa-onnx/csrc/session.h"
13 #include "sherpa-onnx/csrc/text-utils.h" 14 #include "sherpa-onnx/csrc/text-utils.h"
14 15
  1 +// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
  6 +
  7 +#include <memory>
  8 +
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +std::unique_ptr<OfflineSpeakerDiarizationImpl>
  15 +OfflineSpeakerDiarizationImpl::Create(
  16 + const OfflineSpeakerDiarizationConfig &config) {
  17 + if (!config.segmentation.pyannote.model.empty()) {
  18 + return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
  19 + }
  20 +
  21 + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");
  22 +
  23 + return nullptr;
  24 +}
  25 +
  26 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-diarization-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
  7 +
  8 +#include <functional>
  9 +#include <memory>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineSpeakerDiarizationImpl {
  15 + public:
  16 + static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
  17 + const OfflineSpeakerDiarizationConfig &config);
  18 +
  19 + virtual ~OfflineSpeakerDiarizationImpl() = default;
  20 +
  21 + virtual int32_t SampleRate() const = 0;
  22 +
  23 + virtual OfflineSpeakerDiarizationResult Process(
  24 + const float *audio, int32_t n,
  25 + OfflineSpeakerDiarizationProgressCallback callback = nullptr,
  26 + void *callback_arg = nullptr) const = 0;
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
  6 +
  7 +#include <algorithm>
  8 +#include <unordered_map>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "Eigen/Dense"
  13 +#include "sherpa-onnx/csrc/fast-clustering.h"
  14 +#include "sherpa-onnx/csrc/math.h"
  15 +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
  16 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
  17 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +namespace { // NOLINT
  22 +
  23 +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41
  24 +template <class T>
  25 +inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT
  26 + std::hash<T> hasher;
  27 + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT
  28 +}
  29 +
  30 +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47
  31 +struct PairHash {
  32 + template <class T1, class T2>
  33 + std::size_t operator()(const std::pair<T1, T2> &pair) const {
  34 + std::size_t result = 0;
  35 + hash_combine(&result, pair.first);
  36 + hash_combine(&result, pair.second);
  37 + return result;
  38 + }
  39 +};
  40 +} // namespace
  41 +
  42 +using Matrix2D =
  43 + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
  44 +
  45 +using Matrix2DInt32 =
  46 + Eigen::Matrix<int32_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
  47 +
  48 +using FloatRowVector = Eigen::Matrix<float, 1, Eigen::Dynamic>;
  49 +using Int32RowVector = Eigen::Matrix<int32_t, 1, Eigen::Dynamic>;
  50 +
  51 +using Int32Pair = std::pair<int32_t, int32_t>;
  52 +
  53 +class OfflineSpeakerDiarizationPyannoteImpl
  54 + : public OfflineSpeakerDiarizationImpl {
  55 + public:
  56 + ~OfflineSpeakerDiarizationPyannoteImpl() override = default;
  57 +
  58 + explicit OfflineSpeakerDiarizationPyannoteImpl(
  59 + const OfflineSpeakerDiarizationConfig &config)
  60 + : config_(config),
  61 + segmentation_model_(config_.segmentation),
  62 + embedding_extractor_(config_.embedding),
  63 + clustering_(config_.clustering) {
  64 + Init();
  65 + }
  66 +
  67 + int32_t SampleRate() const override {
  68 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  69 +
  70 + return meta_data.sample_rate;
  71 + }
  72 +
  73 + OfflineSpeakerDiarizationResult Process(
  74 + const float *audio, int32_t n,
  75 + OfflineSpeakerDiarizationProgressCallback callback = nullptr,
  76 + void *callback_arg = nullptr) const override {
  77 + std::vector<Matrix2D> segmentations = RunSpeakerSegmentationModel(audio, n);
  78 + // segmentations[i] is for chunk_i
  79 + // Each matrix is of shape (num_frames, num_powerset_classes)
  80 + if (segmentations.empty()) {
  81 + return {};
  82 + }
  83 +
  84 + std::vector<Matrix2DInt32> labels;
  85 + labels.reserve(segmentations.size());
  86 +
  87 + for (const auto &m : segmentations) {
  88 + labels.push_back(ToMultiLabel(m));
  89 + }
  90 +
  91 + segmentations.clear();
  92 +
  93 + // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers)
  94 +
  95 + // speaker count per frame
  96 + Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels);
  97 +
  98 + if (speakers_per_frame.maxCoeff() == 0) {
  99 + SHERPA_ONNX_LOGE("No speakers found in the audio samples");
  100 + return {};
  101 + }
  102 +
  103 + auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
  104 + Matrix2D embeddings =
  105 + ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
  106 + callback, callback_arg);
  107 +
  108 + std::vector<int32_t> cluster_labels = clustering_.Cluster(
  109 + &embeddings(0, 0), embeddings.rows(), embeddings.cols());
  110 +
  111 + int32_t max_cluster_index =
  112 + *std::max_element(cluster_labels.begin(), cluster_labels.end());
  113 +
  114 + auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster(
  115 + chunk_speaker_samples_list_pair.first, cluster_labels);
  116 +
  117 + auto new_labels =
  118 + ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster);
  119 +
  120 + Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n);
  121 +
  122 + Matrix2DInt32 final_labels =
  123 + FinalizeLabels(speaker_count, speakers_per_frame);
  124 +
  125 + auto result = ComputeResult(final_labels);
  126 +
  127 + return result;
  128 + }
  129 +
  130 + private:
  131 + void Init() { InitPowersetMapping(); }
  132 +
  133 + // see also
  134 + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68
  135 + void InitPowersetMapping() {
  136 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  137 + int32_t num_classes = meta_data.num_classes;
  138 + int32_t powerset_max_classes = meta_data.powerset_max_classes;
  139 + int32_t num_speakers = meta_data.num_speakers;
  140 +
  141 + powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers);
  142 + powerset_mapping_.setZero();
  143 +
  144 + int32_t k = 1;
  145 + for (int32_t i = 1; i <= powerset_max_classes; ++i) {
  146 + if (i == 1) {
  147 + for (int32_t j = 0; j != num_speakers; ++j, ++k) {
  148 + powerset_mapping_(k, j) = 1;
  149 + }
  150 + } else if (i == 2) {
  151 + for (int32_t j = 0; j != num_speakers; ++j) {
  152 + for (int32_t m = j + 1; m < num_speakers; ++m, ++k) {
  153 + powerset_mapping_(k, j) = 1;
  154 + powerset_mapping_(k, m) = 1;
  155 + }
  156 + }
  157 + } else {
  158 + SHERPA_ONNX_LOGE(
  159 + "powerset_max_classes = %d is currently not supported!", i);
  160 + SHERPA_ONNX_EXIT(-1);
  161 + }
  162 + }
  163 + }
  164 +
  165 + std::vector<Matrix2D> RunSpeakerSegmentationModel(const float *audio,
  166 + int32_t n) const {
  167 + std::vector<Matrix2D> ans;
  168 +
  169 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  170 + int32_t window_size = meta_data.window_size;
  171 + int32_t window_shift = meta_data.window_shift;
  172 +
  173 + if (n <= 0) {
  174 + SHERPA_ONNX_LOGE(
  175 + "number of audio samples is %d (<= 0). Please provide a positive "
  176 + "number",
  177 + n);
  178 + return {};
  179 + }
  180 +
  181 + if (n <= window_size) {
  182 + std::vector<float> buf(window_size);
  183 + // NOTE: buf is zero initialized by default
  184 +
  185 + std::copy(audio, audio + n, buf.data());
  186 +
  187 + Matrix2D m = ProcessChunk(buf.data());
  188 +
  189 + ans.push_back(std::move(m));
  190 +
  191 + return ans;
  192 + }
  193 +
  194 + int32_t num_chunks = (n - window_size) / window_shift + 1;
  195 + bool has_last_chunk = (n - window_size) % window_shift > 0;
  196 +
  197 + ans.reserve(num_chunks + has_last_chunk);
  198 +
  199 + const float *p = audio;
  200 +
  201 + for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) {
  202 + Matrix2D m = ProcessChunk(p);
  203 +
  204 + ans.push_back(std::move(m));
  205 + }
  206 +
  207 + if (has_last_chunk) {
  208 + std::vector<float> buf(window_size);
  209 + std::copy(p, audio + n, buf.data());
  210 +
  211 + Matrix2D m = ProcessChunk(buf.data());
  212 +
  213 + ans.push_back(std::move(m));
  214 + }
  215 +
  216 + return ans;
  217 + }
  218 +
  219 + Matrix2D ProcessChunk(const float *p) const {
  220 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  221 + int32_t window_size = meta_data.window_size;
  222 +
  223 + auto memory_info =
  224 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  225 +
  226 + std::array<int64_t, 3> shape = {1, 1, window_size};
  227 +
  228 + Ort::Value x =
  229 + Ort::Value::CreateTensor(memory_info, const_cast<float *>(p),
  230 + window_size, shape.data(), shape.size());
  231 +
  232 + Ort::Value out = segmentation_model_.Forward(std::move(x));
  233 + std::vector<int64_t> out_shape = out.GetTensorTypeAndShapeInfo().GetShape();
  234 + Matrix2D m(out_shape[1], out_shape[2]);
  235 + std::copy(out.GetTensorData<float>(), out.GetTensorData<float>() + m.size(),
  236 + &m(0, 0));
  237 + return m;
  238 + }
  239 +
  240 + Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const {
  241 + int32_t num_rows = m.rows();
  242 + Matrix2DInt32 ans(num_rows, powerset_mapping_.cols());
  243 +
  244 + std::ptrdiff_t col_id;
  245 +
  246 + for (int32_t i = 0; i != num_rows; ++i) {
  247 + m.row(i).maxCoeff(&col_id);
  248 + ans.row(i) = powerset_mapping_.row(col_id);
  249 + }
  250 +
  251 + return ans;
  252 + }
  253 +
  254 + // See also
  255 + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122
  256 + Int32RowVector ComputeSpeakersPerFrame(
  257 + const std::vector<Matrix2DInt32> &labels) const {
  258 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  259 + int32_t window_size = meta_data.window_size;
  260 + int32_t window_shift = meta_data.window_shift;
  261 + int32_t receptive_field_shift = meta_data.receptive_field_shift;
  262 +
  263 + int32_t num_chunks = labels.size();
  264 +
  265 + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) /
  266 + receptive_field_shift +
  267 + 1;
  268 +
  269 + FloatRowVector count(num_frames);
  270 + FloatRowVector weight(num_frames);
  271 + count.setZero();
  272 + weight.setZero();
  273 +
  274 + for (int32_t i = 0; i != num_chunks; ++i) {
  275 + int32_t start =
  276 + static_cast<float>(i) * window_shift / receptive_field_shift + 0.5;
  277 +
  278 + auto seq = Eigen::seqN(start, labels[i].rows());
  279 +
  280 + count(seq).array() += labels[i].rowwise().sum().array().cast<float>();
  281 +
  282 + weight(seq).array() += 1;
  283 + }
  284 +
  285 + return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast<int32_t>();
  286 + }
  287 +
  288 + // ans.first: a list of (chunk_id, speaker_id)
  289 + // ans.second: a list of list of (start_sample_index, end_sample_index)
  290 + //
  291 + // ans.first[i] corresponds to ans.second[i]
  292 + std::pair<std::vector<Int32Pair>, std::vector<std::vector<Int32Pair>>>
  293 + GetChunkSpeakerSampleIndexes(const std::vector<Matrix2DInt32> &labels) const {
  294 + auto new_labels = ExcludeOverlap(labels);
  295 +
  296 + std::vector<Int32Pair> chunk_speaker_list;
  297 + std::vector<std::vector<Int32Pair>> samples_index_list;
  298 +
  299 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  300 + int32_t window_size = meta_data.window_size;
  301 + int32_t window_shift = meta_data.window_shift;
  302 + int32_t receptive_field_shift = meta_data.receptive_field_shift;
  303 + int32_t num_speakers = meta_data.num_speakers;
  304 +
  305 + int32_t chunk_index = 0;
  306 + for (const auto &label : new_labels) {
  307 + Matrix2DInt32 tmp = label.transpose();
  308 + // tmp: (num_speakers, num_frames)
  309 + int32_t num_frames = tmp.cols();
  310 +
  311 + int32_t sample_offset = chunk_index * window_shift;
  312 +
  313 + for (int32_t speaker_index = 0; speaker_index != num_speakers;
  314 + ++speaker_index) {
  315 + auto d = tmp.row(speaker_index);
  316 + if (d.sum() < 10) {
  317 + // skip segments less than 10 frames
  318 + continue;
  319 + }
  320 +
  321 + Int32Pair this_chunk_speaker = {chunk_index, speaker_index};
  322 + std::vector<Int32Pair> this_speaker_samples;
  323 +
  324 + bool is_active = false;
  325 + int32_t start_index;
  326 +
  327 + for (int32_t k = 0; k != num_frames; ++k) {
  328 + if (d[k] != 0) {
  329 + if (!is_active) {
  330 + is_active = true;
  331 + start_index = k;
  332 + }
  333 + } else if (is_active) {
  334 + is_active = false;
  335 +
  336 + int32_t start_samples =
  337 + static_cast<float>(start_index) / num_frames * window_size +
  338 + sample_offset;
  339 + int32_t end_samples =
  340 + static_cast<float>(k) / num_frames * window_size +
  341 + sample_offset;
  342 +
  343 + this_speaker_samples.emplace_back(start_samples, end_samples);
  344 + }
  345 + }
  346 +
  347 + if (is_active) {
  348 + int32_t start_samples =
  349 + static_cast<float>(start_index) / num_frames * window_size +
  350 + sample_offset;
  351 + int32_t end_samples =
  352 + static_cast<float>(num_frames - 1) / num_frames * window_size +
  353 + sample_offset;
  354 + this_speaker_samples.emplace_back(start_samples, end_samples);
  355 + }
  356 +
  357 + chunk_speaker_list.push_back(std::move(this_chunk_speaker));
  358 + samples_index_list.push_back(std::move(this_speaker_samples));
  359 + } // for (int32_t speaker_index = 0;
  360 + chunk_index += 1;
  361 + } // for (const auto &label : new_labels)
  362 +
  363 + return {chunk_speaker_list, samples_index_list};
  364 + }
  365 +
  366 + // If there are multiple speakers at a frame, then this frame is excluded.
  367 + std::vector<Matrix2DInt32> ExcludeOverlap(
  368 + const std::vector<Matrix2DInt32> &labels) const {
  369 + int32_t num_chunks = labels.size();
  370 + std::vector<Matrix2DInt32> ans;
  371 + ans.reserve(num_chunks);
  372 +
  373 + for (const auto &label : labels) {
  374 + Matrix2DInt32 new_label(label.rows(), label.cols());
  375 + new_label.setZero();
  376 + Int32RowVector v = label.rowwise().sum();
  377 +
  378 + for (int32_t i = 0; i != v.cols(); ++i) {
  379 + if (v[i] < 2) {
  380 + new_label.row(i) = label.row(i);
  381 + }
  382 + }
  383 +
  384 + ans.push_back(std::move(new_label));
  385 + }
  386 +
  387 + return ans;
  388 + }
  389 +
  390 + /**
  391 + * @param sample_indexes[i] contains the sample segment start and end indexes
  392 + * for the i-th (chunk, speaker) pair
  393 + * @return Return a matrix of shape (sample_indexes.size(), embedding_dim)
  394 + * where ans.row[i] contains the embedding for the
  395 + * i-th (chunk, speaker) pair
  396 + */
  397 + Matrix2D ComputeEmbeddings(
  398 + const float *audio, int32_t n,
  399 + const std::vector<std::vector<Int32Pair>> &sample_indexes,
  400 + OfflineSpeakerDiarizationProgressCallback callback,
  401 + void *callback_arg) const {
  402 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  403 + int32_t sample_rate = meta_data.sample_rate;
  404 + Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
  405 +
  406 + int32_t k = 0;
  407 + for (const auto &v : sample_indexes) {
  408 + auto stream = embedding_extractor_.CreateStream();
  409 + for (const auto &p : v) {
  410 + int32_t end = (p.second <= n) ? p.second : n;
  411 + int32_t num_samples = end - p.first;
  412 +
  413 + if (num_samples > 0) {
  414 + stream->AcceptWaveform(sample_rate, audio + p.first, num_samples);
  415 + }
  416 + }
  417 +
  418 + stream->InputFinished();
  419 + if (!embedding_extractor_.IsReady(stream.get())) {
  420 + SHERPA_ONNX_LOGE(
  421 + "This segment is too short, which should not happen since we have "
  422 + "already filtered short segments");
  423 + SHERPA_ONNX_EXIT(-1);
  424 + }
  425 +
  426 + std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
  427 +
  428 + std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
  429 +
  430 + k += 1;
  431 +
  432 + if (callback) {
  433 + callback(k, ans.rows(), callback_arg);
  434 + }
  435 + }
  436 +
  437 + return ans;
  438 + }
  439 +
  440 + std::unordered_map<Int32Pair, int32_t, PairHash> ConvertChunkSpeakerToCluster(
  441 + const std::vector<Int32Pair> &chunk_speaker_pair,
  442 + const std::vector<int32_t> &cluster_labels) const {
  443 + std::unordered_map<Int32Pair, int32_t, PairHash> ans;
  444 +
  445 + int32_t k = 0;
  446 + for (const auto &p : chunk_speaker_pair) {
  447 + ans[p] = cluster_labels[k];
  448 + k += 1;
  449 + }
  450 +
  451 + return ans;
  452 + }
  453 +
  454 + std::vector<Matrix2DInt32> ReLabel(
  455 + const std::vector<Matrix2DInt32> &labels, int32_t max_cluster_index,
  456 + std::unordered_map<Int32Pair, int32_t, PairHash> chunk_speaker_to_cluster)
  457 + const {
  458 + std::vector<Matrix2DInt32> new_labels;
  459 + new_labels.reserve(labels.size());
  460 +
  461 + int32_t chunk_index = 0;
  462 + for (const auto &label : labels) {
  463 + Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1);
  464 + new_label.setZero();
  465 +
  466 + Matrix2DInt32 t = label.transpose();
  467 + // t: (num_speakers, num_frames)
  468 +
  469 + for (int32_t speaker_index = 0; speaker_index != t.rows();
  470 + ++speaker_index) {
  471 + if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) {
  472 + continue;
  473 + }
  474 +
  475 + int32_t new_speaker_index =
  476 + chunk_speaker_to_cluster.at({chunk_index, speaker_index});
  477 +
  478 + for (int32_t k = 0; k != t.cols(); ++k) {
  479 + if (t(speaker_index, k) == 1) {
  480 + new_label(k, new_speaker_index) = 1;
  481 + }
  482 + }
  483 + }
  484 +
  485 + new_labels.push_back(std::move(new_label));
  486 +
  487 + chunk_index += 1;
  488 + }
  489 +
  490 + return new_labels;
  491 + }
  492 +
  493 + Matrix2DInt32 ComputeSpeakerCount(const std::vector<Matrix2DInt32> &labels,
  494 + int32_t num_samples) const {
  495 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  496 + int32_t window_size = meta_data.window_size;
  497 + int32_t window_shift = meta_data.window_shift;
  498 + int32_t receptive_field_shift = meta_data.receptive_field_shift;
  499 +
  500 + int32_t num_chunks = labels.size();
  501 +
  502 + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) /
  503 + receptive_field_shift +
  504 + 1;
  505 +
  506 + Matrix2DInt32 count(num_frames, labels[0].cols());
  507 + count.setZero();
  508 +
  509 + for (int32_t i = 0; i != num_chunks; ++i) {
  510 + int32_t start =
  511 + static_cast<float>(i) * window_shift / receptive_field_shift + 0.5;
  512 +
  513 + auto seq = Eigen::seqN(start, labels[i].rows());
  514 +
  515 + count(seq, Eigen::all).array() += labels[i].array();
  516 + }
  517 +
  518 + bool has_last_chunk = (num_samples - window_size) % window_shift > 0;
  519 +
  520 + if (has_last_chunk) {
  521 + return count;
  522 + }
  523 +
  524 + int32_t last_frame = num_samples / receptive_field_shift;
  525 + return count(Eigen::seq(0, last_frame), Eigen::all);
  526 + }
  527 +
  528 + Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count,
  529 + const Int32RowVector &speakers_per_frame) const {
  530 + int32_t num_rows = count.rows();
  531 + int32_t num_cols = count.cols();
  532 +
  533 + Matrix2DInt32 ans(num_rows, num_cols);
  534 + ans.setZero();
  535 +
  536 + for (int32_t i = 0; i != num_rows; ++i) {
  537 + int32_t k = speakers_per_frame[i];
  538 + if (k == 0) {
  539 + continue;
  540 + }
  541 + auto top_k = TopkIndex(&count(i, 0), num_cols, k);
  542 +
  543 + for (int32_t m : top_k) {
  544 + ans(i, m) = 1;
  545 + }
  546 + }
  547 +
  548 + return ans;
  549 + }
  550 +
  551 + OfflineSpeakerDiarizationResult ComputeResult(
  552 + const Matrix2DInt32 &final_labels) const {
  553 + Matrix2DInt32 final_labels_t = final_labels.transpose();
  554 + int32_t num_speakers = final_labels_t.rows();
  555 + int32_t num_frames = final_labels_t.cols();
  556 +
  557 + const auto &meta_data = segmentation_model_.GetModelMetaData();
  558 + int32_t window_size = meta_data.window_size;
  559 + int32_t window_shift = meta_data.window_shift;
  560 + int32_t receptive_field_shift = meta_data.receptive_field_shift;
  561 + int32_t receptive_field_size = meta_data.receptive_field_size;
  562 + int32_t sample_rate = meta_data.sample_rate;
  563 +
  564 + float scale = static_cast<float>(receptive_field_shift) / sample_rate;
  565 + float scale_offset = 0.5 * receptive_field_size / sample_rate;
  566 +
  567 + OfflineSpeakerDiarizationResult ans;
  568 +
  569 + for (int32_t speaker_index = 0; speaker_index != num_speakers;
  570 + ++speaker_index) {
  571 + std::vector<OfflineSpeakerDiarizationSegment> this_speaker;
  572 +
  573 + bool is_active = final_labels_t(speaker_index, 0) > 0;
  574 + int32_t start_index = is_active ? 0 : -1;
  575 +
  576 + for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) {
  577 + if (is_active) {
  578 + if (final_labels_t(speaker_index, frame_index) == 0) {
  579 + float start_time = start_index * scale + scale_offset;
  580 + float end_time = frame_index * scale + scale_offset;
  581 +
  582 + OfflineSpeakerDiarizationSegment segment(start_time, end_time,
  583 + speaker_index);
  584 + this_speaker.push_back(segment);
  585 +
  586 + is_active = false;
  587 + }
  588 + } else if (final_labels_t(speaker_index, frame_index) == 1) {
  589 + is_active = true;
  590 + start_index = frame_index;
  591 + }
  592 + }
  593 +
  594 + if (is_active) {
  595 + float start_time = start_index * scale + scale_offset;
  596 + float end_time = (num_frames - 1) * scale + scale_offset;
  597 +
  598 + OfflineSpeakerDiarizationSegment segment(start_time, end_time,
  599 + speaker_index);
  600 + this_speaker.push_back(segment);
  601 + }
  602 +
  603 + // merge segments if the gap between them is less than min_duration_off
  604 + MergeSegments(&this_speaker);
  605 +
  606 + for (const auto &seg : this_speaker) {
  607 + if (seg.Duration() > config_.min_duration_on) {
  608 + ans.Add(seg);
  609 + }
  610 + }
  611 + } // for (int32_t speaker_index = 0; speaker_index != num_speakers;
  612 +
  613 + return ans;
  614 + }
  615 +
  616 + void MergeSegments(
  617 + std::vector<OfflineSpeakerDiarizationSegment> *segments) const {
  618 + float min_duration_off = config_.min_duration_off;
  619 + bool changed = true;
  620 + while (changed) {
  621 + changed = false;
  622 + for (int32_t i = 0; i < static_cast<int32_t>(segments->size()) - 1; ++i) {
  623 + auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off);
  624 + if (s) {
  625 + (*segments)[i] = s.value();
  626 + segments->erase(segments->begin() + i + 1);
  627 +
  628 + changed = true;
  629 + break;
  630 + }
  631 + }
  632 + }
  633 + }
  634 +
  635 + private:
  636 + OfflineSpeakerDiarizationConfig config_;
  637 + OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
  638 + SpeakerEmbeddingExtractor embedding_extractor_;
  639 + FastClustering clustering_;
  640 + Matrix2DInt32 powerset_mapping_;
  641 +};
  642 +
  643 +} // namespace sherpa_onnx
  644 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-speaker-diarization-result.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
  6 +
  7 +#include <algorithm>
  8 +#include <sstream>
  9 +#include <string>
  10 +#include <unordered_set>
  11 +#include <utility>
  12 +
  13 +#include "sherpa-onnx/csrc/macros.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment(
  18 + float start, float end, int32_t speaker, const std::string &text /*= {}*/) {
  19 + if (start > end) {
  20 + SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end);
  21 + SHERPA_ONNX_EXIT(-1);
  22 + }
  23 +
  24 + start_ = start;
  25 + end_ = end;
  26 + speaker_ = speaker;
  27 + text_ = text;
  28 +}
  29 +
  30 +std::optional<OfflineSpeakerDiarizationSegment>
  31 +OfflineSpeakerDiarizationSegment::Merge(
  32 + const OfflineSpeakerDiarizationSegment &other, float gap) const {
  33 + if (other.speaker_ != speaker_) {
  34 + SHERPA_ONNX_LOGE(
  35 + "The two segments should have the same speaker. this->speaker: %d, "
  36 + "other.speaker: %d",
  37 + speaker_, other.speaker_);
  38 + return std::nullopt;
  39 + }
  40 +
  41 + if (end_ < other.start_ && end_ + gap >= other.start_) {
  42 + return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_);
  43 + } else if (other.end_ < start_ && other.end_ + gap >= start_) {
  44 + return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_);
  45 + } else {
  46 + return std::nullopt;
  47 + }
  48 +}
  49 +
  50 +std::string OfflineSpeakerDiarizationSegment::ToString() const {
  51 + char s[128];
  52 + snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, speaker_);
  53 +
  54 + std::ostringstream os;
  55 + os << s;
  56 +
  57 + if (!text_.empty()) {
  58 + os << " " << text_;
  59 + }
  60 +
  61 + return os.str();
  62 +}
  63 +
  64 +void OfflineSpeakerDiarizationResult::Add(
  65 + const OfflineSpeakerDiarizationSegment &segment) {
  66 + segments_.push_back(segment);
  67 +}
  68 +
  69 +int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const {
  70 + std::unordered_set<int32_t> count;
  71 + for (const auto &s : segments_) {
  72 + count.insert(s.Speaker());
  73 + }
  74 +
  75 + return count.size();
  76 +}
  77 +
  78 +int32_t OfflineSpeakerDiarizationResult::NumSegments() const {
  79 + return segments_.size();
  80 +}
  81 +
  82 +// Return a list of segments sorted by segment.start time
  83 +std::vector<OfflineSpeakerDiarizationSegment>
  84 +OfflineSpeakerDiarizationResult::SortByStartTime() const {
  85 + auto ans = segments_;
  86 + std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) {
  87 + return (a.Start() < b.Start()) ||
  88 + ((a.Start() == b.Start()) && (a.Speaker() < b.Speaker()));
  89 + });
  90 +
  91 + return ans;
  92 +}
  93 +
  94 +std::vector<std::vector<OfflineSpeakerDiarizationSegment>>
  95 +OfflineSpeakerDiarizationResult::SortBySpeaker() const {
  96 + auto tmp = segments_;
  97 + std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) {
  98 + return (a.Speaker() < b.Speaker()) ||
  99 + ((a.Speaker() == b.Speaker()) && (a.Start() < b.Start()));
  100 + });
  101 +
  102 + std::vector<std::vector<OfflineSpeakerDiarizationSegment>> ans(NumSpeakers());
  103 + for (auto &s : tmp) {
  104 + ans[s.Speaker()].push_back(std::move(s));
  105 + }
  106 +
  107 + return ans;
  108 +}
  109 +
  110 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-diarization-result.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
  7 +
  8 +#include <cstdint>
  9 +#include <optional>
  10 +#include <string>
  11 +#include <vector>
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineSpeakerDiarizationSegment {
  16 + public:
  17 + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker,
  18 + const std::string &text = {});
  19 +
  20 + // If the gap between the two segments is less than the given gap, then we
  21 + // merge them and return a new segment. Otherwise, it returns null.
  22 + std::optional<OfflineSpeakerDiarizationSegment> Merge(
  23 + const OfflineSpeakerDiarizationSegment &other, float gap) const;
  24 +
  25 + float Start() const { return start_; }
  26 + float End() const { return end_; }
  27 + int32_t Speaker() const { return speaker_; }
  28 + const std::string &Text() const { return text_; }
  29 + float Duration() const { return end_ - start_; }
  30 +
  31 + std::string ToString() const;
  32 +
  33 + private:
  34 + float start_; // in seconds
  35 + float end_; // in seconds
  36 + int32_t speaker_; // ID of the speaker, starting from 0
  37 + std::string text_; // If not empty, it contains the speech recognition result
  38 + // of this segment
  39 +};
  40 +
  41 +class OfflineSpeakerDiarizationResult {
  42 + public:
  43 + // Add a new segment
  44 + void Add(const OfflineSpeakerDiarizationSegment &segment);
  45 +
  46 + // Number of distinct speakers contained in this object at this point
  47 + int32_t NumSpeakers() const;
  48 +
  49 + int32_t NumSegments() const;
  50 +
  51 + // Return a list of segments sorted by segment.start time
  52 + std::vector<OfflineSpeakerDiarizationSegment> SortByStartTime() const;
  53 +
  54 + // ans.size() == NumSpeakers().
  55 + // ans[i] is for speaker_i and is sorted by start time
  56 + std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
  57 + const;
  58 +
  59 + public:
  60 + std::vector<OfflineSpeakerDiarizationSegment> segments_;
  61 +};
  62 +
  63 +} // namespace sherpa_onnx
  64 +
  65 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
  1 +// sherpa-onnx/csrc/offline-speaker-diarization.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) {
  14 + ParseOptions po_segmentation("segmentation", po);
  15 + segmentation.Register(&po_segmentation);
  16 +
  17 + ParseOptions po_embedding("embedding", po);
  18 + embedding.Register(&po_embedding);
  19 +
  20 + ParseOptions po_clustering("clustering", po);
  21 + clustering.Register(&po_clustering);
  22 +
  23 + po->Register("min-duration-on", &min_duration_on,
  24 + "if a segment is less than this value, then it is discarded. "
  25 + "Set it to 0 so that no segment is discarded");
  26 +
  27 + po->Register("min-duration-off", &min_duration_off,
  28 + "if the gap between to segments of the same speaker is less "
  29 + "than this value, then these two segments are merged into a "
  30 + "single segment. We do it recursively.");
  31 +}
  32 +
  33 +bool OfflineSpeakerDiarizationConfig::Validate() const {
  34 + if (!segmentation.Validate()) {
  35 + return false;
  36 + }
  37 +
  38 + if (!embedding.Validate()) {
  39 + return false;
  40 + }
  41 +
  42 + if (!clustering.Validate()) {
  43 + return false;
  44 + }
  45 +
  46 + return true;
  47 +}
  48 +
  49 +std::string OfflineSpeakerDiarizationConfig::ToString() const {
  50 + std::ostringstream os;
  51 +
  52 + os << "OfflineSpeakerDiarizationConfig(";
  53 + os << "segmentation=" << segmentation.ToString() << ", ";
  54 + os << "embedding=" << embedding.ToString() << ", ";
  55 + os << "clustering=" << clustering.ToString() << ", ";
  56 + os << "min_duration_on=" << min_duration_on << ", ";
  57 + os << "min_duration_off=" << min_duration_off << ")";
  58 +
  59 + return os.str();
  60 +}
  61 +
  62 +OfflineSpeakerDiarization::OfflineSpeakerDiarization(
  63 + const OfflineSpeakerDiarizationConfig &config)
  64 + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}
  65 +
  66 +OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;
  67 +
  68 +int32_t OfflineSpeakerDiarization::SampleRate() const {
  69 + return impl_->SampleRate();
  70 +}
  71 +
  72 +OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
  73 + const float *audio, int32_t n,
  74 + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,
  75 + void *callback_arg /*= nullptr*/) const {
  76 + return impl_->Process(audio, n, callback, callback_arg);
  77 +}
  78 +
  79 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-diarization.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
  7 +
  8 +#include <functional>
  9 +#include <memory>
  10 +#include <string>
  11 +
  12 +#include "sherpa-onnx/csrc/fast-clustering-config.h"
  13 +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
  14 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
  15 +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
  16 +
  17 +namespace sherpa_onnx {
  18 +
  19 +struct OfflineSpeakerDiarizationConfig {
  20 + OfflineSpeakerSegmentationModelConfig segmentation;
  21 + SpeakerEmbeddingExtractorConfig embedding;
  22 + FastClusteringConfig clustering;
  23 +
  24 + // if a segment is less than this value, then it is discarded
  25 + float min_duration_on = 0.3; // in seconds
  26 +
  27 + // if the gap between to segments of the same speaker is less than this value,
  28 + // then these two segments are merged into a single segment.
  29 + // We do this recursively.
  30 + float min_duration_off = 0.5; // in seconds
  31 +
  32 + OfflineSpeakerDiarizationConfig() = default;
  33 +
  34 + OfflineSpeakerDiarizationConfig(
  35 + const OfflineSpeakerSegmentationModelConfig &segmentation,
  36 + const SpeakerEmbeddingExtractorConfig &embedding,
  37 + const FastClusteringConfig &clustering)
  38 + : segmentation(segmentation),
  39 + embedding(embedding),
  40 + clustering(clustering) {}
  41 +
  42 + void Register(ParseOptions *po);
  43 + bool Validate() const;
  44 + std::string ToString() const;
  45 +};
  46 +
  47 +class OfflineSpeakerDiarizationImpl;
  48 +
  49 +using OfflineSpeakerDiarizationProgressCallback = std::function<int32_t(
  50 + int32_t processed_chunks, int32_t num_chunks, void *arg)>;
  51 +
  52 +class OfflineSpeakerDiarization {
  53 + public:
  54 + explicit OfflineSpeakerDiarization(
  55 + const OfflineSpeakerDiarizationConfig &config);
  56 +
  57 + ~OfflineSpeakerDiarization();
  58 +
  59 + // Expected sample rate of the input audio samples
  60 + int32_t SampleRate() const;
  61 +
  62 + OfflineSpeakerDiarizationResult Process(
  63 + const float *audio, int32_t n,
  64 + OfflineSpeakerDiarizationProgressCallback callback = nullptr,
  65 + void *callback_arg = nullptr) const;
  66 +
  67 + private:
  68 + std::unique_ptr<OfflineSpeakerDiarizationImpl> impl_;
  69 +};
  70 +
  71 +} // namespace sherpa_onnx
  72 +
  73 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
  5 +
  6 +#include <sstream>
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) {
  14 + pyannote.Register(po);
  15 +
  16 + po->Register("num-threads", &num_threads,
  17 + "Number of threads to run the neural network");
  18 +
  19 + po->Register("debug", &debug,
  20 + "true to print model information while loading it.");
  21 +
  22 + po->Register("provider", &provider,
  23 + "Specify a provider to use: cpu, cuda, coreml");
  24 +}
  25 +
  26 +bool OfflineSpeakerSegmentationModelConfig::Validate() const {
  27 + if (num_threads < 1) {
  28 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  29 + return false;
  30 + }
  31 +
  32 + if (!pyannote.model.empty()) {
  33 + return pyannote.Validate();
  34 + }
  35 +
  36 + if (pyannote.model.empty()) {
  37 + SHERPA_ONNX_LOGE(
  38 + "You have to provide at least one speaker segmentation model");
  39 + return false;
  40 + }
  41 +
  42 + return true;
  43 +}
  44 +
  45 +std::string OfflineSpeakerSegmentationModelConfig::ToString() const {
  46 + std::ostringstream os;
  47 +
  48 + os << "OfflineSpeakerSegmentationModelConfig(";
  49 + os << "pyannote=" << pyannote.ToString() << ", ";
  50 + os << "num_threads=" << num_threads << ", ";
  51 + os << "debug=" << (debug ? "True" : "False") << ", ";
  52 + os << "provider=\"" << provider << "\")";
  53 +
  54 + return os.str();
  55 +}
  56 +
  57 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineSpeakerSegmentationModelConfig {
  15 + OfflineSpeakerSegmentationPyannoteModelConfig pyannote;
  16 +
  17 + int32_t num_threads = 1;
  18 + bool debug = false;
  19 + std::string provider = "cpu";
  20 +
  21 + OfflineSpeakerSegmentationModelConfig() = default;
  22 +
  23 + explicit OfflineSpeakerSegmentationModelConfig(
  24 + const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote,
  25 + int32_t num_threads, bool debug, const std::string &provider)
  26 + : pyannote(pyannote),
  27 + num_threads(num_threads),
  28 + debug(debug),
  29 + provider(provider) {}
  30 +
  31 + void Register(ParseOptions *po);
  32 +
  33 + bool Validate() const;
  34 +
  35 + std::string ToString() const;
  36 +};
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
  5 +
  6 +#include <sstream>
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) {
  15 + po->Register("pyannote-model", &model,
  16 + "Path to model.onnx of the Pyannote segmentation model.");
  17 +}
  18 +
  19 +bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const {
  20 + if (!FileExists(model)) {
  21 + SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist",
  22 + model.c_str());
  23 + return false;
  24 + }
  25 +
  26 + return true;
  27 +}
  28 +
  29 +std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const {
  30 + std::ostringstream os;
  31 +
  32 + os << "OfflineSpeakerSegmentationPyannoteModelConfig(";
  33 + os << "model=\"" << model << "\")";
  34 +
  35 + return os.str();
  36 +}
  37 +
  38 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineSpeakerSegmentationPyannoteModelConfig {
  14 + std::string model;
  15 +
  16 + OfflineSpeakerSegmentationPyannoteModelConfig() = default;
  17 +
  18 + explicit OfflineSpeakerSegmentationPyannoteModelConfig(
  19 + const std::string &model)
  20 + : model(model) {}
  21 +
  22 + void Register(ParseOptions *po);
  23 + bool Validate() const;
  24 +
  25 + std::string ToString() const;
  26 +};
  27 +
  28 +} // namespace sherpa_onnx
  29 +
  30 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
  7 +
  8 +#include <cstdint>
  9 +#include <string>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// If you are not sure what each field means, please
  14 +// have a look of the Python file in the model directory that
  15 +// you have downloaded.
  16 +struct OfflineSpeakerSegmentationPyannoteModelMetaData {
  17 + int32_t sample_rate = 0;
  18 + int32_t window_size = 0; // in samples
  19 + int32_t window_shift = 0; // in samples
  20 + int32_t receptive_field_size = 0; // in samples
  21 + int32_t receptive_field_shift = 0; // in samples
  22 + int32_t num_speakers = 0;
  23 + int32_t powerset_max_classes = 0;
  24 + int32_t num_classes = 0;
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +#include "sherpa-onnx/csrc/session.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineSpeakerSegmentationPyannoteModel::Impl {
  17 + public:
  18 + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config)
  19 + : config_(config),
  20 + env_(ORT_LOGGING_LEVEL_ERROR),
  21 + sess_opts_(GetSessionOptions(config)),
  22 + allocator_{} {
  23 + auto buf = ReadFile(config_.pyannote.model);
  24 + Init(buf.data(), buf.size());
  25 + }
  26 +
  27 + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
  28 + const {
  29 + return meta_data_;
  30 + }
  31 +
  32 + Ort::Value Forward(Ort::Value x) {
  33 + auto out = sess_->Run({}, input_names_ptr_.data(), &x, 1,
  34 + output_names_ptr_.data(), output_names_ptr_.size());
  35 +
  36 + return std::move(out[0]);
  37 + }
  38 +
  39 + private:
  40 + void Init(void *model_data, size_t model_data_length) {
  41 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  42 + sess_opts_);
  43 +
  44 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  45 +
  46 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  47 +
  48 + // get meta data
  49 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  50 + if (config_.debug) {
  51 + std::ostringstream os;
  52 + PrintModelMetadata(os, meta_data);
  53 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  54 + }
  55 +
  56 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  57 + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
  58 + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size");
  59 +
  60 + meta_data_.window_shift =
  61 + static_cast<int32_t>(0.1 * meta_data_.window_size);
  62 +
  63 + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size,
  64 + "receptive_field_size");
  65 + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift,
  66 + "receptive_field_shift");
  67 + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers");
  68 + SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes,
  69 + "powerset_max_classes");
  70 + SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes");
  71 + }
  72 +
  73 + private:
  74 + OfflineSpeakerSegmentationModelConfig config_;
  75 + Ort::Env env_;
  76 + Ort::SessionOptions sess_opts_;
  77 + Ort::AllocatorWithDefaultOptions allocator_;
  78 +
  79 + std::unique_ptr<Ort::Session> sess_;
  80 +
  81 + std::vector<std::string> input_names_;
  82 + std::vector<const char *> input_names_ptr_;
  83 +
  84 + std::vector<std::string> output_names_;
  85 + std::vector<const char *> output_names_ptr_;
  86 +
  87 + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_;
  88 +};
  89 +
  90 +OfflineSpeakerSegmentationPyannoteModel::
  91 + OfflineSpeakerSegmentationPyannoteModel(
  92 + const OfflineSpeakerSegmentationModelConfig &config)
  93 + : impl_(std::make_unique<Impl>(config)) {}
  94 +
  95 +OfflineSpeakerSegmentationPyannoteModel::
  96 + ~OfflineSpeakerSegmentationPyannoteModel() = default;
  97 +
  98 +const OfflineSpeakerSegmentationPyannoteModelMetaData &
  99 +OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const {
  100 + return impl_->GetModelMetaData();
  101 +}
  102 +
  103 +Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(
  104 + Ort::Value x) const {
  105 + return impl_->Forward(std::move(x));
  106 +}
  107 +
  108 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
  6 +
  7 +#include <memory>
  8 +
  9 +#include "onnxruntime_cxx_api.h" // NOLINT
  10 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
  11 +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineSpeakerSegmentationPyannoteModel {
  16 + public:
  17 + explicit OfflineSpeakerSegmentationPyannoteModel(
  18 + const OfflineSpeakerSegmentationModelConfig &config);
  19 +
  20 + ~OfflineSpeakerSegmentationPyannoteModel();
  21 +
  22 + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
  23 + const;
  24 +
  25 + /**
  26 + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples)
  27 + * @return Return a float tensor of
  28 + * shape (batch_size, num_frames, num_speakers). Note that
  29 + * num_speakers here uses powerset encoding.
  30 + */
  31 + Ort::Value Forward(Ort::Value x) const;
  32 +
  33 + private:
  34 + class Impl;
  35 + std::unique_ptr<Impl> impl_;
  36 +};
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
@@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) { @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) {
61 61
62 bool TensorrtConfig::Validate() const { 62 bool TensorrtConfig::Validate() const {
63 if (trt_max_workspace_size < 0) { 63 if (trt_max_workspace_size < 0) {
64 - SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.",  
65 - trt_max_workspace_size); 64 + std::ostringstream os;
  65 + os << "trt_max_workspace_size: " << trt_max_workspace_size
  66 + << " is not valid.";
  67 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
66 return false; 68 return false;
67 } 69 }
68 if (trt_max_partition_iterations < 0) { 70 if (trt_max_partition_iterations < 0) {
@@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
35 api.ReleaseStatus(status); 35 api.ReleaseStatus(status);
36 } 36 }
37 37
38 -static Ort::SessionOptions GetSessionOptionsImpl( 38 +Ort::SessionOptions GetSessionOptionsImpl(
39 int32_t num_threads, const std::string &provider_str, 39 int32_t num_threads, const std::string &provider_str,
40 - const ProviderConfig *provider_config = nullptr) { 40 + const ProviderConfig *provider_config /*= nullptr*/) {
41 Provider p = StringToProvider(provider_str); 41 Provider p = StringToProvider(provider_str);
42 42
43 Ort::SessionOptions sess_opts; 43 Ort::SessionOptions sess_opts;
@@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
259 &config.provider_config); 259 &config.provider_config);
260 } 260 }
261 261
262 -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {  
263 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
264 -}  
265 -  
266 Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { 262 Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) {
267 return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); 263 return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
268 } 264 }
@@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
271 return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); 267 return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
272 } 268 }
273 269
274 -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {  
275 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
276 -}  
277 -  
278 -#if SHERPA_ONNX_ENABLE_TTS  
279 -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {  
280 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
281 -}  
282 -#endif  
283 -  
284 -Ort::SessionOptions GetSessionOptions(  
285 - const SpeakerEmbeddingExtractorConfig &config) {  
286 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
287 -}  
288 -  
289 -Ort::SessionOptions GetSessionOptions(  
290 - const SpokenLanguageIdentificationConfig &config) {  
291 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
292 -}  
293 -  
294 -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {  
295 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
296 -}  
297 -  
298 -Ort::SessionOptions GetSessionOptions(  
299 - const OfflinePunctuationModelConfig &config) {  
300 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
301 -}  
302 -  
303 -Ort::SessionOptions GetSessionOptions(  
304 - const OnlinePunctuationModelConfig &config) {  
305 - return GetSessionOptionsImpl(config.num_threads, config.provider);  
306 -}  
307 -  
308 } // namespace sherpa_onnx 270 } // namespace sherpa_onnx
@@ -8,53 +8,28 @@ @@ -8,53 +8,28 @@
8 #include <string> 8 #include <string>
9 9
10 #include "onnxruntime_cxx_api.h" // NOLINT 10 #include "onnxruntime_cxx_api.h" // NOLINT
11 -#include "sherpa-onnx/csrc/audio-tagging-model-config.h"  
12 #include "sherpa-onnx/csrc/offline-lm-config.h" 11 #include "sherpa-onnx/csrc/offline-lm-config.h"
13 -#include "sherpa-onnx/csrc/offline-model-config.h"  
14 -#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"  
15 -#include "sherpa-onnx/csrc/online-punctuation-model-config.h"  
16 #include "sherpa-onnx/csrc/online-lm-config.h" 12 #include "sherpa-onnx/csrc/online-lm-config.h"
17 #include "sherpa-onnx/csrc/online-model-config.h" 13 #include "sherpa-onnx/csrc/online-model-config.h"
18 -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"  
19 -#include "sherpa-onnx/csrc/spoken-language-identification.h"  
20 -#include "sherpa-onnx/csrc/vad-model-config.h"  
21 -  
22 -#if SHERPA_ONNX_ENABLE_TTS  
23 -#include "sherpa-onnx/csrc/offline-tts-model-config.h"  
24 -#endif  
25 14
26 namespace sherpa_onnx { 15 namespace sherpa_onnx {
27 16
28 -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);  
29 -  
30 -Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,  
31 - const std::string &model_type);  
32 -  
33 -Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); 17 +Ort::SessionOptions GetSessionOptionsImpl(
  18 + int32_t num_threads, const std::string &provider_str,
  19 + const ProviderConfig *provider_config = nullptr);
34 20
35 Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); 21 Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
36 -  
37 Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); 22 Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
38 23
39 -Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);  
40 -  
41 -#if SHERPA_ONNX_ENABLE_TTS  
42 -Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);  
43 -#endif  
44 -  
45 -Ort::SessionOptions GetSessionOptions(  
46 - const SpeakerEmbeddingExtractorConfig &config);  
47 -  
48 -Ort::SessionOptions GetSessionOptions(  
49 - const SpokenLanguageIdentificationConfig &config);  
50 -  
51 -Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); 24 +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
52 25
53 -Ort::SessionOptions GetSessionOptions(  
54 - const OfflinePunctuationModelConfig &config); 26 +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
  27 + const std::string &model_type);
55 28
56 -Ort::SessionOptions GetSessionOptions(  
57 - const OnlinePunctuationModelConfig &config); 29 +template <typename T>
  30 +Ort::SessionOptions GetSessionOptions(const T &config) {
  31 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  32 +}
58 33
59 } // namespace sherpa_onnx 34 } // namespace sherpa_onnx
60 35
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
  6 +#include "sherpa-onnx/csrc/parse-options.h"
  7 +#include "sherpa-onnx/csrc/wave-reader.h"
  8 +
  9 +static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
  10 + void *arg) {
  11 + float progress = 100.0 * processed_chunks / num_chunks;
  12 + fprintf(stderr, "progress %.2f%%\n", progress);
  13 +
  14 + // the return value is currently ignored
  15 + return 0;
  16 +}
  17 +
  18 +int main(int32_t argc, char *argv[]) {
  19 + const char *kUsageMessage = R"usage(
  20 +Offline/Non-streaming speaker diarization with sherpa-onnx
  21 +Usage example:
  22 +
  23 +Step 1: Download a speaker segmentation model
  24 +
  25 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  26 +for a list of available models. The following is an example
  27 +
  28 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  29 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  30 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
  31 +
  32 +Step 2: Download a speaker embedding extractor model
  33 +
  34 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
  35 +for a list of available models. The following is an example
  36 +
  37 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
  38 +
  39 +Step 3. Download test wave files
  40 +
  41 +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
  42 +for a list of available test wave files. The following is an example
  43 +
  44 + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
  45 +
  46 +Step 4. Build sherpa-onnx
  47 +
  48 +Step 5. Run it
  49 +
  50 + ./bin/sherpa-onnx-offline-speaker-diarization \
  51 + --clustering.num-clusters=4 \
  52 + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
  53 + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
  54 + ./0-four-speakers-zh.wav
  55 +
  56 +Since we know that there are four speakers in the test wave file, we use
  57 +--clustering.num-clusters=4 in the above example.
  58 +
  59 +If we don't know number of speakers in the given wave file, we can use
  60 +the argument --clustering.cluster-threshold. The following is an example:
  61 +
  62 + ./bin/sherpa-onnx-offline-speaker-diarization \
  63 + --clustering.cluster-threshold=0.90 \
  64 + --segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
  65 + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
  66 + ./0-four-speakers-zh.wav
  67 +
  68 +A larger threshold leads to few clusters, i.e., few speakers;
  69 +a smaller threshold leads to more clusters, i.e., more speakers
  70 + )usage";
  71 + sherpa_onnx::OfflineSpeakerDiarizationConfig config;
  72 + sherpa_onnx::ParseOptions po(kUsageMessage);
  73 + config.Register(&po);
  74 + po.Read(argc, argv);
  75 +
  76 + std::cout << config.ToString() << "\n";
  77 +
  78 + if (!config.Validate()) {
  79 + po.PrintUsage();
  80 + std::cerr << "Errors in config!\n";
  81 + return -1;
  82 + }
  83 +
  84 + if (po.NumArgs() != 1) {
  85 + std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
  86 + po.PrintUsage();
  87 + return -1;
  88 + }
  89 +
  90 + sherpa_onnx::OfflineSpeakerDiarization sd(config);
  91 +
  92 + std::cout << "Started\n";
  93 + const auto begin = std::chrono::steady_clock::now();
  94 + const std::string wav_filename = po.GetArg(1);
  95 + int32_t sample_rate = -1;
  96 + bool is_ok = false;
  97 + const std::vector<float> samples =
  98 + sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
  99 + if (!is_ok) {
  100 + std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
  101 + return -1;
  102 + }
  103 +
  104 + if (sample_rate != sd.SampleRate()) {
  105 + std::cerr << "Expect sample rate " << sd.SampleRate()
  106 + << ". Given: " << sample_rate << "\n";
  107 + return -1;
  108 + }
  109 +
  110 + float duration = samples.size() / static_cast<float>(sample_rate);
  111 +
  112 + auto result =
  113 + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
  114 + .SortByStartTime();
  115 +
  116 + for (const auto &r : result) {
  117 + std::cout << r.ToString() << "\n";
  118 + }
  119 +
  120 + const auto end = std::chrono::steady_clock::now();
  121 + float elapsed_seconds =
  122 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  123 + .count() /
  124 + 1000.;
  125 +
  126 + fprintf(stderr, "Duration : %.3f s\n", duration);
  127 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  128 + float rtf = elapsed_seconds / duration;
  129 + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
  130 + elapsed_seconds, duration, rtf);
  131 +
  132 + return 0;
  133 +}
@@ -9,14 +9,15 @@ @@ -9,14 +9,15 @@
9 #include "sherpa-onnx/csrc/parse-options.h" 9 #include "sherpa-onnx/csrc/parse-options.h"
10 #include "sherpa-onnx/csrc/wave-writer.h" 10 #include "sherpa-onnx/csrc/wave-writer.h"
11 11
12 -int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) { 12 +static int32_t AudioCallback(const float * /*samples*/, int32_t n,
  13 + float progress) {
13 printf("sample=%d, progress=%f\n", n, progress); 14 printf("sample=%d, progress=%f\n", n, progress);
14 return 1; 15 return 1;
15 } 16 }
16 17
17 int main(int32_t argc, char *argv[]) { 18 int main(int32_t argc, char *argv[]) {
18 const char *kUsageMessage = R"usage( 19 const char *kUsageMessage = R"usage(
19 -Offline text-to-speech with sherpa-onnx 20 +Offline/Non-streaming text-to-speech with sherpa-onnx
20 21
21 Usage example: 22 Usage example:
22 23
@@ -79,7 +80,7 @@ or details. @@ -79,7 +80,7 @@ or details.
79 sherpa_onnx::OfflineTts tts(config); 80 sherpa_onnx::OfflineTts tts(config);
80 81
81 const auto begin = std::chrono::steady_clock::now(); 82 const auto begin = std::chrono::steady_clock::now();
82 - auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback); 83 + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback);
83 const auto end = std::chrono::steady_clock::now(); 84 const auto end = std::chrono::steady_clock::now();
84 85
85 if (audio.samples.empty()) { 86 if (audio.samples.empty()) {
@@ -19,7 +19,7 @@ The input text can contain English words. @@ -19,7 +19,7 @@ The input text can contain English words.
19 Usage: 19 Usage:
20 20
21 Please download the model from: 21 Please download the model from:
22 -https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 22 +https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
23 23
24 ./bin/Release/sherpa-onnx-online-punctuation \ 24 ./bin/Release/sherpa-onnx-online-punctuation \
25 --cnn-bilstm=/path/to/model.onnx \ 25 --cnn-bilstm=/path/to/model.onnx \
@@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
26 26
27 bool SpeakerEmbeddingExtractorConfig::Validate() const { 27 bool SpeakerEmbeddingExtractorConfig::Validate() const {
28 if (model.empty()) { 28 if (model.empty()) {
29 - SHERPA_ONNX_LOGE("Please provide --model"); 29 + SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model");
30 return false; 30 return false;
31 } 31 }
32 32
33 if (!FileExists(model)) { 33 if (!FileExists(model)) {
34 - SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist", 34 + SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist",
35 model.c_str()); 35 model.c_str());
36 return false; 36 return false;
37 } 37 }