offline-speaker-diarization-result.cc
3.0 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// sherpa-onnx/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include <algorithm>
#include <array>
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment(
float start, float end, int32_t speaker, const std::string &text /*= {}*/) {
if (start > end) {
SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end);
SHERPA_ONNX_EXIT(-1);
}
start_ = start;
end_ = end;
speaker_ = speaker;
text_ = text;
}
std::optional<OfflineSpeakerDiarizationSegment>
OfflineSpeakerDiarizationSegment::Merge(
const OfflineSpeakerDiarizationSegment &other, float gap) const {
if (other.speaker_ != speaker_) {
SHERPA_ONNX_LOGE(
"The two segments should have the same speaker. this->speaker: %d, "
"other.speaker: %d",
speaker_, other.speaker_);
return std::nullopt;
}
if (end_ < other.start_ && end_ + gap >= other.start_) {
return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_);
} else if (other.end_ < start_ && other.end_ + gap >= start_) {
return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_);
} else {
return std::nullopt;
}
}
std::string OfflineSpeakerDiarizationSegment::ToString() const {
std::array<char, 128> s{};
snprintf(s.data(), s.size(), "%.3f -- %.3f speaker_%02d", start_, end_,
speaker_);
std::ostringstream os;
os << s.data();
if (!text_.empty()) {
os << " " << text_;
}
return os.str();
}
void OfflineSpeakerDiarizationResult::Add(
const OfflineSpeakerDiarizationSegment &segment) {
segments_.push_back(segment);
}
int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const {
std::unordered_set<int32_t> count;
for (const auto &s : segments_) {
count.insert(s.Speaker());
}
return count.size();
}
int32_t OfflineSpeakerDiarizationResult::NumSegments() const {
return segments_.size();
}
// Return a list of segments sorted by segment.start time
std::vector<OfflineSpeakerDiarizationSegment>
OfflineSpeakerDiarizationResult::SortByStartTime() const {
auto ans = segments_;
std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) {
return (a.Start() < b.Start()) ||
((a.Start() == b.Start()) && (a.Speaker() < b.Speaker()));
});
return ans;
}
std::vector<std::vector<OfflineSpeakerDiarizationSegment>>
OfflineSpeakerDiarizationResult::SortBySpeaker() const {
auto tmp = segments_;
std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) {
return (a.Speaker() < b.Speaker()) ||
((a.Speaker() == b.Speaker()) && (a.Start() < b.Start()));
});
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> ans(NumSpeakers());
for (auto &s : tmp) {
ans[s.Speaker()].push_back(std::move(s));
}
return ans;
}
} // namespace sherpa_onnx