Toggle navigation
Toggle navigation
此项目
正在载入...
Sign in
xuning
/
sherpaonnx
转到一个项目
Toggle navigation
项目
群组
代码片段
帮助
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
Fangjun Kuang
2024-10-09 12:01:20 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Committed by
GitHub
2024-10-09 12:01:20 +0800
Commit
59407edcad3a4a26342cee7dc7f0fac6d1ff50b4
59407edc
1 parent
70165cb4
C++ API for speaker diarization (#1396)
隐藏空白字符变更
内嵌
并排对比
正在显示
39 个修改的文件
包含
1652 行增加
和
108 行删除
.github/scripts/test-speaker-diarization.sh
.github/workflows/export-pyannote-segmentation-to-onnx.yaml
.github/workflows/linux.yaml
.github/workflows/macos.yaml
.github/workflows/speaker-diarization.yaml
.github/workflows/windows-x64.yaml
.github/workflows/windows-x86.yaml
c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c
c-api-examples/streaming-ctc-buffered-tokens-c-api.c
c-api-examples/streaming-paraformer-buffered-tokens-c-api.c
c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c
cmake/cmake_extension.py
scripts/pyannote/segmentation/README.md
scripts/pyannote/segmentation/export-onnx.py
sherpa-onnx/csrc/CMakeLists.txt
sherpa-onnx/csrc/fast-clustering-config.cc
sherpa-onnx/csrc/macros.h
sherpa-onnx/csrc/offline-sense-voice-model.cc
sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
sherpa-onnx/csrc/offline-speaker-diarization-impl.h
sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
sherpa-onnx/csrc/offline-speaker-diarization-result.cc
sherpa-onnx/csrc/offline-speaker-diarization-result.h
sherpa-onnx/csrc/offline-speaker-diarization.cc
sherpa-onnx/csrc/offline-speaker-diarization.h
sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc
sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
sherpa-onnx/csrc/provider-config.cc
sherpa-onnx/csrc/session.cc
sherpa-onnx/csrc/session.h
sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
sherpa-onnx/csrc/speaker-embedding-extractor.cc
.github/scripts/test-speaker-diarization.sh
0 → 100755
查看文件 @
59407ed
#!/usr/bin/env bash
set
-ex
log
()
{
# This function is from espnet
local
fname
=
${
BASH_SOURCE
[1]##*/
}
echo
-e
"
$(
date
'+%Y-%m-%d %H:%M:%S'
)
(
${
fname
}
:
${
BASH_LINENO
[0]
}
:
${
FUNCNAME
[1]
}
)
$*
"
}
echo
"EXE is
$EXE
"
echo
"PATH:
$PATH
"
which
$EXE
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
log
"specify number of clusters"
$EXE
\
--clustering.num-clusters
=
4
\
--segmentation.pyannote-model
=
./sherpa-onnx-pyannote-segmentation-3-0/model.onnx
\
--embedding.model
=
./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
\
./0-four-speakers-zh.wav
log
"specify threshold for clustering"
$EXE
\
--clustering.cluster-threshold
=
0.90
\
--segmentation.pyannote-model
=
./sherpa-onnx-pyannote-segmentation-3-0/model.onnx
\
--embedding.model
=
./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
\
./0-four-speakers-zh.wav
rm -rf sherpa-onnx-pyannote-
*
rm -fv
*
.onnx
rm -fv
*
.wav
...
...
.github/workflows/export-pyannote-segmentation-to-onnx.yaml
查看文件 @
59407ed
...
...
@@ -29,7 +29,7 @@ jobs:
-
name
:
Install pyannote
shell
:
bash
run
:
|
pip install pyannote.audio onnx
onnxruntime
pip install pyannote.audio onnx
==1.15.0 onnxruntime==1.16.3
-
name
:
Run
shell
:
bash
...
...
.github/workflows/linux.yaml
查看文件 @
59407ed
...
...
@@ -18,6 +18,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -38,6 +39,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -143,6 +145,15 @@ jobs:
name
:
release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path
:
install/*
-
name
:
Test offline speaker diarization
shell
:
bash
run
:
|
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization
.github/scripts/test-speaker-diarization.sh
-
name
:
Test offline transducer
shell
:
bash
run
:
|
...
...
.github/workflows/macos.yaml
查看文件 @
59407ed
...
...
@@ -18,6 +18,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -37,6 +38,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -115,6 +117,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
-
name
:
Test offline speaker diarization
shell
:
bash
run
:
|
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization
.github/scripts/test-speaker-diarization.sh
-
name
:
Test offline transducer
shell
:
bash
run
:
|
...
...
.github/workflows/speaker-diarization.yaml
查看文件 @
59407ed
...
...
@@ -67,7 +67,7 @@ jobs:
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
test_wavs=(
0-
two
-speakers-zh.wav
0-
four
-speakers-zh.wav
1-two-speakers-en.wav
2-two-speakers-en.wav
3-two-speakers-en.wav
...
...
.github/workflows/windows-x64.yaml
查看文件 @
59407ed
...
...
@@ -17,6 +17,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -34,6 +35,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -87,6 +89,15 @@ jobs:
name
:
release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path
:
build/install/*
-
name
:
Test offline speaker diarization
shell
:
bash
run
:
|
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
-
name
:
Test online punctuation
shell
:
bash
run
:
|
...
...
.github/workflows/windows-x86.yaml
查看文件 @
59407ed
...
...
@@ -17,6 +17,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -34,6 +35,7 @@ on:
-
'
.github/scripts/test-audio-tagging.sh'
-
'
.github/scripts/test-offline-punctuation.sh'
-
'
.github/scripts/test-online-punctuation.sh'
-
'
.github/scripts/test-speaker-diarization.sh'
-
'
CMakeLists.txt'
-
'
cmake/**'
-
'
sherpa-onnx/csrc/*'
...
...
@@ -87,6 +89,15 @@ jobs:
name
:
release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path
:
build/install/*
-
name
:
Test offline speaker diarization
shell
:
bash
run
:
|
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
-
name
:
Test online punctuation
shell
:
bash
run
:
|
...
...
c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c
查看文件 @
59407ed
...
...
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf
(
stderr
,
"Memory error
\n
"
);
return
-
1
;
}
size_t
read_bytes
=
fread
(
*
buffer_out
,
1
,
size
,
file
);
size_t
read_bytes
=
fread
(
(
void
*
)
*
buffer_out
,
1
,
size
,
file
);
if
(
read_bytes
!=
size
)
{
printf
(
"Errors occured in reading the file %s
\n
"
,
filename
);
free
((
void
*
)
*
buffer_out
);
...
...
c-api-examples/streaming-ctc-buffered-tokens-c-api.c
查看文件 @
59407ed
...
...
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf
(
stderr
,
"Memory error
\n
"
);
return
-
1
;
}
size_t
read_bytes
=
fread
(
*
buffer_out
,
1
,
size
,
file
);
size_t
read_bytes
=
fread
(
(
void
*
)
*
buffer_out
,
1
,
size
,
file
);
if
(
read_bytes
!=
size
)
{
printf
(
"Errors occured in reading the file %s
\n
"
,
filename
);
free
((
void
*
)
*
buffer_out
);
...
...
c-api-examples/streaming-paraformer-buffered-tokens-c-api.c
查看文件 @
59407ed
...
...
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf
(
stderr
,
"Memory error
\n
"
);
return
-
1
;
}
size_t
read_bytes
=
fread
(
*
buffer_out
,
1
,
size
,
file
);
size_t
read_bytes
=
fread
(
(
void
*
)
*
buffer_out
,
1
,
size
,
file
);
if
(
read_bytes
!=
size
)
{
printf
(
"Errors occured in reading the file %s
\n
"
,
filename
);
free
((
void
*
)
*
buffer_out
);
...
...
c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c
查看文件 @
59407ed
...
...
@@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf
(
stderr
,
"Memory error
\n
"
);
return
-
1
;
}
size_t
read_bytes
=
fread
(
*
buffer_out
,
1
,
size
,
file
);
size_t
read_bytes
=
fread
(
(
void
*
)
*
buffer_out
,
1
,
size
,
file
);
if
(
read_bytes
!=
size
)
{
printf
(
"Errors occured in reading the file %s
\n
"
,
filename
);
free
((
void
*
)
*
buffer_out
);
...
...
cmake/cmake_extension.py
查看文件 @
59407ed
...
...
@@ -55,6 +55,7 @@ def get_binaries():
"sherpa-onnx-offline-audio-tagging"
,
"sherpa-onnx-offline-language-identification"
,
"sherpa-onnx-offline-punctuation"
,
"sherpa-onnx-offline-speaker-diarization"
,
"sherpa-onnx-offline-tts"
,
"sherpa-onnx-offline-tts-play"
,
"sherpa-onnx-offline-websocket-server"
,
...
...
scripts/pyannote/segmentation/README.md
查看文件 @
59407ed
...
...
@@ -3,12 +3,9 @@
Please download test wave files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
## 0-
two
-speakers-zh.wav
## 0-
four
-speakers-zh.wav
This file is from
https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0
Note that we have renamed it from
`2speakers_example.wav` to `0-two-speakers-zh.wav`
.
It is recorded by @csukuangfj
## 1-two-speakers-en.wav
...
...
@@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav`
```
bash
sox ML16091-Audio.mp3 3-two-speakers-en.wav
sox ML16091-Audio.mp3
-r 16k
3-two-speakers-en.wav
```
...
...
scripts/pyannote/segmentation/export-onnx.py
查看文件 @
59407ed
...
...
@@ -72,7 +72,7 @@ def main():
model
.
receptive_field
.
duration
*
16000
)
opset_version
=
1
8
opset_version
=
1
3
filename
=
"model.onnx"
torch
.
onnx
.
export
(
...
...
sherpa-onnx/csrc/CMakeLists.txt
查看文件 @
59407ed
...
...
@@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list
(
APPEND sources
fast-clustering-config.cc
fast-clustering.cc
offline-speaker-diarization-impl.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
offline-speaker-segmentation-model-config.cc
offline-speaker-segmentation-pyannote-model-config.cc
offline-speaker-segmentation-pyannote-model.cc
)
endif
()
...
...
@@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable
(
sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc
)
endif
()
if
(
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION
)
add_executable
(
sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc
)
endif
()
set
(
main_exes
sherpa-onnx
sherpa-onnx-keyword-spotter
...
...
@@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY)
)
endif
()
if
(
SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION
)
list
(
APPEND main_exes
sherpa-onnx-offline-speaker-diarization
)
endif
()
foreach
(
exe IN LISTS main_exes
)
target_link_libraries
(
${
exe
}
sherpa-onnx-core
)
endforeach
()
...
...
sherpa-onnx/csrc/fast-clustering-config.cc
查看文件 @
59407ed
...
...
@@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const {
}
void
FastClusteringConfig
::
Register
(
ParseOptions
*
po
)
{
std
::
string
prefix
=
"ctc"
;
ParseOptions
p
(
prefix
,
po
);
p
.
Register
(
"num-clusters"
,
&
num_clusters
,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance."
);
p
.
Register
(
"cluster-threshold"
,
&
threshold
,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters"
);
po
->
Register
(
"num-clusters"
,
&
num_clusters
,
"Number of cluster. If greater than 0, then cluster threshold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance."
);
po
->
Register
(
"cluster-threshold"
,
&
threshold
,
"If num_clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters"
);
}
bool
FastClusteringConfig
::
Validate
()
const
{
...
...
sherpa-onnx/csrc/macros.h
查看文件 @
59407ed
...
...
@@ -5,6 +5,7 @@
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_CSRC_MACROS_H_
#include <stdio.h>
#include <stdlib.h>
#if __ANDROID_API__ >= 8
#include "android/log.h"
...
...
@@ -169,4 +170,6 @@
} \
} while (0)
#define SHERPA_ONNX_EXIT(code) exit(code)
#endif // SHERPA_ONNX_CSRC_MACROS_H_
...
...
sherpa-onnx/csrc/offline-sense-voice-model.cc
查看文件 @
59407ed
...
...
@@ -9,6 +9,7 @@
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
...
...
sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"
namespace
sherpa_onnx
{
std
::
unique_ptr
<
OfflineSpeakerDiarizationImpl
>
OfflineSpeakerDiarizationImpl
::
Create
(
const
OfflineSpeakerDiarizationConfig
&
config
)
{
if
(
!
config
.
segmentation
.
pyannote
.
model
.
empty
())
{
return
std
::
make_unique
<
OfflineSpeakerDiarizationPyannoteImpl
>
(
config
);
}
SHERPA_ONNX_LOGE
(
"Please specify a speaker segmentation model."
);
return
nullptr
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-speaker-diarization-impl.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
#include <functional>
#include <memory>
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
namespace
sherpa_onnx
{
class
OfflineSpeakerDiarizationImpl
{
public
:
static
std
::
unique_ptr
<
OfflineSpeakerDiarizationImpl
>
Create
(
const
OfflineSpeakerDiarizationConfig
&
config
);
virtual
~
OfflineSpeakerDiarizationImpl
()
=
default
;
virtual
int32_t
SampleRate
()
const
=
0
;
virtual
OfflineSpeakerDiarizationResult
Process
(
const
float
*
audio
,
int32_t
n
,
OfflineSpeakerDiarizationProgressCallback
callback
=
nullptr
,
void
*
callback_arg
=
nullptr
)
const
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
...
...
sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/fast-clustering.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace
sherpa_onnx
{
namespace
{
// NOLINT
// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41
template
<
class
T
>
inline
void
hash_combine
(
std
::
size_t
*
seed
,
const
T
&
v
)
{
// NOLINT
std
::
hash
<
T
>
hasher
;
*
seed
^=
hasher
(
v
)
+
0x9e3779b9
+
((
*
seed
)
<<
6
)
+
((
*
seed
)
>>
2
);
// NOLINT
}
// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47
struct
PairHash
{
template
<
class
T1
,
class
T2
>
std
::
size_t
operator
()(
const
std
::
pair
<
T1
,
T2
>
&
pair
)
const
{
std
::
size_t
result
=
0
;
hash_combine
(
&
result
,
pair
.
first
);
hash_combine
(
&
result
,
pair
.
second
);
return
result
;
}
};
}
// namespace
using
Matrix2D
=
Eigen
::
Matrix
<
float
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
;
using
Matrix2DInt32
=
Eigen
::
Matrix
<
int32_t
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
;
using
FloatRowVector
=
Eigen
::
Matrix
<
float
,
1
,
Eigen
::
Dynamic
>
;
using
Int32RowVector
=
Eigen
::
Matrix
<
int32_t
,
1
,
Eigen
::
Dynamic
>
;
using
Int32Pair
=
std
::
pair
<
int32_t
,
int32_t
>
;
class
OfflineSpeakerDiarizationPyannoteImpl
:
public
OfflineSpeakerDiarizationImpl
{
public
:
~
OfflineSpeakerDiarizationPyannoteImpl
()
override
=
default
;
explicit
OfflineSpeakerDiarizationPyannoteImpl
(
const
OfflineSpeakerDiarizationConfig
&
config
)
:
config_
(
config
),
segmentation_model_
(
config_
.
segmentation
),
embedding_extractor_
(
config_
.
embedding
),
clustering_
(
config_
.
clustering
)
{
Init
();
}
int32_t
SampleRate
()
const
override
{
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
return
meta_data
.
sample_rate
;
}
OfflineSpeakerDiarizationResult
Process
(
const
float
*
audio
,
int32_t
n
,
OfflineSpeakerDiarizationProgressCallback
callback
=
nullptr
,
void
*
callback_arg
=
nullptr
)
const
override
{
std
::
vector
<
Matrix2D
>
segmentations
=
RunSpeakerSegmentationModel
(
audio
,
n
);
// segmentations[i] is for chunk_i
// Each matrix is of shape (num_frames, num_powerset_classes)
if
(
segmentations
.
empty
())
{
return
{};
}
std
::
vector
<
Matrix2DInt32
>
labels
;
labels
.
reserve
(
segmentations
.
size
());
for
(
const
auto
&
m
:
segmentations
)
{
labels
.
push_back
(
ToMultiLabel
(
m
));
}
segmentations
.
clear
();
// labels[i] is a 0-1 matrix of shape (num_frames, num_speakers)
// speaker count per frame
Int32RowVector
speakers_per_frame
=
ComputeSpeakersPerFrame
(
labels
);
if
(
speakers_per_frame
.
maxCoeff
()
==
0
)
{
SHERPA_ONNX_LOGE
(
"No speakers found in the audio samples"
);
return
{};
}
auto
chunk_speaker_samples_list_pair
=
GetChunkSpeakerSampleIndexes
(
labels
);
Matrix2D
embeddings
=
ComputeEmbeddings
(
audio
,
n
,
chunk_speaker_samples_list_pair
.
second
,
callback
,
callback_arg
);
std
::
vector
<
int32_t
>
cluster_labels
=
clustering_
.
Cluster
(
&
embeddings
(
0
,
0
),
embeddings
.
rows
(),
embeddings
.
cols
());
int32_t
max_cluster_index
=
*
std
::
max_element
(
cluster_labels
.
begin
(),
cluster_labels
.
end
());
auto
chunk_speaker_to_cluster
=
ConvertChunkSpeakerToCluster
(
chunk_speaker_samples_list_pair
.
first
,
cluster_labels
);
auto
new_labels
=
ReLabel
(
labels
,
max_cluster_index
,
chunk_speaker_to_cluster
);
Matrix2DInt32
speaker_count
=
ComputeSpeakerCount
(
new_labels
,
n
);
Matrix2DInt32
final_labels
=
FinalizeLabels
(
speaker_count
,
speakers_per_frame
);
auto
result
=
ComputeResult
(
final_labels
);
return
result
;
}
private
:
void
Init
()
{
InitPowersetMapping
();
}
// see also
// https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68
void
InitPowersetMapping
()
{
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
num_classes
=
meta_data
.
num_classes
;
int32_t
powerset_max_classes
=
meta_data
.
powerset_max_classes
;
int32_t
num_speakers
=
meta_data
.
num_speakers
;
powerset_mapping_
=
Matrix2DInt32
(
num_classes
,
num_speakers
);
powerset_mapping_
.
setZero
();
int32_t
k
=
1
;
for
(
int32_t
i
=
1
;
i
<=
powerset_max_classes
;
++
i
)
{
if
(
i
==
1
)
{
for
(
int32_t
j
=
0
;
j
!=
num_speakers
;
++
j
,
++
k
)
{
powerset_mapping_
(
k
,
j
)
=
1
;
}
}
else
if
(
i
==
2
)
{
for
(
int32_t
j
=
0
;
j
!=
num_speakers
;
++
j
)
{
for
(
int32_t
m
=
j
+
1
;
m
<
num_speakers
;
++
m
,
++
k
)
{
powerset_mapping_
(
k
,
j
)
=
1
;
powerset_mapping_
(
k
,
m
)
=
1
;
}
}
}
else
{
SHERPA_ONNX_LOGE
(
"powerset_max_classes = %d is currently not supported!"
,
i
);
SHERPA_ONNX_EXIT
(
-
1
);
}
}
}
std
::
vector
<
Matrix2D
>
RunSpeakerSegmentationModel
(
const
float
*
audio
,
int32_t
n
)
const
{
std
::
vector
<
Matrix2D
>
ans
;
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
window_size
=
meta_data
.
window_size
;
int32_t
window_shift
=
meta_data
.
window_shift
;
if
(
n
<=
0
)
{
SHERPA_ONNX_LOGE
(
"number of audio samples is %d (<= 0). Please provide a positive "
"number"
,
n
);
return
{};
}
if
(
n
<=
window_size
)
{
std
::
vector
<
float
>
buf
(
window_size
);
// NOTE: buf is zero initialized by default
std
::
copy
(
audio
,
audio
+
n
,
buf
.
data
());
Matrix2D
m
=
ProcessChunk
(
buf
.
data
());
ans
.
push_back
(
std
::
move
(
m
));
return
ans
;
}
int32_t
num_chunks
=
(
n
-
window_size
)
/
window_shift
+
1
;
bool
has_last_chunk
=
(
n
-
window_size
)
%
window_shift
>
0
;
ans
.
reserve
(
num_chunks
+
has_last_chunk
);
const
float
*
p
=
audio
;
for
(
int32_t
i
=
0
;
i
!=
num_chunks
;
++
i
,
p
+=
window_shift
)
{
Matrix2D
m
=
ProcessChunk
(
p
);
ans
.
push_back
(
std
::
move
(
m
));
}
if
(
has_last_chunk
)
{
std
::
vector
<
float
>
buf
(
window_size
);
std
::
copy
(
p
,
audio
+
n
,
buf
.
data
());
Matrix2D
m
=
ProcessChunk
(
buf
.
data
());
ans
.
push_back
(
std
::
move
(
m
));
}
return
ans
;
}
Matrix2D
ProcessChunk
(
const
float
*
p
)
const
{
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
window_size
=
meta_data
.
window_size
;
auto
memory_info
=
Ort
::
MemoryInfo
::
CreateCpu
(
OrtDeviceAllocator
,
OrtMemTypeDefault
);
std
::
array
<
int64_t
,
3
>
shape
=
{
1
,
1
,
window_size
};
Ort
::
Value
x
=
Ort
::
Value
::
CreateTensor
(
memory_info
,
const_cast
<
float
*>
(
p
),
window_size
,
shape
.
data
(),
shape
.
size
());
Ort
::
Value
out
=
segmentation_model_
.
Forward
(
std
::
move
(
x
));
std
::
vector
<
int64_t
>
out_shape
=
out
.
GetTensorTypeAndShapeInfo
().
GetShape
();
Matrix2D
m
(
out_shape
[
1
],
out_shape
[
2
]);
std
::
copy
(
out
.
GetTensorData
<
float
>
(),
out
.
GetTensorData
<
float
>
()
+
m
.
size
(),
&
m
(
0
,
0
));
return
m
;
}
Matrix2DInt32
ToMultiLabel
(
const
Matrix2D
&
m
)
const
{
int32_t
num_rows
=
m
.
rows
();
Matrix2DInt32
ans
(
num_rows
,
powerset_mapping_
.
cols
());
std
::
ptrdiff_t
col_id
;
for
(
int32_t
i
=
0
;
i
!=
num_rows
;
++
i
)
{
m
.
row
(
i
).
maxCoeff
(
&
col_id
);
ans
.
row
(
i
)
=
powerset_mapping_
.
row
(
col_id
);
}
return
ans
;
}
// See also
// https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122
Int32RowVector
ComputeSpeakersPerFrame
(
const
std
::
vector
<
Matrix2DInt32
>
&
labels
)
const
{
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
window_size
=
meta_data
.
window_size
;
int32_t
window_shift
=
meta_data
.
window_shift
;
int32_t
receptive_field_shift
=
meta_data
.
receptive_field_shift
;
int32_t
num_chunks
=
labels
.
size
();
int32_t
num_frames
=
(
window_size
+
(
num_chunks
-
1
)
*
window_shift
)
/
receptive_field_shift
+
1
;
FloatRowVector
count
(
num_frames
);
FloatRowVector
weight
(
num_frames
);
count
.
setZero
();
weight
.
setZero
();
for
(
int32_t
i
=
0
;
i
!=
num_chunks
;
++
i
)
{
int32_t
start
=
static_cast
<
float
>
(
i
)
*
window_shift
/
receptive_field_shift
+
0
.
5
;
auto
seq
=
Eigen
::
seqN
(
start
,
labels
[
i
].
rows
());
count
(
seq
).
array
()
+=
labels
[
i
].
rowwise
().
sum
().
array
().
cast
<
float
>
();
weight
(
seq
).
array
()
+=
1
;
}
return
((
count
.
array
()
/
(
weight
.
array
()
+
1e-12
f
))
+
0
.
5
).
cast
<
int32_t
>
();
}
// ans.first: a list of (chunk_id, speaker_id)
// ans.second: a list of list of (start_sample_index, end_sample_index)
//
// ans.first[i] corresponds to ans.second[i]
std
::
pair
<
std
::
vector
<
Int32Pair
>
,
std
::
vector
<
std
::
vector
<
Int32Pair
>>>
GetChunkSpeakerSampleIndexes
(
const
std
::
vector
<
Matrix2DInt32
>
&
labels
)
const
{
auto
new_labels
=
ExcludeOverlap
(
labels
);
std
::
vector
<
Int32Pair
>
chunk_speaker_list
;
std
::
vector
<
std
::
vector
<
Int32Pair
>>
samples_index_list
;
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
window_size
=
meta_data
.
window_size
;
int32_t
window_shift
=
meta_data
.
window_shift
;
int32_t
receptive_field_shift
=
meta_data
.
receptive_field_shift
;
int32_t
num_speakers
=
meta_data
.
num_speakers
;
int32_t
chunk_index
=
0
;
for
(
const
auto
&
label
:
new_labels
)
{
Matrix2DInt32
tmp
=
label
.
transpose
();
// tmp: (num_speakers, num_frames)
int32_t
num_frames
=
tmp
.
cols
();
int32_t
sample_offset
=
chunk_index
*
window_shift
;
for
(
int32_t
speaker_index
=
0
;
speaker_index
!=
num_speakers
;
++
speaker_index
)
{
auto
d
=
tmp
.
row
(
speaker_index
);
if
(
d
.
sum
()
<
10
)
{
// skip segments less than 10 frames
continue
;
}
Int32Pair
this_chunk_speaker
=
{
chunk_index
,
speaker_index
};
std
::
vector
<
Int32Pair
>
this_speaker_samples
;
bool
is_active
=
false
;
int32_t
start_index
;
for
(
int32_t
k
=
0
;
k
!=
num_frames
;
++
k
)
{
if
(
d
[
k
]
!=
0
)
{
if
(
!
is_active
)
{
is_active
=
true
;
start_index
=
k
;
}
}
else
if
(
is_active
)
{
is_active
=
false
;
int32_t
start_samples
=
static_cast
<
float
>
(
start_index
)
/
num_frames
*
window_size
+
sample_offset
;
int32_t
end_samples
=
static_cast
<
float
>
(
k
)
/
num_frames
*
window_size
+
sample_offset
;
this_speaker_samples
.
emplace_back
(
start_samples
,
end_samples
);
}
}
if
(
is_active
)
{
int32_t
start_samples
=
static_cast
<
float
>
(
start_index
)
/
num_frames
*
window_size
+
sample_offset
;
int32_t
end_samples
=
static_cast
<
float
>
(
num_frames
-
1
)
/
num_frames
*
window_size
+
sample_offset
;
this_speaker_samples
.
emplace_back
(
start_samples
,
end_samples
);
}
chunk_speaker_list
.
push_back
(
std
::
move
(
this_chunk_speaker
));
samples_index_list
.
push_back
(
std
::
move
(
this_speaker_samples
));
}
// for (int32_t speaker_index = 0;
chunk_index
+=
1
;
}
// for (const auto &label : new_labels)
return
{
chunk_speaker_list
,
samples_index_list
};
}
// If there are multiple speakers at a frame, then this frame is excluded.
std
::
vector
<
Matrix2DInt32
>
ExcludeOverlap
(
const
std
::
vector
<
Matrix2DInt32
>
&
labels
)
const
{
int32_t
num_chunks
=
labels
.
size
();
std
::
vector
<
Matrix2DInt32
>
ans
;
ans
.
reserve
(
num_chunks
);
for
(
const
auto
&
label
:
labels
)
{
Matrix2DInt32
new_label
(
label
.
rows
(),
label
.
cols
());
new_label
.
setZero
();
Int32RowVector
v
=
label
.
rowwise
().
sum
();
for
(
int32_t
i
=
0
;
i
!=
v
.
cols
();
++
i
)
{
if
(
v
[
i
]
<
2
)
{
new_label
.
row
(
i
)
=
label
.
row
(
i
);
}
}
ans
.
push_back
(
std
::
move
(
new_label
));
}
return
ans
;
}
/**
* @param sample_indexes[i] contains the sample segment start and end indexes
* for the i-th (chunk, speaker) pair
* @return Return a matrix of shape (sample_indexes.size(), embedding_dim)
* where ans.row[i] contains the embedding for the
* i-th (chunk, speaker) pair
*/
Matrix2D
ComputeEmbeddings
(
const
float
*
audio
,
int32_t
n
,
const
std
::
vector
<
std
::
vector
<
Int32Pair
>>
&
sample_indexes
,
OfflineSpeakerDiarizationProgressCallback
callback
,
void
*
callback_arg
)
const
{
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
sample_rate
=
meta_data
.
sample_rate
;
Matrix2D
ans
(
sample_indexes
.
size
(),
embedding_extractor_
.
Dim
());
int32_t
k
=
0
;
for
(
const
auto
&
v
:
sample_indexes
)
{
auto
stream
=
embedding_extractor_
.
CreateStream
();
for
(
const
auto
&
p
:
v
)
{
int32_t
end
=
(
p
.
second
<=
n
)
?
p
.
second
:
n
;
int32_t
num_samples
=
end
-
p
.
first
;
if
(
num_samples
>
0
)
{
stream
->
AcceptWaveform
(
sample_rate
,
audio
+
p
.
first
,
num_samples
);
}
}
stream
->
InputFinished
();
if
(
!
embedding_extractor_
.
IsReady
(
stream
.
get
()))
{
SHERPA_ONNX_LOGE
(
"This segment is too short, which should not happen since we have "
"already filtered short segments"
);
SHERPA_ONNX_EXIT
(
-
1
);
}
std
::
vector
<
float
>
embedding
=
embedding_extractor_
.
Compute
(
stream
.
get
());
std
::
copy
(
embedding
.
begin
(),
embedding
.
end
(),
&
ans
(
k
,
0
));
k
+=
1
;
if
(
callback
)
{
callback
(
k
,
ans
.
rows
(),
callback_arg
);
}
}
return
ans
;
}
std
::
unordered_map
<
Int32Pair
,
int32_t
,
PairHash
>
ConvertChunkSpeakerToCluster
(
const
std
::
vector
<
Int32Pair
>
&
chunk_speaker_pair
,
const
std
::
vector
<
int32_t
>
&
cluster_labels
)
const
{
std
::
unordered_map
<
Int32Pair
,
int32_t
,
PairHash
>
ans
;
int32_t
k
=
0
;
for
(
const
auto
&
p
:
chunk_speaker_pair
)
{
ans
[
p
]
=
cluster_labels
[
k
];
k
+=
1
;
}
return
ans
;
}
std
::
vector
<
Matrix2DInt32
>
ReLabel
(
const
std
::
vector
<
Matrix2DInt32
>
&
labels
,
int32_t
max_cluster_index
,
std
::
unordered_map
<
Int32Pair
,
int32_t
,
PairHash
>
chunk_speaker_to_cluster
)
const
{
std
::
vector
<
Matrix2DInt32
>
new_labels
;
new_labels
.
reserve
(
labels
.
size
());
int32_t
chunk_index
=
0
;
for
(
const
auto
&
label
:
labels
)
{
Matrix2DInt32
new_label
(
label
.
rows
(),
max_cluster_index
+
1
);
new_label
.
setZero
();
Matrix2DInt32
t
=
label
.
transpose
();
// t: (num_speakers, num_frames)
for
(
int32_t
speaker_index
=
0
;
speaker_index
!=
t
.
rows
();
++
speaker_index
)
{
if
(
chunk_speaker_to_cluster
.
count
({
chunk_index
,
speaker_index
})
==
0
)
{
continue
;
}
int32_t
new_speaker_index
=
chunk_speaker_to_cluster
.
at
({
chunk_index
,
speaker_index
});
for
(
int32_t
k
=
0
;
k
!=
t
.
cols
();
++
k
)
{
if
(
t
(
speaker_index
,
k
)
==
1
)
{
new_label
(
k
,
new_speaker_index
)
=
1
;
}
}
}
new_labels
.
push_back
(
std
::
move
(
new_label
));
chunk_index
+=
1
;
}
return
new_labels
;
}
Matrix2DInt32
ComputeSpeakerCount
(
const
std
::
vector
<
Matrix2DInt32
>
&
labels
,
int32_t
num_samples
)
const
{
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
window_size
=
meta_data
.
window_size
;
int32_t
window_shift
=
meta_data
.
window_shift
;
int32_t
receptive_field_shift
=
meta_data
.
receptive_field_shift
;
int32_t
num_chunks
=
labels
.
size
();
int32_t
num_frames
=
(
window_size
+
(
num_chunks
-
1
)
*
window_shift
)
/
receptive_field_shift
+
1
;
Matrix2DInt32
count
(
num_frames
,
labels
[
0
].
cols
());
count
.
setZero
();
for
(
int32_t
i
=
0
;
i
!=
num_chunks
;
++
i
)
{
int32_t
start
=
static_cast
<
float
>
(
i
)
*
window_shift
/
receptive_field_shift
+
0
.
5
;
auto
seq
=
Eigen
::
seqN
(
start
,
labels
[
i
].
rows
());
count
(
seq
,
Eigen
::
all
).
array
()
+=
labels
[
i
].
array
();
}
bool
has_last_chunk
=
(
num_samples
-
window_size
)
%
window_shift
>
0
;
if
(
has_last_chunk
)
{
return
count
;
}
int32_t
last_frame
=
num_samples
/
receptive_field_shift
;
return
count
(
Eigen
::
seq
(
0
,
last_frame
),
Eigen
::
all
);
}
Matrix2DInt32
FinalizeLabels
(
const
Matrix2DInt32
&
count
,
const
Int32RowVector
&
speakers_per_frame
)
const
{
int32_t
num_rows
=
count
.
rows
();
int32_t
num_cols
=
count
.
cols
();
Matrix2DInt32
ans
(
num_rows
,
num_cols
);
ans
.
setZero
();
for
(
int32_t
i
=
0
;
i
!=
num_rows
;
++
i
)
{
int32_t
k
=
speakers_per_frame
[
i
];
if
(
k
==
0
)
{
continue
;
}
auto
top_k
=
TopkIndex
(
&
count
(
i
,
0
),
num_cols
,
k
);
for
(
int32_t
m
:
top_k
)
{
ans
(
i
,
m
)
=
1
;
}
}
return
ans
;
}
OfflineSpeakerDiarizationResult
ComputeResult
(
const
Matrix2DInt32
&
final_labels
)
const
{
Matrix2DInt32
final_labels_t
=
final_labels
.
transpose
();
int32_t
num_speakers
=
final_labels_t
.
rows
();
int32_t
num_frames
=
final_labels_t
.
cols
();
const
auto
&
meta_data
=
segmentation_model_
.
GetModelMetaData
();
int32_t
window_size
=
meta_data
.
window_size
;
int32_t
window_shift
=
meta_data
.
window_shift
;
int32_t
receptive_field_shift
=
meta_data
.
receptive_field_shift
;
int32_t
receptive_field_size
=
meta_data
.
receptive_field_size
;
int32_t
sample_rate
=
meta_data
.
sample_rate
;
float
scale
=
static_cast
<
float
>
(
receptive_field_shift
)
/
sample_rate
;
float
scale_offset
=
0
.
5
*
receptive_field_size
/
sample_rate
;
OfflineSpeakerDiarizationResult
ans
;
for
(
int32_t
speaker_index
=
0
;
speaker_index
!=
num_speakers
;
++
speaker_index
)
{
std
::
vector
<
OfflineSpeakerDiarizationSegment
>
this_speaker
;
bool
is_active
=
final_labels_t
(
speaker_index
,
0
)
>
0
;
int32_t
start_index
=
is_active
?
0
:
-
1
;
for
(
int32_t
frame_index
=
1
;
frame_index
!=
num_frames
;
++
frame_index
)
{
if
(
is_active
)
{
if
(
final_labels_t
(
speaker_index
,
frame_index
)
==
0
)
{
float
start_time
=
start_index
*
scale
+
scale_offset
;
float
end_time
=
frame_index
*
scale
+
scale_offset
;
OfflineSpeakerDiarizationSegment
segment
(
start_time
,
end_time
,
speaker_index
);
this_speaker
.
push_back
(
segment
);
is_active
=
false
;
}
}
else
if
(
final_labels_t
(
speaker_index
,
frame_index
)
==
1
)
{
is_active
=
true
;
start_index
=
frame_index
;
}
}
if
(
is_active
)
{
float
start_time
=
start_index
*
scale
+
scale_offset
;
float
end_time
=
(
num_frames
-
1
)
*
scale
+
scale_offset
;
OfflineSpeakerDiarizationSegment
segment
(
start_time
,
end_time
,
speaker_index
);
this_speaker
.
push_back
(
segment
);
}
// merge segments if the gap between them is less than min_duration_off
MergeSegments
(
&
this_speaker
);
for
(
const
auto
&
seg
:
this_speaker
)
{
if
(
seg
.
Duration
()
>
config_
.
min_duration_on
)
{
ans
.
Add
(
seg
);
}
}
}
// for (int32_t speaker_index = 0; speaker_index != num_speakers;
return
ans
;
}
void
MergeSegments
(
std
::
vector
<
OfflineSpeakerDiarizationSegment
>
*
segments
)
const
{
float
min_duration_off
=
config_
.
min_duration_off
;
bool
changed
=
true
;
while
(
changed
)
{
changed
=
false
;
for
(
int32_t
i
=
0
;
i
<
static_cast
<
int32_t
>
(
segments
->
size
())
-
1
;
++
i
)
{
auto
s
=
(
*
segments
)[
i
].
Merge
((
*
segments
)[
i
+
1
],
min_duration_off
);
if
(
s
)
{
(
*
segments
)[
i
]
=
s
.
value
();
segments
->
erase
(
segments
->
begin
()
+
i
+
1
);
changed
=
true
;
break
;
}
}
}
}
private
:
OfflineSpeakerDiarizationConfig
config_
;
OfflineSpeakerSegmentationPyannoteModel
segmentation_model_
;
SpeakerEmbeddingExtractor
embedding_extractor_
;
FastClustering
clustering_
;
Matrix2DInt32
powerset_mapping_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
...
...
sherpa-onnx/csrc/offline-speaker-diarization-result.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
namespace
sherpa_onnx
{
OfflineSpeakerDiarizationSegment
::
OfflineSpeakerDiarizationSegment
(
float
start
,
float
end
,
int32_t
speaker
,
const
std
::
string
&
text
/*= {}*/
)
{
if
(
start
>
end
)
{
SHERPA_ONNX_LOGE
(
"start %.3f should be less than end %.3f"
,
start
,
end
);
SHERPA_ONNX_EXIT
(
-
1
);
}
start_
=
start
;
end_
=
end
;
speaker_
=
speaker
;
text_
=
text
;
}
std
::
optional
<
OfflineSpeakerDiarizationSegment
>
OfflineSpeakerDiarizationSegment
::
Merge
(
const
OfflineSpeakerDiarizationSegment
&
other
,
float
gap
)
const
{
if
(
other
.
speaker_
!=
speaker_
)
{
SHERPA_ONNX_LOGE
(
"The two segments should have the same speaker. this->speaker: %d, "
"other.speaker: %d"
,
speaker_
,
other
.
speaker_
);
return
std
::
nullopt
;
}
if
(
end_
<
other
.
start_
&&
end_
+
gap
>=
other
.
start_
)
{
return
OfflineSpeakerDiarizationSegment
(
start_
,
other
.
end_
,
speaker_
);
}
else
if
(
other
.
end_
<
start_
&&
other
.
end_
+
gap
>=
start_
)
{
return
OfflineSpeakerDiarizationSegment
(
other
.
start_
,
end_
,
speaker_
);
}
else
{
return
std
::
nullopt
;
}
}
std
::
string
OfflineSpeakerDiarizationSegment
::
ToString
()
const
{
char
s
[
128
];
snprintf
(
s
,
sizeof
(
s
),
"%.3f -- %.3f speaker_%02d"
,
start_
,
end_
,
speaker_
);
std
::
ostringstream
os
;
os
<<
s
;
if
(
!
text_
.
empty
())
{
os
<<
" "
<<
text_
;
}
return
os
.
str
();
}
void
OfflineSpeakerDiarizationResult
::
Add
(
const
OfflineSpeakerDiarizationSegment
&
segment
)
{
segments_
.
push_back
(
segment
);
}
int32_t
OfflineSpeakerDiarizationResult
::
NumSpeakers
()
const
{
std
::
unordered_set
<
int32_t
>
count
;
for
(
const
auto
&
s
:
segments_
)
{
count
.
insert
(
s
.
Speaker
());
}
return
count
.
size
();
}
int32_t
OfflineSpeakerDiarizationResult
::
NumSegments
()
const
{
return
segments_
.
size
();
}
// Return a list of segments sorted by segment.start time
std
::
vector
<
OfflineSpeakerDiarizationSegment
>
OfflineSpeakerDiarizationResult
::
SortByStartTime
()
const
{
auto
ans
=
segments_
;
std
::
sort
(
ans
.
begin
(),
ans
.
end
(),
[](
const
auto
&
a
,
const
auto
&
b
)
{
return
(
a
.
Start
()
<
b
.
Start
())
||
((
a
.
Start
()
==
b
.
Start
())
&&
(
a
.
Speaker
()
<
b
.
Speaker
()));
});
return
ans
;
}
std
::
vector
<
std
::
vector
<
OfflineSpeakerDiarizationSegment
>>
OfflineSpeakerDiarizationResult
::
SortBySpeaker
()
const
{
auto
tmp
=
segments_
;
std
::
sort
(
tmp
.
begin
(),
tmp
.
end
(),
[](
const
auto
&
a
,
const
auto
&
b
)
{
return
(
a
.
Speaker
()
<
b
.
Speaker
())
||
((
a
.
Speaker
()
==
b
.
Speaker
())
&&
(
a
.
Start
()
<
b
.
Start
()));
});
std
::
vector
<
std
::
vector
<
OfflineSpeakerDiarizationSegment
>>
ans
(
NumSpeakers
());
for
(
auto
&
s
:
tmp
)
{
ans
[
s
.
Speaker
()].
push_back
(
std
::
move
(
s
));
}
return
ans
;
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-speaker-diarization-result.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization-result.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
namespace
sherpa_onnx
{
class
OfflineSpeakerDiarizationSegment
{
public
:
OfflineSpeakerDiarizationSegment
(
float
start
,
float
end
,
int32_t
speaker
,
const
std
::
string
&
text
=
{});
// If the gap between the two segments is less than the given gap, then we
// merge them and return a new segment. Otherwise, it returns null.
std
::
optional
<
OfflineSpeakerDiarizationSegment
>
Merge
(
const
OfflineSpeakerDiarizationSegment
&
other
,
float
gap
)
const
;
float
Start
()
const
{
return
start_
;
}
float
End
()
const
{
return
end_
;
}
int32_t
Speaker
()
const
{
return
speaker_
;
}
const
std
::
string
&
Text
()
const
{
return
text_
;
}
float
Duration
()
const
{
return
end_
-
start_
;
}
std
::
string
ToString
()
const
;
private
:
float
start_
;
// in seconds
float
end_
;
// in seconds
int32_t
speaker_
;
// ID of the speaker, starting from 0
std
::
string
text_
;
// If not empty, it contains the speech recognition result
// of this segment
};
class
OfflineSpeakerDiarizationResult
{
public
:
// Add a new segment
void
Add
(
const
OfflineSpeakerDiarizationSegment
&
segment
);
// Number of distinct speakers contained in this object at this point
int32_t
NumSpeakers
()
const
;
int32_t
NumSegments
()
const
;
// Return a list of segments sorted by segment.start time
std
::
vector
<
OfflineSpeakerDiarizationSegment
>
SortByStartTime
()
const
;
// ans.size() == NumSpeakers().
// ans[i] is for speaker_i and is sorted by start time
std
::
vector
<
std
::
vector
<
OfflineSpeakerDiarizationSegment
>>
SortBySpeaker
()
const
;
public
:
std
::
vector
<
OfflineSpeakerDiarizationSegment
>
segments_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
...
...
sherpa-onnx/csrc/offline-speaker-diarization.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include <string>
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
namespace
sherpa_onnx
{
void
OfflineSpeakerDiarizationConfig
::
Register
(
ParseOptions
*
po
)
{
ParseOptions
po_segmentation
(
"segmentation"
,
po
);
segmentation
.
Register
(
&
po_segmentation
);
ParseOptions
po_embedding
(
"embedding"
,
po
);
embedding
.
Register
(
&
po_embedding
);
ParseOptions
po_clustering
(
"clustering"
,
po
);
clustering
.
Register
(
&
po_clustering
);
po
->
Register
(
"min-duration-on"
,
&
min_duration_on
,
"if a segment is less than this value, then it is discarded. "
"Set it to 0 so that no segment is discarded"
);
po
->
Register
(
"min-duration-off"
,
&
min_duration_off
,
"if the gap between to segments of the same speaker is less "
"than this value, then these two segments are merged into a "
"single segment. We do it recursively."
);
}
bool
OfflineSpeakerDiarizationConfig
::
Validate
()
const
{
if
(
!
segmentation
.
Validate
())
{
return
false
;
}
if
(
!
embedding
.
Validate
())
{
return
false
;
}
if
(
!
clustering
.
Validate
())
{
return
false
;
}
return
true
;
}
std
::
string
OfflineSpeakerDiarizationConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineSpeakerDiarizationConfig("
;
os
<<
"segmentation="
<<
segmentation
.
ToString
()
<<
", "
;
os
<<
"embedding="
<<
embedding
.
ToString
()
<<
", "
;
os
<<
"clustering="
<<
clustering
.
ToString
()
<<
", "
;
os
<<
"min_duration_on="
<<
min_duration_on
<<
", "
;
os
<<
"min_duration_off="
<<
min_duration_off
<<
")"
;
return
os
.
str
();
}
OfflineSpeakerDiarization
::
OfflineSpeakerDiarization
(
const
OfflineSpeakerDiarizationConfig
&
config
)
:
impl_
(
OfflineSpeakerDiarizationImpl
::
Create
(
config
))
{}
OfflineSpeakerDiarization
::~
OfflineSpeakerDiarization
()
=
default
;
int32_t
OfflineSpeakerDiarization
::
SampleRate
()
const
{
return
impl_
->
SampleRate
();
}
OfflineSpeakerDiarizationResult
OfflineSpeakerDiarization
::
Process
(
const
float
*
audio
,
int32_t
n
,
OfflineSpeakerDiarizationProgressCallback
callback
/*= nullptr*/
,
void
*
callback_arg
/*= nullptr*/
)
const
{
return
impl_
->
Process
(
audio
,
n
,
callback
,
callback_arg
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-speaker-diarization.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-diarization.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#include <functional>
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/fast-clustering-config.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace
sherpa_onnx
{
struct
OfflineSpeakerDiarizationConfig
{
OfflineSpeakerSegmentationModelConfig
segmentation
;
SpeakerEmbeddingExtractorConfig
embedding
;
FastClusteringConfig
clustering
;
// if a segment is less than this value, then it is discarded
float
min_duration_on
=
0
.
3
;
// in seconds
// if the gap between to segments of the same speaker is less than this value,
// then these two segments are merged into a single segment.
// We do this recursively.
float
min_duration_off
=
0
.
5
;
// in seconds
OfflineSpeakerDiarizationConfig
()
=
default
;
OfflineSpeakerDiarizationConfig
(
const
OfflineSpeakerSegmentationModelConfig
&
segmentation
,
const
SpeakerEmbeddingExtractorConfig
&
embedding
,
const
FastClusteringConfig
&
clustering
)
:
segmentation
(
segmentation
),
embedding
(
embedding
),
clustering
(
clustering
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
class
OfflineSpeakerDiarizationImpl
;
using
OfflineSpeakerDiarizationProgressCallback
=
std
::
function
<
int32_t
(
int32_t
processed_chunks
,
int32_t
num_chunks
,
void
*
arg
)
>
;
class
OfflineSpeakerDiarization
{
public
:
explicit
OfflineSpeakerDiarization
(
const
OfflineSpeakerDiarizationConfig
&
config
);
~
OfflineSpeakerDiarization
();
// Expected sample rate of the input audio samples
int32_t
SampleRate
()
const
;
OfflineSpeakerDiarizationResult
Process
(
const
float
*
audio
,
int32_t
n
,
OfflineSpeakerDiarizationProgressCallback
callback
=
nullptr
,
void
*
callback_arg
=
nullptr
)
const
;
private
:
std
::
unique_ptr
<
OfflineSpeakerDiarizationImpl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
namespace
sherpa_onnx
{
void
OfflineSpeakerSegmentationModelConfig
::
Register
(
ParseOptions
*
po
)
{
pyannote
.
Register
(
po
);
po
->
Register
(
"num-threads"
,
&
num_threads
,
"Number of threads to run the neural network"
);
po
->
Register
(
"debug"
,
&
debug
,
"true to print model information while loading it."
);
po
->
Register
(
"provider"
,
&
provider
,
"Specify a provider to use: cpu, cuda, coreml"
);
}
bool
OfflineSpeakerSegmentationModelConfig
::
Validate
()
const
{
if
(
num_threads
<
1
)
{
SHERPA_ONNX_LOGE
(
"num_threads should be > 0. Given %d"
,
num_threads
);
return
false
;
}
if
(
!
pyannote
.
model
.
empty
())
{
return
pyannote
.
Validate
();
}
if
(
pyannote
.
model
.
empty
())
{
SHERPA_ONNX_LOGE
(
"You have to provide at least one speaker segmentation model"
);
return
false
;
}
return
true
;
}
std
::
string
OfflineSpeakerSegmentationModelConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineSpeakerSegmentationModelConfig("
;
os
<<
"pyannote="
<<
pyannote
.
ToString
()
<<
", "
;
os
<<
"num_threads="
<<
num_threads
<<
", "
;
os
<<
"debug="
<<
(
debug
?
"True"
:
"False"
)
<<
", "
;
os
<<
"provider=
\"
"
<<
provider
<<
"
\"
)"
;
return
os
.
str
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
struct
OfflineSpeakerSegmentationModelConfig
{
OfflineSpeakerSegmentationPyannoteModelConfig
pyannote
;
int32_t
num_threads
=
1
;
bool
debug
=
false
;
std
::
string
provider
=
"cpu"
;
OfflineSpeakerSegmentationModelConfig
()
=
default
;
explicit
OfflineSpeakerSegmentationModelConfig
(
const
OfflineSpeakerSegmentationPyannoteModelConfig
&
pyannote
,
int32_t
num_threads
,
bool
debug
,
const
std
::
string
&
provider
)
:
pyannote
(
pyannote
),
num_threads
(
num_threads
),
debug
(
debug
),
provider
(
provider
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace
sherpa_onnx
{
void
OfflineSpeakerSegmentationPyannoteModelConfig
::
Register
(
ParseOptions
*
po
)
{
po
->
Register
(
"pyannote-model"
,
&
model
,
"Path to model.onnx of the Pyannote segmentation model."
);
}
bool
OfflineSpeakerSegmentationPyannoteModelConfig
::
Validate
()
const
{
if
(
!
FileExists
(
model
))
{
SHERPA_ONNX_LOGE
(
"Pyannote segmentation model: '%s' does not exist"
,
model
.
c_str
());
return
false
;
}
return
true
;
}
std
::
string
OfflineSpeakerSegmentationPyannoteModelConfig
::
ToString
()
const
{
std
::
ostringstream
os
;
os
<<
"OfflineSpeakerSegmentationPyannoteModelConfig("
;
os
<<
"model=
\"
"
<<
model
<<
"
\"
)"
;
return
os
.
str
();
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace
sherpa_onnx
{
struct
OfflineSpeakerSegmentationPyannoteModelConfig
{
std
::
string
model
;
OfflineSpeakerSegmentationPyannoteModelConfig
()
=
default
;
explicit
OfflineSpeakerSegmentationPyannoteModelConfig
(
const
std
::
string
&
model
)
:
model
(
model
)
{}
void
Register
(
ParseOptions
*
po
);
bool
Validate
()
const
;
std
::
string
ToString
()
const
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
#include <cstdint>
#include <string>
namespace
sherpa_onnx
{
// If you are not sure what each field means, please
// have a look of the Python file in the model directory that
// you have downloaded.
struct
OfflineSpeakerSegmentationPyannoteModelMetaData
{
int32_t
sample_rate
=
0
;
int32_t
window_size
=
0
;
// in samples
int32_t
window_shift
=
0
;
// in samples
int32_t
receptive_field_size
=
0
;
// in samples
int32_t
receptive_field_shift
=
0
;
// in samples
int32_t
num_speakers
=
0
;
int32_t
powerset_max_classes
=
0
;
int32_t
num_classes
=
0
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
namespace
sherpa_onnx
{
class
OfflineSpeakerSegmentationPyannoteModel
::
Impl
{
public
:
explicit
Impl
(
const
OfflineSpeakerSegmentationModelConfig
&
config
)
:
config_
(
config
),
env_
(
ORT_LOGGING_LEVEL_ERROR
),
sess_opts_
(
GetSessionOptions
(
config
)),
allocator_
{}
{
auto
buf
=
ReadFile
(
config_
.
pyannote
.
model
);
Init
(
buf
.
data
(),
buf
.
size
());
}
const
OfflineSpeakerSegmentationPyannoteModelMetaData
&
GetModelMetaData
()
const
{
return
meta_data_
;
}
Ort
::
Value
Forward
(
Ort
::
Value
x
)
{
auto
out
=
sess_
->
Run
({},
input_names_ptr_
.
data
(),
&
x
,
1
,
output_names_ptr_
.
data
(),
output_names_ptr_
.
size
());
return
std
::
move
(
out
[
0
]);
}
private
:
void
Init
(
void
*
model_data
,
size_t
model_data_length
)
{
sess_
=
std
::
make_unique
<
Ort
::
Session
>
(
env_
,
model_data
,
model_data_length
,
sess_opts_
);
GetInputNames
(
sess_
.
get
(),
&
input_names_
,
&
input_names_ptr_
);
GetOutputNames
(
sess_
.
get
(),
&
output_names_
,
&
output_names_ptr_
);
// get meta data
Ort
::
ModelMetadata
meta_data
=
sess_
->
GetModelMetadata
();
if
(
config_
.
debug
)
{
std
::
ostringstream
os
;
PrintModelMetadata
(
os
,
meta_data
);
SHERPA_ONNX_LOGE
(
"%s
\n
"
,
os
.
str
().
c_str
());
}
Ort
::
AllocatorWithDefaultOptions
allocator
;
// used in the macro below
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
sample_rate
,
"sample_rate"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
window_size
,
"window_size"
);
meta_data_
.
window_shift
=
static_cast
<
int32_t
>
(
0.1
*
meta_data_
.
window_size
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
receptive_field_size
,
"receptive_field_size"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
receptive_field_shift
,
"receptive_field_shift"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
num_speakers
,
"num_speakers"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
powerset_max_classes
,
"powerset_max_classes"
);
SHERPA_ONNX_READ_META_DATA
(
meta_data_
.
num_classes
,
"num_classes"
);
}
private
:
OfflineSpeakerSegmentationModelConfig
config_
;
Ort
::
Env
env_
;
Ort
::
SessionOptions
sess_opts_
;
Ort
::
AllocatorWithDefaultOptions
allocator_
;
std
::
unique_ptr
<
Ort
::
Session
>
sess_
;
std
::
vector
<
std
::
string
>
input_names_
;
std
::
vector
<
const
char
*>
input_names_ptr_
;
std
::
vector
<
std
::
string
>
output_names_
;
std
::
vector
<
const
char
*>
output_names_ptr_
;
OfflineSpeakerSegmentationPyannoteModelMetaData
meta_data_
;
};
OfflineSpeakerSegmentationPyannoteModel
::
OfflineSpeakerSegmentationPyannoteModel
(
const
OfflineSpeakerSegmentationModelConfig
&
config
)
:
impl_
(
std
::
make_unique
<
Impl
>
(
config
))
{}
OfflineSpeakerSegmentationPyannoteModel
::
~
OfflineSpeakerSegmentationPyannoteModel
()
=
default
;
const
OfflineSpeakerSegmentationPyannoteModelMetaData
&
OfflineSpeakerSegmentationPyannoteModel
::
GetModelMetaData
()
const
{
return
impl_
->
GetModelMetaData
();
}
Ort
::
Value
OfflineSpeakerSegmentationPyannoteModel
::
Forward
(
Ort
::
Value
x
)
const
{
return
impl_
->
Forward
(
std
::
move
(
x
));
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
namespace
sherpa_onnx
{
class
OfflineSpeakerSegmentationPyannoteModel
{
public
:
explicit
OfflineSpeakerSegmentationPyannoteModel
(
const
OfflineSpeakerSegmentationModelConfig
&
config
);
~
OfflineSpeakerSegmentationPyannoteModel
();
const
OfflineSpeakerSegmentationPyannoteModelMetaData
&
GetModelMetaData
()
const
;
/**
* @param x A 3-D float tensor of shape (batch_size, 1, num_samples)
* @return Return a float tensor of
* shape (batch_size, num_frames, num_speakers). Note that
* num_speakers here uses powerset encoding.
*/
Ort
::
Value
Forward
(
Ort
::
Value
x
)
const
;
private
:
class
Impl
;
std
::
unique_ptr
<
Impl
>
impl_
;
};
}
// namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
...
...
sherpa-onnx/csrc/provider-config.cc
查看文件 @
59407ed
...
...
@@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) {
bool
TensorrtConfig
::
Validate
()
const
{
if
(
trt_max_workspace_size
<
0
)
{
SHERPA_ONNX_LOGE
(
"trt_max_workspace_size: %ld is not valid."
,
trt_max_workspace_size
);
std
::
ostringstream
os
;
os
<<
"trt_max_workspace_size: "
<<
trt_max_workspace_size
<<
" is not valid."
;
SHERPA_ONNX_LOGE
(
"%s"
,
os
.
str
().
c_str
());
return
false
;
}
if
(
trt_max_partition_iterations
<
0
)
{
...
...
sherpa-onnx/csrc/session.cc
查看文件 @
59407ed
...
...
@@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
api
.
ReleaseStatus
(
status
);
}
static
Ort
::
SessionOptions
GetSessionOptionsImpl
(
Ort
::
SessionOptions
GetSessionOptionsImpl
(
int32_t
num_threads
,
const
std
::
string
&
provider_str
,
const
ProviderConfig
*
provider_config
=
nullptr
)
{
const
ProviderConfig
*
provider_config
/*= nullptr*/
)
{
Provider
p
=
StringToProvider
(
provider_str
);
Ort
::
SessionOptions
sess_opts
;
...
...
@@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
&
config
.
provider_config
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineLMConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
lm_num_threads
,
config
.
lm_provider
);
}
...
...
@@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
return
GetSessionOptionsImpl
(
config
.
lm_num_threads
,
config
.
lm_provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
VadModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
#if SHERPA_ONNX_ENABLE_TTS
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineTtsModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
#endif
Ort
::
SessionOptions
GetSessionOptions
(
const
SpeakerEmbeddingExtractorConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
SpokenLanguageIdentificationConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
AudioTaggingModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflinePunctuationModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlinePunctuationModelConfig
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/session.h
查看文件 @
59407ed
...
...
@@ -8,53 +8,28 @@
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
#if SHERPA_ONNX_ENABLE_TTS
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#endif
namespace
sherpa_onnx
{
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
,
const
std
::
string
&
model_type
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptionsImpl
(
int32_t
num_threads
,
const
std
::
string
&
provider_str
,
const
ProviderConfig
*
provider_config
=
nullptr
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineLMConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineLMConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
VadModelConfig
&
config
);
#if SHERPA_ONNX_ENABLE_TTS
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflineTtsModelConfig
&
config
);
#endif
Ort
::
SessionOptions
GetSessionOptions
(
const
SpeakerEmbeddingExtractorConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
SpokenLanguageIdentificationConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
AudioTaggingModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OfflinePunctuationModelConfig
&
config
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlineModelConfig
&
config
,
const
std
::
string
&
model_type
);
Ort
::
SessionOptions
GetSessionOptions
(
const
OnlinePunctuationModelConfig
&
config
);
template
<
typename
T
>
Ort
::
SessionOptions
GetSessionOptions
(
const
T
&
config
)
{
return
GetSessionOptionsImpl
(
config
.
num_threads
,
config
.
provider
);
}
}
// namespace sherpa_onnx
...
...
sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
0 → 100644
查看文件 @
59407ed
// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
static
int32_t
ProgressCallback
(
int32_t
processed_chunks
,
int32_t
num_chunks
,
void
*
arg
)
{
float
progress
=
100.0
*
processed_chunks
/
num_chunks
;
fprintf
(
stderr
,
"progress %.2f%%
\n
"
,
progress
);
// the return value is currently ignored
return
0
;
}
int
main
(
int32_t
argc
,
char
*
argv
[])
{
const
char
*
kUsageMessage
=
R"usage(
Offline/Non-streaming speaker diarization with sherpa-onnx
Usage example:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Build sherpa-onnx
Step 5. Run it
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
Since we know that there are four speakers in the test wave file, we use
--clustering.num-clusters=4 in the above example.
If we don't know number of speakers in the given wave file, we can use
the argument --clustering.cluster-threshold. The following is an example:
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
A larger threshold leads to few clusters, i.e., few speakers;
a smaller threshold leads to more clusters, i.e., more speakers
)usage"
;
sherpa_onnx
::
OfflineSpeakerDiarizationConfig
config
;
sherpa_onnx
::
ParseOptions
po
(
kUsageMessage
);
config
.
Register
(
&
po
);
po
.
Read
(
argc
,
argv
);
std
::
cout
<<
config
.
ToString
()
<<
"
\n
"
;
if
(
!
config
.
Validate
())
{
po
.
PrintUsage
();
std
::
cerr
<<
"Errors in config!
\n
"
;
return
-
1
;
}
if
(
po
.
NumArgs
()
!=
1
)
{
std
::
cerr
<<
"Error: Please provide exactly 1 wave file.
\n\n
"
;
po
.
PrintUsage
();
return
-
1
;
}
sherpa_onnx
::
OfflineSpeakerDiarization
sd
(
config
);
std
::
cout
<<
"Started
\n
"
;
const
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
const
std
::
string
wav_filename
=
po
.
GetArg
(
1
);
int32_t
sample_rate
=
-
1
;
bool
is_ok
=
false
;
const
std
::
vector
<
float
>
samples
=
sherpa_onnx
::
ReadWave
(
wav_filename
,
&
sample_rate
,
&
is_ok
);
if
(
!
is_ok
)
{
std
::
cerr
<<
"Failed to read "
<<
wav_filename
.
c_str
()
<<
"
\n
"
;
return
-
1
;
}
if
(
sample_rate
!=
sd
.
SampleRate
())
{
std
::
cerr
<<
"Expect sample rate "
<<
sd
.
SampleRate
()
<<
". Given: "
<<
sample_rate
<<
"
\n
"
;
return
-
1
;
}
float
duration
=
samples
.
size
()
/
static_cast
<
float
>
(
sample_rate
);
auto
result
=
sd
.
Process
(
samples
.
data
(),
samples
.
size
(),
ProgressCallback
,
nullptr
)
.
SortByStartTime
();
for
(
const
auto
&
r
:
result
)
{
std
::
cout
<<
r
.
ToString
()
<<
"
\n
"
;
}
const
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
elapsed_seconds
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
milliseconds
>
(
end
-
begin
)
.
count
()
/
1000.
;
fprintf
(
stderr
,
"Duration : %.3f s
\n
"
,
duration
);
fprintf
(
stderr
,
"Elapsed seconds: %.3f s
\n
"
,
elapsed_seconds
);
float
rtf
=
elapsed_seconds
/
duration
;
fprintf
(
stderr
,
"Real time factor (RTF): %.3f / %.3f = %.3f
\n
"
,
elapsed_seconds
,
duration
,
rtf
);
return
0
;
}
...
...
sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc
查看文件 @
59407ed
...
...
@@ -9,14 +9,15 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"
int32_t
audioCallback
(
const
float
*
/*samples*/
,
int32_t
n
,
float
progress
)
{
static
int32_t
AudioCallback
(
const
float
*
/*samples*/
,
int32_t
n
,
float
progress
)
{
printf
(
"sample=%d, progress=%f
\n
"
,
n
,
progress
);
return
1
;
}
int
main
(
int32_t
argc
,
char
*
argv
[])
{
const
char
*
kUsageMessage
=
R"usage(
Offline text-to-speech with sherpa-onnx
Offline
/Non-streaming
text-to-speech with sherpa-onnx
Usage example:
...
...
@@ -79,7 +80,7 @@ or details.
sherpa_onnx
::
OfflineTts
tts
(
config
);
const
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
auto
audio
=
tts
.
Generate
(
po
.
GetArg
(
1
),
sid
,
1.0
,
a
udioCallback
);
auto
audio
=
tts
.
Generate
(
po
.
GetArg
(
1
),
sid
,
1.0
,
A
udioCallback
);
const
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
if
(
audio
.
samples
.
empty
())
{
...
...
sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
查看文件 @
59407ed
...
...
@@ -19,7 +19,7 @@ The input text can contain English words.
Usage:
Please download the model from:
https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-
punct-ct-transformer-zh-en-vocab272727-2024-04-12
.tar.bz2
https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-
online-punct-en-2024-08-06
.tar.bz2
./bin/Release/sherpa-onnx-online-punctuation \
--cnn-bilstm=/path/to/model.onnx \
...
...
sherpa-onnx/csrc/speaker-embedding-extractor.cc
查看文件 @
59407ed
...
...
@@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
bool
SpeakerEmbeddingExtractorConfig
::
Validate
()
const
{
if
(
model
.
empty
())
{
SHERPA_ONNX_LOGE
(
"Please provide
--
model"
);
SHERPA_ONNX_LOGE
(
"Please provide
a speaker embedding extractor
model"
);
return
false
;
}
if
(
!
FileExists
(
model
))
{
SHERPA_ONNX_LOGE
(
"
--speaker-embedding-
model: '%s' does not exist"
,
SHERPA_ONNX_LOGE
(
"
speaker embedding extractor
model: '%s' does not exist"
,
model
.
c_str
());
return
false
;
}
...
...
请
注册
或
登录
后发表评论