online-stream.cc
3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// sherpa-onnx/csrc/online-stream.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-stream.h"
#include <memory>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/features.h"
namespace sherpa_onnx {
class OnlineStream::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config)
: feat_extractor_(config) {}
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
}
void InputFinished() { feat_extractor_.InputFinished(); }
int32_t NumFramesReady() const {
return feat_extractor_.NumFramesReady() - start_frame_index_;
}
bool IsLastFrame(int32_t frame) const {
return feat_extractor_.IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
}
void Reset() {
// we don't reset the feature extractor
start_frame_index_ += num_processed_frames_;
num_processed_frames_ = 0;
}
int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
OnlineTransducerDecoderResult &GetResult() { return result_; }
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
void SetStates(std::vector<Ort::Value> states) {
states_ = std::move(states);
}
std::vector<Ort::Value> &GetStates() { return states_; }
private:
FeatureExtractor feat_extractor_;
int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
: impl_(std::make_unique<Impl>(config)) {}
OnlineStream::~OnlineStream() = default;
void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform,
int32_t n) {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}
void OnlineStream::InputFinished() { impl_->InputFinished(); }
int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); }
bool OnlineStream::IsLastFrame(int32_t frame) const {
return impl_->IsLastFrame(frame);
}
std::vector<float> OnlineStream::GetFrames(int32_t frame_index,
int32_t n) const {
return impl_->GetFrames(frame_index, n);
}
void OnlineStream::Reset() { impl_->Reset(); }
int32_t OnlineStream::FeatureDim() const { return impl_->FeatureDim(); }
int32_t &OnlineStream::GetNumProcessedFrames() {
return impl_->GetNumProcessedFrames();
}
void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
impl_->SetResult(r);
}
OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
impl_->SetStates(std::move(states));
}
std::vector<Ort::Value> &OnlineStream::GetStates() {
return impl_->GetStates();
}
} // namespace sherpa_onnx