Fangjun Kuang
Committed by GitHub

Add online stream. (#28)

@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
3 add_executable(sherpa-onnx 3 add_executable(sherpa-onnx
4 features.cc 4 features.cc
5 online-lstm-transducer-model.cc 5 online-lstm-transducer-model.cc
  6 + online-stream.cc
6 online-transducer-greedy-search-decoder.cc 7 online-transducer-greedy-search-decoder.cc
7 online-transducer-model-config.cc 8 online-transducer-model-config.cc
8 online-transducer-model.cc 9 online-transducer-model.cc
@@ -11,16 +11,12 @@ @@ -11,16 +11,12 @@
11 namespace sherpa_onnx { 11 namespace sherpa_onnx {
12 12
13 struct FeatureExtractorConfig { 13 struct FeatureExtractorConfig {
14 - int32_t sampling_rate = 16000; 14 + float sampling_rate = 16000;
15 int32_t feature_dim = 80; 15 int32_t feature_dim = 80;
16 }; 16 };
17 17
18 class FeatureExtractor { 18 class FeatureExtractor {
19 public: 19 public:
20 - /**  
21 - * @param sampling_rate Sampling rate of the data used to train the model.  
22 - * @param feature_dim Dimension of the features used to train the model.  
23 - */  
24 explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); 20 explicit FeatureExtractor(const FeatureExtractorConfig &config = {});
25 ~FeatureExtractor(); 21 ~FeatureExtractor();
26 22
@@ -32,16 +28,19 @@ class FeatureExtractor { @@ -32,16 +28,19 @@ class FeatureExtractor {
32 */ 28 */
33 void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); 29 void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
34 30
35 - // InputFinished() tells the class you won't be providing any  
36 - // more waveform. This will help flush out the last frame or two  
37 - // of features, in the case where snip-edges == false; it also  
38 - // affects the return value of IsLastFrame(). 31 + /**
  32 + * InputFinished() tells the class you won't be providing any
  33 + * more waveform. This will help flush out the last frame or two
  34 + * of features, in the case where snip-edges == false; it also
  35 + * affects the return value of IsLastFrame().
  36 + */
39 void InputFinished(); 37 void InputFinished();
40 38
41 int32_t NumFramesReady() const; 39 int32_t NumFramesReady() const;
42 40
43 - // Note: IsLastFrame() will only ever return true if you have called  
44 - // InputFinished() (and this frame is the last frame). 41 + /** Note: IsLastFrame() will only ever return true if you have called
  42 + * InputFinished() (and this frame is the last frame).
  43 + */
45 bool IsLastFrame(int32_t frame) const; 44 bool IsLastFrame(int32_t frame) const;
46 45
47 /** Get n frames starting from the given frame index. 46 /** Get n frames starting from the given frame index.
  1 +// sherpa-onnx/csrc/online-stream.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/online-stream.h"
  5 +
  6 +#include <memory>
  7 +#include <vector>
  8 +
  9 +#include "sherpa-onnx/csrc/features.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +class OnlineStream::Impl {
  14 + public:
  15 + explicit Impl(const FeatureExtractorConfig &config)
  16 + : feat_extractor_(config) {}
  17 +
  18 + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
  19 + feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
  20 + }
  21 +
  22 + void InputFinished() { feat_extractor_.InputFinished(); }
  23 +
  24 + int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); }
  25 +
  26 + bool IsLastFrame(int32_t frame) const {
  27 + return feat_extractor_.IsLastFrame(frame);
  28 + }
  29 +
  30 + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
  31 + return feat_extractor_.GetFrames(frame_index, n);
  32 + }
  33 +
  34 + void Reset() { feat_extractor_.Reset(); }
  35 +
  36 + int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
  37 +
  38 + void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
  39 +
  40 + const OnlineTransducerDecoderResult &GetResult() const { return result_; }
  41 +
  42 + int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
  43 +
  44 + private:
  45 + FeatureExtractor feat_extractor_;
  46 + int32_t num_processed_frames_ = 0; // before subsampling
  47 + OnlineTransducerDecoderResult result_;
  48 +};
  49 +
  50 +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
  51 + : impl_(std::make_unique<Impl>(config)) {}
  52 +
  53 +OnlineStream::~OnlineStream() = default;
  54 +
  55 +void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform,
  56 + int32_t n) {
  57 + impl_->AcceptWaveform(sampling_rate, waveform, n);
  58 +}
  59 +
  60 +void OnlineStream::InputFinished() { impl_->InputFinished(); }
  61 +
  62 +int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); }
  63 +
  64 +bool OnlineStream::IsLastFrame(int32_t frame) const {
  65 + return impl_->IsLastFrame(frame);
  66 +}
  67 +
  68 +std::vector<float> OnlineStream::GetFrames(int32_t frame_index,
  69 + int32_t n) const {
  70 + return impl_->GetFrames(frame_index, n);
  71 +}
  72 +
  73 +void OnlineStream::Reset() { impl_->Reset(); }
  74 +
  75 +int32_t OnlineStream::FeatureDim() const { return impl_->FeatureDim(); }
  76 +
  77 +int32_t &OnlineStream::GetNumProcessedFrames() {
  78 + return impl_->GetNumProcessedFrames();
  79 +}
  80 +
  81 +void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
  82 + impl_->SetResult(r);
  83 +}
  84 +
  85 +const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
  86 + return impl_->GetResult();
  87 +}
  88 +
  89 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-stream.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_STREAM_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_STREAM_H_
  7 +
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/features.h"
  12 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OnlineStream {
  17 + public:
  18 + explicit OnlineStream(const FeatureExtractorConfig &config = {});
  19 + ~OnlineStream();
  20 +
  21 + /**
  22 + @param sampling_rate The sampling_rate of the input waveform. Should match
  23 + the one expected by the feature extractor.
  24 + @param waveform Pointer to a 1-D array of size n
  25 + @param n Number of entries in waveform
  26 + */
  27 + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
  28 +
  29 + /**
  30 + * InputFinished() tells the class you won't be providing any
  31 + * more waveform. This will help flush out the last frame or two
  32 + * of features, in the case where snip-edges == false; it also
  33 + * affects the return value of IsLastFrame().
  34 + */
  35 + void InputFinished();
  36 +
  37 + int32_t NumFramesReady() const;
  38 +
  39 + /** Note: IsLastFrame() will only ever return true if you have called
  40 + * InputFinished() (and this frame is the last frame).
  41 + */
  42 + bool IsLastFrame(int32_t frame) const;
  43 +
  44 + /** Get n frames starting from the given frame index.
  45 + *
  46 + * @param frame_index The starting frame index
  47 + * @param n Number of frames to get.
  48 + * @return Return a 2-D tensor of shape (n, feature_dim).
  49 + * which is flattened into a 1-D vector (flattened in in row major)
  50 + */
  51 + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
  52 +
  53 + void Reset();
  54 +
  55 + int32_t FeatureDim() const;
  56 +
  57 + // Return a reference to the number of processed frames so far.
  58 + // Initially, it is 0. It is always less than NumFramesReady().
  59 + //
  60 + // The returned reference is valid as long as this object is alive.
  61 + int32_t &GetNumProcessedFrames();
  62 +
  63 + void SetResult(const OnlineTransducerDecoderResult &r);
  64 + const OnlineTransducerDecoderResult &GetResult() const;
  65 +
  66 + private:
  67 + class Impl;
  68 + std::unique_ptr<Impl> impl_;
  69 +};
  70 +
  71 +} // namespace sherpa_onnx
  72 +
  73 +#endif // SHERPA_ONNX_CSRC_ONLINE_STREAM_H_
@@ -8,8 +8,7 @@ @@ -8,8 +8,7 @@
8 #include <string> 8 #include <string>
9 #include <vector> 9 #include <vector>
10 10
11 -#include "kaldi-native-fbank/csrc/online-feature.h"  
12 -#include "sherpa-onnx/csrc/features.h" 11 +#include "sherpa-onnx/csrc/online-stream.h"
13 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 12 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
14 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 13 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
15 #include "sherpa-onnx/csrc/online-transducer-model.h" 14 #include "sherpa-onnx/csrc/online-transducer-model.h"
@@ -64,7 +63,7 @@ for a list of pre-trained models to download. @@ -64,7 +63,7 @@ for a list of pre-trained models to download.
64 63
65 std::vector<Ort::Value> states = model->GetEncoderInitStates(); 64 std::vector<Ort::Value> states = model->GetEncoderInitStates();
66 65
67 - int32_t expected_sampling_rate = 16000; 66 + float expected_sampling_rate = 16000;
68 67
69 bool is_ok = false; 68 bool is_ok = false;
70 std::vector<float> samples = 69 std::vector<float> samples =
@@ -75,7 +74,7 @@ for a list of pre-trained models to download. @@ -75,7 +74,7 @@ for a list of pre-trained models to download.
75 return -1; 74 return -1;
76 } 75 }
77 76
78 - float duration = samples.size() / static_cast<float>(expected_sampling_rate); 77 + float duration = samples.size() / expected_sampling_rate;
79 78
80 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); 79 fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
81 fprintf(stderr, "wav duration (s): %.3f\n", duration); 80 fprintf(stderr, "wav duration (s): %.3f\n", duration);
@@ -83,32 +82,33 @@ for a list of pre-trained models to download. @@ -83,32 +82,33 @@ for a list of pre-trained models to download.
83 auto begin = std::chrono::steady_clock::now(); 82 auto begin = std::chrono::steady_clock::now();
84 fprintf(stderr, "Started\n"); 83 fprintf(stderr, "Started\n");
85 84
86 - sherpa_onnx::FeatureExtractor feat_extractor;  
87 - feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),  
88 - samples.size()); 85 + sherpa_onnx::OnlineStream stream;
  86 + stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
89 87
90 std::vector<float> tail_paddings( 88 std::vector<float> tail_paddings(
91 static_cast<int>(0.2 * expected_sampling_rate)); 89 static_cast<int>(0.2 * expected_sampling_rate));
92 - feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),  
93 - tail_paddings.size());  
94 - feat_extractor.InputFinished(); 90 + stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
  91 + tail_paddings.size());
  92 + stream.InputFinished();
95 93
96 - int32_t num_frames = feat_extractor.NumFramesReady();  
97 - int32_t feature_dim = feat_extractor.FeatureDim(); 94 + int32_t num_frames = stream.NumFramesReady();
  95 + int32_t feature_dim = stream.FeatureDim();
98 96
99 std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; 97 std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
100 98
101 sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); 99 sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());
102 std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = { 100 std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {
103 decoder.GetEmptyResult()}; 101 decoder.GetEmptyResult()};
104 -  
105 - for (int32_t start = 0; start + chunk_size < num_frames;  
106 - start += chunk_shift) {  
107 - std::vector<float> features = feat_extractor.GetFrames(start, chunk_size); 102 + while (stream.NumFramesReady() - stream.GetNumProcessedFrames() >
  103 + chunk_size) {
  104 + std::vector<float> features =
  105 + stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size);
  106 + stream.GetNumProcessedFrames() += chunk_shift;
108 107
109 Ort::Value x = 108 Ort::Value x =
110 Ort::Value::CreateTensor(memory_info, features.data(), features.size(), 109 Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
111 x_shape.data(), x_shape.size()); 110 x_shape.data(), x_shape.size());
  111 +
112 auto pair = model->RunEncoder(std::move(x), states); 112 auto pair = model->RunEncoder(std::move(x), states);
113 states = std::move(pair.second); 113 states = std::move(pair.second);
114 decoder.Decode(std::move(pair.first), &result); 114 decoder.Decode(std::move(pair.first), &result);
@@ -116,8 +116,8 @@ for a list of pre-trained models to download. @@ -116,8 +116,8 @@ for a list of pre-trained models to download.
116 decoder.StripLeadingBlanks(&result[0]); 116 decoder.StripLeadingBlanks(&result[0]);
117 const auto &hyp = result[0].tokens; 117 const auto &hyp = result[0].tokens;
118 std::string text; 118 std::string text;
119 - for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {  
120 - text += sym[hyp[i]]; 119 + for (auto t : hyp) {
  120 + text += sym[t];
121 } 121 }
122 122
123 fprintf(stderr, "Done!\n"); 123 fprintf(stderr, "Done!\n");