Fangjun Kuang
Committed by GitHub

C++ API for speaker diarization (#1396)

正在显示 39 个修改的文件 包含 1652 行增加108 行删除
#!/usr/bin/env bash
set -ex
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
log "specify number of clusters"
$EXE \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
log "specify threshold for clustering"
$EXE \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
rm -rf sherpa-onnx-pyannote-*
rm -fv *.onnx
rm -fv *.wav
... ...
... ... @@ -29,7 +29,7 @@ jobs:
- name: Install pyannote
shell: bash
run: |
pip install pyannote.audio onnx onnxruntime
pip install pyannote.audio onnx==1.15.0 onnxruntime==1.16.3
- name: Run
shell: bash
... ...
... ... @@ -18,6 +18,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -38,6 +39,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -143,6 +145,15 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: install/*
- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization
.github/scripts/test-speaker-diarization.sh
- name: Test offline transducer
shell: bash
run: |
... ...
... ... @@ -18,6 +18,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -37,6 +38,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -115,6 +117,15 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization
.github/scripts/test-speaker-diarization.sh
- name: Test offline transducer
shell: bash
run: |
... ...
... ... @@ -67,7 +67,7 @@ jobs:
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
test_wavs=(
0-two-speakers-zh.wav
0-four-speakers-zh.wav
1-two-speakers-en.wav
2-two-speakers-en.wav
3-two-speakers-en.wav
... ...
... ... @@ -17,6 +17,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -34,6 +35,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -87,6 +89,15 @@ jobs:
name: release-windows-x64-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*
- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
- name: Test online punctuation
shell: bash
run: |
... ...
... ... @@ -17,6 +17,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -34,6 +35,7 @@ on:
- '.github/scripts/test-audio-tagging.sh'
- '.github/scripts/test-offline-punctuation.sh'
- '.github/scripts/test-online-punctuation.sh'
- '.github/scripts/test-speaker-diarization.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -87,6 +89,15 @@ jobs:
name: release-windows-x86-${{ matrix.shared_lib }}-${{ matrix.with_tts }}
path: build/install/*
- name: Test offline speaker diarization
shell: bash
run: |
du -h -d1 .
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-speaker-diarization.exe
.github/scripts/test-speaker-diarization.sh
- name: Test online punctuation
shell: bash
run: |
... ...
... ... @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
... ...
... ... @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
... ...
... ... @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
... ...
... ... @@ -36,7 +36,7 @@ static size_t ReadFile(const char *filename, const char **buffer_out) {
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
size_t read_bytes = fread((void *)*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free((void *)*buffer_out);
... ...
... ... @@ -55,6 +55,7 @@ def get_binaries():
"sherpa-onnx-offline-audio-tagging",
"sherpa-onnx-offline-language-identification",
"sherpa-onnx-offline-punctuation",
"sherpa-onnx-offline-speaker-diarization",
"sherpa-onnx-offline-tts",
"sherpa-onnx-offline-tts-play",
"sherpa-onnx-offline-websocket-server",
... ...
... ... @@ -3,12 +3,9 @@
Please download test wave files from
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
## 0-two-speakers-zh.wav
## 0-four-speakers-zh.wav
This file is from
https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0
Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`.
It is recorded by @csukuangfj
## 1-two-speakers-en.wav
... ... @@ -40,5 +37,5 @@ commands to convert it to `3-two-speakers-en.wav`
```bash
sox ML16091-Audio.mp3 3-two-speakers-en.wav
sox ML16091-Audio.mp3 -r 16k 3-two-speakers-en.wav
```
... ...
... ... @@ -72,7 +72,7 @@ def main():
model.receptive_field.duration * 16000
)
opset_version = 18
opset_version = 13
filename = "model.onnx"
torch.onnx.export(
... ...
... ... @@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND sources
fast-clustering-config.cc
fast-clustering.cc
offline-speaker-diarization-impl.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
offline-speaker-segmentation-model-config.cc
offline-speaker-segmentation-pyannote-model-config.cc
offline-speaker-segmentation-pyannote-model.cc
)
endif()
... ... @@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc)
endif()
set(main_exes
sherpa-onnx
sherpa-onnx-keyword-spotter
... ... @@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY)
)
endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND main_exes
sherpa-onnx-offline-speaker-diarization
)
endif()
foreach(exe IN LISTS main_exes)
target_link_libraries(${exe} sherpa-onnx-core)
endforeach()
... ...
... ... @@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const {
}
void FastClusteringConfig::Register(ParseOptions *po) {
std::string prefix = "ctc";
ParseOptions p(prefix, po);
p.Register("num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then --cluster-thresold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");
p.Register("cluster-threshold", &threshold,
"If --num-clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
po->Register(
"num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then cluster threshold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");
po->Register("cluster-threshold", &threshold,
"If num_clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
}
bool FastClusteringConfig::Validate() const {
... ...
... ... @@ -5,6 +5,7 @@
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
#define SHERPA_ONNX_CSRC_MACROS_H_
#include <stdio.h>
#include <stdlib.h>
#if __ANDROID_API__ >= 8
#include "android/log.h"
... ... @@ -169,4 +170,6 @@
} \
} while (0)
#define SHERPA_ONNX_EXIT(code) exit(code)
#endif // SHERPA_ONNX_CSRC_MACROS_H_
... ...
... ... @@ -9,6 +9,7 @@
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
... ...
// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"
namespace sherpa_onnx {
std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");
return nullptr;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-diarization-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
#include <functional>
#include <memory>
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
namespace sherpa_onnx {
class OfflineSpeakerDiarizationImpl {
public:
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
const OfflineSpeakerDiarizationConfig &config);
virtual ~OfflineSpeakerDiarizationImpl() = default;
virtual int32_t SampleRate() const = 0;
virtual OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
void *callback_arg = nullptr) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "sherpa-onnx/csrc/fast-clustering.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
namespace { // NOLINT
// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41
template <class T>
inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT
}
// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47
struct PairHash {
template <class T1, class T2>
std::size_t operator()(const std::pair<T1, T2> &pair) const {
std::size_t result = 0;
hash_combine(&result, pair.first);
hash_combine(&result, pair.second);
return result;
}
};
} // namespace
using Matrix2D =
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using Matrix2DInt32 =
Eigen::Matrix<int32_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using FloatRowVector = Eigen::Matrix<float, 1, Eigen::Dynamic>;
using Int32RowVector = Eigen::Matrix<int32_t, 1, Eigen::Dynamic>;
using Int32Pair = std::pair<int32_t, int32_t>;
class OfflineSpeakerDiarizationPyannoteImpl
: public OfflineSpeakerDiarizationImpl {
public:
~OfflineSpeakerDiarizationPyannoteImpl() override = default;
explicit OfflineSpeakerDiarizationPyannoteImpl(
const OfflineSpeakerDiarizationConfig &config)
: config_(config),
segmentation_model_(config_.segmentation),
embedding_extractor_(config_.embedding),
clustering_(config_.clustering) {
Init();
}
int32_t SampleRate() const override {
const auto &meta_data = segmentation_model_.GetModelMetaData();
return meta_data.sample_rate;
}
OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
void *callback_arg = nullptr) const override {
std::vector<Matrix2D> segmentations = RunSpeakerSegmentationModel(audio, n);
// segmentations[i] is for chunk_i
// Each matrix is of shape (num_frames, num_powerset_classes)
if (segmentations.empty()) {
return {};
}
std::vector<Matrix2DInt32> labels;
labels.reserve(segmentations.size());
for (const auto &m : segmentations) {
labels.push_back(ToMultiLabel(m));
}
segmentations.clear();
// labels[i] is a 0-1 matrix of shape (num_frames, num_speakers)
// speaker count per frame
Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels);
if (speakers_per_frame.maxCoeff() == 0) {
SHERPA_ONNX_LOGE("No speakers found in the audio samples");
return {};
}
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
callback, callback_arg);
std::vector<int32_t> cluster_labels = clustering_.Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
int32_t max_cluster_index =
*std::max_element(cluster_labels.begin(), cluster_labels.end());
auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster(
chunk_speaker_samples_list_pair.first, cluster_labels);
auto new_labels =
ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster);
Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n);
Matrix2DInt32 final_labels =
FinalizeLabels(speaker_count, speakers_per_frame);
auto result = ComputeResult(final_labels);
return result;
}
private:
void Init() { InitPowersetMapping(); }
// see also
// https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68
void InitPowersetMapping() {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t num_classes = meta_data.num_classes;
int32_t powerset_max_classes = meta_data.powerset_max_classes;
int32_t num_speakers = meta_data.num_speakers;
powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers);
powerset_mapping_.setZero();
int32_t k = 1;
for (int32_t i = 1; i <= powerset_max_classes; ++i) {
if (i == 1) {
for (int32_t j = 0; j != num_speakers; ++j, ++k) {
powerset_mapping_(k, j) = 1;
}
} else if (i == 2) {
for (int32_t j = 0; j != num_speakers; ++j) {
for (int32_t m = j + 1; m < num_speakers; ++m, ++k) {
powerset_mapping_(k, j) = 1;
powerset_mapping_(k, m) = 1;
}
}
} else {
SHERPA_ONNX_LOGE(
"powerset_max_classes = %d is currently not supported!", i);
SHERPA_ONNX_EXIT(-1);
}
}
}
std::vector<Matrix2D> RunSpeakerSegmentationModel(const float *audio,
int32_t n) const {
std::vector<Matrix2D> ans;
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
int32_t window_shift = meta_data.window_shift;
if (n <= 0) {
SHERPA_ONNX_LOGE(
"number of audio samples is %d (<= 0). Please provide a positive "
"number",
n);
return {};
}
if (n <= window_size) {
std::vector<float> buf(window_size);
// NOTE: buf is zero initialized by default
std::copy(audio, audio + n, buf.data());
Matrix2D m = ProcessChunk(buf.data());
ans.push_back(std::move(m));
return ans;
}
int32_t num_chunks = (n - window_size) / window_shift + 1;
bool has_last_chunk = (n - window_size) % window_shift > 0;
ans.reserve(num_chunks + has_last_chunk);
const float *p = audio;
for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) {
Matrix2D m = ProcessChunk(p);
ans.push_back(std::move(m));
}
if (has_last_chunk) {
std::vector<float> buf(window_size);
std::copy(p, audio + n, buf.data());
Matrix2D m = ProcessChunk(buf.data());
ans.push_back(std::move(m));
}
return ans;
}
Matrix2D ProcessChunk(const float *p) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> shape = {1, 1, window_size};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, const_cast<float *>(p),
window_size, shape.data(), shape.size());
Ort::Value out = segmentation_model_.Forward(std::move(x));
std::vector<int64_t> out_shape = out.GetTensorTypeAndShapeInfo().GetShape();
Matrix2D m(out_shape[1], out_shape[2]);
std::copy(out.GetTensorData<float>(), out.GetTensorData<float>() + m.size(),
&m(0, 0));
return m;
}
Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const {
int32_t num_rows = m.rows();
Matrix2DInt32 ans(num_rows, powerset_mapping_.cols());
std::ptrdiff_t col_id;
for (int32_t i = 0; i != num_rows; ++i) {
m.row(i).maxCoeff(&col_id);
ans.row(i) = powerset_mapping_.row(col_id);
}
return ans;
}
// See also
// https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122
Int32RowVector ComputeSpeakersPerFrame(
const std::vector<Matrix2DInt32> &labels) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
int32_t window_shift = meta_data.window_shift;
int32_t receptive_field_shift = meta_data.receptive_field_shift;
int32_t num_chunks = labels.size();
int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) /
receptive_field_shift +
1;
FloatRowVector count(num_frames);
FloatRowVector weight(num_frames);
count.setZero();
weight.setZero();
for (int32_t i = 0; i != num_chunks; ++i) {
int32_t start =
static_cast<float>(i) * window_shift / receptive_field_shift + 0.5;
auto seq = Eigen::seqN(start, labels[i].rows());
count(seq).array() += labels[i].rowwise().sum().array().cast<float>();
weight(seq).array() += 1;
}
return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast<int32_t>();
}
// ans.first: a list of (chunk_id, speaker_id)
// ans.second: a list of list of (start_sample_index, end_sample_index)
//
// ans.first[i] corresponds to ans.second[i]
std::pair<std::vector<Int32Pair>, std::vector<std::vector<Int32Pair>>>
GetChunkSpeakerSampleIndexes(const std::vector<Matrix2DInt32> &labels) const {
auto new_labels = ExcludeOverlap(labels);
std::vector<Int32Pair> chunk_speaker_list;
std::vector<std::vector<Int32Pair>> samples_index_list;
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
int32_t window_shift = meta_data.window_shift;
int32_t receptive_field_shift = meta_data.receptive_field_shift;
int32_t num_speakers = meta_data.num_speakers;
int32_t chunk_index = 0;
for (const auto &label : new_labels) {
Matrix2DInt32 tmp = label.transpose();
// tmp: (num_speakers, num_frames)
int32_t num_frames = tmp.cols();
int32_t sample_offset = chunk_index * window_shift;
for (int32_t speaker_index = 0; speaker_index != num_speakers;
++speaker_index) {
auto d = tmp.row(speaker_index);
if (d.sum() < 10) {
// skip segments less than 10 frames
continue;
}
Int32Pair this_chunk_speaker = {chunk_index, speaker_index};
std::vector<Int32Pair> this_speaker_samples;
bool is_active = false;
int32_t start_index;
for (int32_t k = 0; k != num_frames; ++k) {
if (d[k] != 0) {
if (!is_active) {
is_active = true;
start_index = k;
}
} else if (is_active) {
is_active = false;
int32_t start_samples =
static_cast<float>(start_index) / num_frames * window_size +
sample_offset;
int32_t end_samples =
static_cast<float>(k) / num_frames * window_size +
sample_offset;
this_speaker_samples.emplace_back(start_samples, end_samples);
}
}
if (is_active) {
int32_t start_samples =
static_cast<float>(start_index) / num_frames * window_size +
sample_offset;
int32_t end_samples =
static_cast<float>(num_frames - 1) / num_frames * window_size +
sample_offset;
this_speaker_samples.emplace_back(start_samples, end_samples);
}
chunk_speaker_list.push_back(std::move(this_chunk_speaker));
samples_index_list.push_back(std::move(this_speaker_samples));
} // for (int32_t speaker_index = 0;
chunk_index += 1;
} // for (const auto &label : new_labels)
return {chunk_speaker_list, samples_index_list};
}
// If there are multiple speakers at a frame, then this frame is excluded.
std::vector<Matrix2DInt32> ExcludeOverlap(
const std::vector<Matrix2DInt32> &labels) const {
int32_t num_chunks = labels.size();
std::vector<Matrix2DInt32> ans;
ans.reserve(num_chunks);
for (const auto &label : labels) {
Matrix2DInt32 new_label(label.rows(), label.cols());
new_label.setZero();
Int32RowVector v = label.rowwise().sum();
for (int32_t i = 0; i != v.cols(); ++i) {
if (v[i] < 2) {
new_label.row(i) = label.row(i);
}
}
ans.push_back(std::move(new_label));
}
return ans;
}
/**
* @param sample_indexes[i] contains the sample segment start and end indexes
* for the i-th (chunk, speaker) pair
* @return Return a matrix of shape (sample_indexes.size(), embedding_dim)
* where ans.row[i] contains the embedding for the
* i-th (chunk, speaker) pair
*/
Matrix2D ComputeEmbeddings(
const float *audio, int32_t n,
const std::vector<std::vector<Int32Pair>> &sample_indexes,
OfflineSpeakerDiarizationProgressCallback callback,
void *callback_arg) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t sample_rate = meta_data.sample_rate;
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
int32_t k = 0;
for (const auto &v : sample_indexes) {
auto stream = embedding_extractor_.CreateStream();
for (const auto &p : v) {
int32_t end = (p.second <= n) ? p.second : n;
int32_t num_samples = end - p.first;
if (num_samples > 0) {
stream->AcceptWaveform(sample_rate, audio + p.first, num_samples);
}
}
stream->InputFinished();
if (!embedding_extractor_.IsReady(stream.get())) {
SHERPA_ONNX_LOGE(
"This segment is too short, which should not happen since we have "
"already filtered short segments");
SHERPA_ONNX_EXIT(-1);
}
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
k += 1;
if (callback) {
callback(k, ans.rows(), callback_arg);
}
}
return ans;
}
std::unordered_map<Int32Pair, int32_t, PairHash> ConvertChunkSpeakerToCluster(
const std::vector<Int32Pair> &chunk_speaker_pair,
const std::vector<int32_t> &cluster_labels) const {
std::unordered_map<Int32Pair, int32_t, PairHash> ans;
int32_t k = 0;
for (const auto &p : chunk_speaker_pair) {
ans[p] = cluster_labels[k];
k += 1;
}
return ans;
}
std::vector<Matrix2DInt32> ReLabel(
const std::vector<Matrix2DInt32> &labels, int32_t max_cluster_index,
std::unordered_map<Int32Pair, int32_t, PairHash> chunk_speaker_to_cluster)
const {
std::vector<Matrix2DInt32> new_labels;
new_labels.reserve(labels.size());
int32_t chunk_index = 0;
for (const auto &label : labels) {
Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1);
new_label.setZero();
Matrix2DInt32 t = label.transpose();
// t: (num_speakers, num_frames)
for (int32_t speaker_index = 0; speaker_index != t.rows();
++speaker_index) {
if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) {
continue;
}
int32_t new_speaker_index =
chunk_speaker_to_cluster.at({chunk_index, speaker_index});
for (int32_t k = 0; k != t.cols(); ++k) {
if (t(speaker_index, k) == 1) {
new_label(k, new_speaker_index) = 1;
}
}
}
new_labels.push_back(std::move(new_label));
chunk_index += 1;
}
return new_labels;
}
Matrix2DInt32 ComputeSpeakerCount(const std::vector<Matrix2DInt32> &labels,
int32_t num_samples) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
int32_t window_shift = meta_data.window_shift;
int32_t receptive_field_shift = meta_data.receptive_field_shift;
int32_t num_chunks = labels.size();
int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) /
receptive_field_shift +
1;
Matrix2DInt32 count(num_frames, labels[0].cols());
count.setZero();
for (int32_t i = 0; i != num_chunks; ++i) {
int32_t start =
static_cast<float>(i) * window_shift / receptive_field_shift + 0.5;
auto seq = Eigen::seqN(start, labels[i].rows());
count(seq, Eigen::all).array() += labels[i].array();
}
bool has_last_chunk = (num_samples - window_size) % window_shift > 0;
if (has_last_chunk) {
return count;
}
int32_t last_frame = num_samples / receptive_field_shift;
return count(Eigen::seq(0, last_frame), Eigen::all);
}
Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count,
const Int32RowVector &speakers_per_frame) const {
int32_t num_rows = count.rows();
int32_t num_cols = count.cols();
Matrix2DInt32 ans(num_rows, num_cols);
ans.setZero();
for (int32_t i = 0; i != num_rows; ++i) {
int32_t k = speakers_per_frame[i];
if (k == 0) {
continue;
}
auto top_k = TopkIndex(&count(i, 0), num_cols, k);
for (int32_t m : top_k) {
ans(i, m) = 1;
}
}
return ans;
}
OfflineSpeakerDiarizationResult ComputeResult(
const Matrix2DInt32 &final_labels) const {
Matrix2DInt32 final_labels_t = final_labels.transpose();
int32_t num_speakers = final_labels_t.rows();
int32_t num_frames = final_labels_t.cols();
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t window_size = meta_data.window_size;
int32_t window_shift = meta_data.window_shift;
int32_t receptive_field_shift = meta_data.receptive_field_shift;
int32_t receptive_field_size = meta_data.receptive_field_size;
int32_t sample_rate = meta_data.sample_rate;
float scale = static_cast<float>(receptive_field_shift) / sample_rate;
float scale_offset = 0.5 * receptive_field_size / sample_rate;
OfflineSpeakerDiarizationResult ans;
for (int32_t speaker_index = 0; speaker_index != num_speakers;
++speaker_index) {
std::vector<OfflineSpeakerDiarizationSegment> this_speaker;
bool is_active = final_labels_t(speaker_index, 0) > 0;
int32_t start_index = is_active ? 0 : -1;
for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) {
if (is_active) {
if (final_labels_t(speaker_index, frame_index) == 0) {
float start_time = start_index * scale + scale_offset;
float end_time = frame_index * scale + scale_offset;
OfflineSpeakerDiarizationSegment segment(start_time, end_time,
speaker_index);
this_speaker.push_back(segment);
is_active = false;
}
} else if (final_labels_t(speaker_index, frame_index) == 1) {
is_active = true;
start_index = frame_index;
}
}
if (is_active) {
float start_time = start_index * scale + scale_offset;
float end_time = (num_frames - 1) * scale + scale_offset;
OfflineSpeakerDiarizationSegment segment(start_time, end_time,
speaker_index);
this_speaker.push_back(segment);
}
// merge segments if the gap between them is less than min_duration_off
MergeSegments(&this_speaker);
for (const auto &seg : this_speaker) {
if (seg.Duration() > config_.min_duration_on) {
ans.Add(seg);
}
}
} // for (int32_t speaker_index = 0; speaker_index != num_speakers;
return ans;
}
void MergeSegments(
std::vector<OfflineSpeakerDiarizationSegment> *segments) const {
float min_duration_off = config_.min_duration_off;
bool changed = true;
while (changed) {
changed = false;
for (int32_t i = 0; i < static_cast<int32_t>(segments->size()) - 1; ++i) {
auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off);
if (s) {
(*segments)[i] = s.value();
segments->erase(segments->begin() + i + 1);
changed = true;
break;
}
}
}
}
private:
OfflineSpeakerDiarizationConfig config_;
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
SpeakerEmbeddingExtractor embedding_extractor_;
FastClustering clustering_;
Matrix2DInt32 powerset_mapping_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
... ...
// sherpa-onnx/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include <algorithm>
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment(
float start, float end, int32_t speaker, const std::string &text /*= {}*/) {
if (start > end) {
SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end);
SHERPA_ONNX_EXIT(-1);
}
start_ = start;
end_ = end;
speaker_ = speaker;
text_ = text;
}
std::optional<OfflineSpeakerDiarizationSegment>
OfflineSpeakerDiarizationSegment::Merge(
const OfflineSpeakerDiarizationSegment &other, float gap) const {
if (other.speaker_ != speaker_) {
SHERPA_ONNX_LOGE(
"The two segments should have the same speaker. this->speaker: %d, "
"other.speaker: %d",
speaker_, other.speaker_);
return std::nullopt;
}
if (end_ < other.start_ && end_ + gap >= other.start_) {
return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_);
} else if (other.end_ < start_ && other.end_ + gap >= start_) {
return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_);
} else {
return std::nullopt;
}
}
std::string OfflineSpeakerDiarizationSegment::ToString() const {
char s[128];
snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, speaker_);
std::ostringstream os;
os << s;
if (!text_.empty()) {
os << " " << text_;
}
return os.str();
}
void OfflineSpeakerDiarizationResult::Add(
const OfflineSpeakerDiarizationSegment &segment) {
segments_.push_back(segment);
}
int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const {
std::unordered_set<int32_t> count;
for (const auto &s : segments_) {
count.insert(s.Speaker());
}
return count.size();
}
int32_t OfflineSpeakerDiarizationResult::NumSegments() const {
return segments_.size();
}
// Return a list of segments sorted by segment.start time
std::vector<OfflineSpeakerDiarizationSegment>
OfflineSpeakerDiarizationResult::SortByStartTime() const {
auto ans = segments_;
std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) {
return (a.Start() < b.Start()) ||
((a.Start() == b.Start()) && (a.Speaker() < b.Speaker()));
});
return ans;
}
std::vector<std::vector<OfflineSpeakerDiarizationSegment>>
OfflineSpeakerDiarizationResult::SortBySpeaker() const {
auto tmp = segments_;
std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) {
return (a.Speaker() < b.Speaker()) ||
((a.Speaker() == b.Speaker()) && (a.Start() < b.Start()));
});
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> ans(NumSpeakers());
for (auto &s : tmp) {
ans[s.Speaker()].push_back(std::move(s));
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-diarization-result.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
namespace sherpa_onnx {
class OfflineSpeakerDiarizationSegment {
public:
OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker,
const std::string &text = {});
// If the gap between the two segments is less than the given gap, then we
// merge them and return a new segment. Otherwise, it returns null.
std::optional<OfflineSpeakerDiarizationSegment> Merge(
const OfflineSpeakerDiarizationSegment &other, float gap) const;
float Start() const { return start_; }
float End() const { return end_; }
int32_t Speaker() const { return speaker_; }
const std::string &Text() const { return text_; }
float Duration() const { return end_ - start_; }
std::string ToString() const;
private:
float start_; // in seconds
float end_; // in seconds
int32_t speaker_; // ID of the speaker, starting from 0
std::string text_; // If not empty, it contains the speech recognition result
// of this segment
};
class OfflineSpeakerDiarizationResult {
public:
// Add a new segment
void Add(const OfflineSpeakerDiarizationSegment &segment);
// Number of distinct speakers contained in this object at this point
int32_t NumSpeakers() const;
int32_t NumSegments() const;
// Return a list of segments sorted by segment.start time
std::vector<OfflineSpeakerDiarizationSegment> SortByStartTime() const;
// ans.size() == NumSpeakers().
// ans[i] is for speaker_i and is sorted by start time
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
const;
public:
std::vector<OfflineSpeakerDiarizationSegment> segments_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
... ...
// sherpa-onnx/csrc/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include <string>
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
namespace sherpa_onnx {
void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) {
ParseOptions po_segmentation("segmentation", po);
segmentation.Register(&po_segmentation);
ParseOptions po_embedding("embedding", po);
embedding.Register(&po_embedding);
ParseOptions po_clustering("clustering", po);
clustering.Register(&po_clustering);
po->Register("min-duration-on", &min_duration_on,
"if a segment is less than this value, then it is discarded. "
"Set it to 0 so that no segment is discarded");
po->Register("min-duration-off", &min_duration_off,
"if the gap between to segments of the same speaker is less "
"than this value, then these two segments are merged into a "
"single segment. We do it recursively.");
}
bool OfflineSpeakerDiarizationConfig::Validate() const {
if (!segmentation.Validate()) {
return false;
}
if (!embedding.Validate()) {
return false;
}
if (!clustering.Validate()) {
return false;
}
return true;
}
std::string OfflineSpeakerDiarizationConfig::ToString() const {
std::ostringstream os;
os << "OfflineSpeakerDiarizationConfig(";
os << "segmentation=" << segmentation.ToString() << ", ";
os << "embedding=" << embedding.ToString() << ", ";
os << "clustering=" << clustering.ToString() << ", ";
os << "min_duration_on=" << min_duration_on << ", ";
os << "min_duration_off=" << min_duration_off << ")";
return os.str();
}
OfflineSpeakerDiarization::OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}
OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;
int32_t OfflineSpeakerDiarization::SampleRate() const {
return impl_->SampleRate();
}
OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,
void *callback_arg /*= nullptr*/) const {
return impl_->Process(audio, n, callback, callback_arg);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-diarization.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#include <functional>
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/fast-clustering-config.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
struct OfflineSpeakerDiarizationConfig {
OfflineSpeakerSegmentationModelConfig segmentation;
SpeakerEmbeddingExtractorConfig embedding;
FastClusteringConfig clustering;
// if a segment is less than this value, then it is discarded
float min_duration_on = 0.3; // in seconds
// if the gap between to segments of the same speaker is less than this value,
// then these two segments are merged into a single segment.
// We do this recursively.
float min_duration_off = 0.5; // in seconds
OfflineSpeakerDiarizationConfig() = default;
OfflineSpeakerDiarizationConfig(
const OfflineSpeakerSegmentationModelConfig &segmentation,
const SpeakerEmbeddingExtractorConfig &embedding,
const FastClusteringConfig &clustering)
: segmentation(segmentation),
embedding(embedding),
clustering(clustering) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class OfflineSpeakerDiarizationImpl;
using OfflineSpeakerDiarizationProgressCallback = std::function<int32_t(
int32_t processed_chunks, int32_t num_chunks, void *arg)>;
class OfflineSpeakerDiarization {
public:
explicit OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config);
~OfflineSpeakerDiarization();
// Expected sample rate of the input audio samples
int32_t SampleRate() const;
OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
void *callback_arg = nullptr) const;
private:
std::unique_ptr<OfflineSpeakerDiarizationImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) {
pyannote.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool OfflineSpeakerSegmentationModelConfig::Validate() const {
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}
if (!pyannote.model.empty()) {
return pyannote.Validate();
}
if (pyannote.model.empty()) {
SHERPA_ONNX_LOGE(
"You have to provide at least one speaker segmentation model");
return false;
}
return true;
}
std::string OfflineSpeakerSegmentationModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSpeakerSegmentationModelConfig(";
os << "pyannote=" << pyannote.ToString() << ", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSpeakerSegmentationModelConfig {
OfflineSpeakerSegmentationPyannoteModelConfig pyannote;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
OfflineSpeakerSegmentationModelConfig() = default;
explicit OfflineSpeakerSegmentationModelConfig(
const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote,
int32_t num_threads, bool debug, const std::string &provider)
: pyannote(pyannote),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) {
po->Register("pyannote-model", &model,
"Path to model.onnx of the Pyannote segmentation model.");
}
bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist",
model.c_str());
return false;
}
return true;
}
std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineSpeakerSegmentationPyannoteModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineSpeakerSegmentationPyannoteModelConfig {
std::string model;
OfflineSpeakerSegmentationPyannoteModelConfig() = default;
explicit OfflineSpeakerSegmentationPyannoteModelConfig(
const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
#include <cstdint>
#include <string>
namespace sherpa_onnx {
// If you are not sure what each field means, please
// have a look of the Python file in the model directory that
// you have downloaded.
struct OfflineSpeakerSegmentationPyannoteModelMetaData {
int32_t sample_rate = 0;
int32_t window_size = 0; // in samples
int32_t window_shift = 0; // in samples
int32_t receptive_field_size = 0; // in samples
int32_t receptive_field_shift = 0; // in samples
int32_t num_speakers = 0;
int32_t powerset_max_classes = 0;
int32_t num_classes = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
namespace sherpa_onnx {
class OfflineSpeakerSegmentationPyannoteModel::Impl {
public:
explicit Impl(const OfflineSpeakerSegmentationModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.pyannote.model);
Init(buf.data(), buf.size());
}
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
const {
return meta_data_;
}
Ort::Value Forward(Ort::Value x) {
auto out = sess_->Run({}, input_names_ptr_.data(), &x, 1,
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(out[0]);
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size");
meta_data_.window_shift =
static_cast<int32_t>(0.1 * meta_data_.window_size);
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size,
"receptive_field_size");
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift,
"receptive_field_shift");
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers");
SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes,
"powerset_max_classes");
SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes");
}
private:
OfflineSpeakerSegmentationModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_;
};
OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineSpeakerSegmentationPyannoteModel::
~OfflineSpeakerSegmentationPyannoteModel() = default;
const OfflineSpeakerSegmentationPyannoteModelMetaData &
OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const {
return impl_->GetModelMetaData();
}
Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(
Ort::Value x) const {
return impl_->Forward(std::move(x));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
namespace sherpa_onnx {
class OfflineSpeakerSegmentationPyannoteModel {
public:
explicit OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config);
~OfflineSpeakerSegmentationPyannoteModel();
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
const;
/**
* @param x A 3-D float tensor of shape (batch_size, 1, num_samples)
* @return Return a float tensor of
* shape (batch_size, num_frames, num_speakers). Note that
* num_speakers here uses powerset encoding.
*/
Ort::Value Forward(Ort::Value x) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
... ...
... ... @@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) {
bool TensorrtConfig::Validate() const {
if (trt_max_workspace_size < 0) {
SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.",
trt_max_workspace_size);
std::ostringstream os;
os << "trt_max_workspace_size: " << trt_max_workspace_size
<< " is not valid.";
SHERPA_ONNX_LOGE("%s", os.str().c_str());
return false;
}
if (trt_max_partition_iterations < 0) {
... ...
... ... @@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
api.ReleaseStatus(status);
}
static Ort::SessionOptions GetSessionOptionsImpl(
Ort::SessionOptions GetSessionOptionsImpl(
int32_t num_threads, const std::string &provider_str,
const ProviderConfig *provider_config = nullptr) {
const ProviderConfig *provider_config /*= nullptr*/) {
Provider p = StringToProvider(provider_str);
Ort::SessionOptions sess_opts;
... ... @@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
&config.provider_config);
}
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) {
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
}
... ... @@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
}
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
#if SHERPA_ONNX_ENABLE_TTS
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
#endif
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const OfflinePunctuationModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const OnlinePunctuationModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx
... ...
... ... @@ -8,53 +8,28 @@
#include <string>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
#if SHERPA_ONNX_ENABLE_TTS
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#endif
namespace sherpa_onnx {
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type);
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
Ort::SessionOptions GetSessionOptionsImpl(
int32_t num_threads, const std::string &provider_str,
const ProviderConfig *provider_config = nullptr);
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
#if SHERPA_ONNX_ENABLE_TTS
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
#endif
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config);
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
Ort::SessionOptions GetSessionOptions(
const OfflinePunctuationModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
const std::string &model_type);
Ort::SessionOptions GetSessionOptions(
const OnlinePunctuationModelConfig &config);
template <typename T>
Ort::SessionOptions GetSessionOptions(const T &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
void *arg) {
float progress = 100.0 * processed_chunks / num_chunks;
fprintf(stderr, "progress %.2f%%\n", progress);
// the return value is currently ignored
return 0;
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline/Non-streaming speaker diarization with sherpa-onnx
Usage example:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Build sherpa-onnx
Step 5. Run it
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
Since we know that there are four speakers in the test wave file, we use
--clustering.num-clusters=4 in the above example.
If we don't know number of speakers in the given wave file, we can use
the argument --clustering.cluster-threshold. The following is an example:
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
A larger threshold leads to few clusters, i.e., few speakers;
a smaller threshold leads to more clusters, i.e., more speakers
)usage";
sherpa_onnx::OfflineSpeakerDiarizationConfig config;
sherpa_onnx::ParseOptions po(kUsageMessage);
config.Register(&po);
po.Read(argc, argv);
std::cout << config.ToString() << "\n";
if (!config.Validate()) {
po.PrintUsage();
std::cerr << "Errors in config!\n";
return -1;
}
if (po.NumArgs() != 1) {
std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
po.PrintUsage();
return -1;
}
sherpa_onnx::OfflineSpeakerDiarization sd(config);
std::cout << "Started\n";
const auto begin = std::chrono::steady_clock::now();
const std::string wav_filename = po.GetArg(1);
int32_t sample_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
if (!is_ok) {
std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
return -1;
}
if (sample_rate != sd.SampleRate()) {
std::cerr << "Expect sample rate " << sd.SampleRate()
<< ". Given: " << sample_rate << "\n";
return -1;
}
float duration = samples.size() / static_cast<float>(sample_rate);
auto result =
sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
.SortByStartTime();
for (const auto &r : result) {
std::cout << r.ToString() << "\n";
}
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "Duration : %.3f s\n", duration);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}
... ...
... ... @@ -9,14 +9,15 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"
int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) {
static int32_t AudioCallback(const float * /*samples*/, int32_t n,
float progress) {
printf("sample=%d, progress=%f\n", n, progress);
return 1;
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline text-to-speech with sherpa-onnx
Offline/Non-streaming text-to-speech with sherpa-onnx
Usage example:
... ... @@ -79,7 +80,7 @@ or details.
sherpa_onnx::OfflineTts tts(config);
const auto begin = std::chrono::steady_clock::now();
auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback);
auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback);
const auto end = std::chrono::steady_clock::now();
if (audio.samples.empty()) {
... ...
... ... @@ -19,7 +19,7 @@ The input text can contain English words.
Usage:
Please download the model from:
https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
./bin/Release/sherpa-onnx-online-punctuation \
--cnn-bilstm=/path/to/model.onnx \
... ...
... ... @@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
bool SpeakerEmbeddingExtractorConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --model");
SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist",
SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist",
model.c_str());
return false;
}
... ...