Fangjun Kuang
Committed by GitHub

add online-recognizer (#29)

@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
3 add_executable(sherpa-onnx 3 add_executable(sherpa-onnx
4 features.cc 4 features.cc
5 online-lstm-transducer-model.cc 5 online-lstm-transducer-model.cc
  6 + online-recognizer.cc
6 online-stream.cc 7 online-stream.cc
7 online-transducer-greedy-search-decoder.cc 8 online-transducer-greedy-search-decoder.cc
8 online-transducer-model-config.cc 9 online-transducer-model-config.cc
@@ -7,12 +7,23 @@ @@ -7,12 +7,23 @@
7 #include <algorithm> 7 #include <algorithm>
8 #include <memory> 8 #include <memory>
9 #include <mutex> // NOLINT 9 #include <mutex> // NOLINT
  10 +#include <sstream>
10 #include <vector> 11 #include <vector>
11 12
12 #include "kaldi-native-fbank/csrc/online-feature.h" 13 #include "kaldi-native-fbank/csrc/online-feature.h"
13 14
14 namespace sherpa_onnx { 15 namespace sherpa_onnx {
15 16
  17 +std::string FeatureExtractorConfig::ToString() const {
  18 + std::ostringstream os;
  19 +
  20 + os << "FeatureExtractorConfig(";
  21 + os << "sampling_rate=" << sampling_rate << ", ";
  22 + os << "feature_dim=" << feature_dim << ")";
  23 +
  24 + return os.str();
  25 +}
  26 +
16 class FeatureExtractor::Impl { 27 class FeatureExtractor::Impl {
17 public: 28 public:
18 explicit Impl(const FeatureExtractorConfig &config) { 29 explicit Impl(const FeatureExtractorConfig &config) {
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 #define SHERPA_ONNX_CSRC_FEATURES_H_ 6 #define SHERPA_ONNX_CSRC_FEATURES_H_
7 7
8 #include <memory> 8 #include <memory>
  9 +#include <string>
9 #include <vector> 10 #include <vector>
10 11
11 namespace sherpa_onnx { 12 namespace sherpa_onnx {
@@ -13,6 +14,8 @@ namespace sherpa_onnx { @@ -13,6 +14,8 @@ namespace sherpa_onnx {
13 struct FeatureExtractorConfig { 14 struct FeatureExtractorConfig {
14 float sampling_rate = 16000; 15 float sampling_rate = 16000;
15 int32_t feature_dim = 80; 16 int32_t feature_dim = 80;
  17 +
  18 + std::string ToString() const;
16 }; 19 };
17 20
18 class FeatureExtractor { 21 class FeatureExtractor {
  1 +// sherpa-onnx/csrc/online-recognizer.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-recognizer.h"
  6 +
  7 +#include <assert.h>
  8 +
  9 +#include <memory>
  10 +#include <sstream>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  15 +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
  16 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  17 +#include "sherpa-onnx/csrc/symbol-table.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
  22 + const SymbolTable &sym_table) {
  23 + std::string text;
  24 + for (auto t : src.tokens) {
  25 + text += sym_table[t];
  26 + }
  27 +
  28 + OnlineRecognizerResult ans;
  29 + ans.text = std::move(text);
  30 + return ans;
  31 +}
  32 +
  33 +std::string OnlineRecognizerConfig::ToString() const {
  34 + std::ostringstream os;
  35 +
  36 + os << "OnlineRecognizerConfig(";
  37 + os << "feat_config=" << feat_config.ToString() << ", ";
  38 + os << "model_config=" << model_config.ToString() << ", ";
  39 + os << "tokens=\"" << tokens << "\")";
  40 +
  41 + return os.str();
  42 +}
  43 +
  44 +class OnlineRecognizer::Impl {
  45 + public:
  46 + explicit Impl(const OnlineRecognizerConfig &config)
  47 + : config_(config),
  48 + model_(OnlineTransducerModel::Create(config.model_config)),
  49 + sym_(config.tokens) {
  50 + decoder_ =
  51 + std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
  52 + }
  53 +
  54 + std::unique_ptr<OnlineStream> CreateStream() const {
  55 + auto stream = std::make_unique<OnlineStream>(config_.feat_config);
  56 + stream->SetResult(decoder_->GetEmptyResult());
  57 + stream->SetStates(model_->GetEncoderInitStates());
  58 + return stream;
  59 + }
  60 +
  61 + bool IsReady(OnlineStream *s) const {
  62 + return s->GetNumProcessedFrames() + model_->ChunkSize() <
  63 + s->NumFramesReady();
  64 + }
  65 +
  66 + void DecodeStreams(OnlineStream **ss, int32_t n) {
  67 + if (n != 1) {
  68 + fprintf(stderr, "only n == 1 is implemented\n");
  69 + exit(-1);
  70 + }
  71 + OnlineStream *s = ss[0];
  72 + assert(IsReady(s));
  73 +
  74 + int32_t chunk_size = model_->ChunkSize();
  75 + int32_t chunk_shift = model_->ChunkShift();
  76 +
  77 + int32_t feature_dim = s->FeatureDim();
  78 +
  79 + std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
  80 +
  81 + auto memory_info =
  82 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  83 +
  84 + std::vector<float> features =
  85 + s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
  86 +
  87 + s->GetNumProcessedFrames() += chunk_shift;
  88 +
  89 + Ort::Value x =
  90 + Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
  91 + x_shape.data(), x_shape.size());
  92 +
  93 + auto pair = model_->RunEncoder(std::move(x), s->GetStates());
  94 +
  95 + s->SetStates(std::move(pair.second));
  96 + std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
  97 +
  98 + decoder_->Decode(std::move(pair.first), &results);
  99 + s->SetResult(results[0]);
  100 + }
  101 +
  102 + OnlineRecognizerResult GetResult(OnlineStream *s) {
  103 + OnlineTransducerDecoderResult decoder_result = s->GetResult();
  104 + decoder_->StripLeadingBlanks(&decoder_result);
  105 +
  106 + return Convert(decoder_result, sym_);
  107 + }
  108 +
  109 + private:
  110 + OnlineRecognizerConfig config_;
  111 + std::unique_ptr<OnlineTransducerModel> model_;
  112 + std::unique_ptr<OnlineTransducerDecoder> decoder_;
  113 + SymbolTable sym_;
  114 +};
  115 +
  116 +OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
  117 + : impl_(std::make_unique<Impl>(config)) {}
  118 +OnlineRecognizer::~OnlineRecognizer() = default;
  119 +
  120 +std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
  121 + return impl_->CreateStream();
  122 +}
  123 +
  124 +bool OnlineRecognizer::IsReady(OnlineStream *s) const {
  125 + return impl_->IsReady(s);
  126 +}
  127 +
  128 +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
  129 + impl_->DecodeStreams(ss, n);
  130 +}
  131 +
  132 +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
  133 + return impl_->GetResult(s);
  134 +}
  135 +
  136 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-recognizer.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
  7 +
  8 +#include <memory>
  9 +#include <string>
  10 +
  11 +#include "sherpa-onnx/csrc/features.h"
  12 +#include "sherpa-onnx/csrc/online-stream.h"
  13 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +struct OnlineRecognizerResult {
  18 + std::string text;
  19 +};
  20 +
  21 +struct OnlineRecognizerConfig {
  22 + FeatureExtractorConfig feat_config;
  23 + OnlineTransducerModelConfig model_config;
  24 + std::string tokens;
  25 +
  26 + std::string ToString() const;
  27 +};
  28 +
  29 +class OnlineRecognizer {
  30 + public:
  31 + explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
  32 + ~OnlineRecognizer();
  33 +
  34 + /// Create a stream for decoding.
  35 + std::unique_ptr<OnlineStream> CreateStream() const;
  36 +
  37 + /**
  38 + * Return true if the given stream has enough frames for decoding.
  39 + * Return false otherwise
  40 + */
  41 + bool IsReady(OnlineStream *s) const;
  42 +
  43 + /** Decode a single stream. */
  44 + void DecodeStream(OnlineStream *s) {
  45 + OnlineStream *ss[1] = {s};
  46 + DecodeStreams(ss, 1);
  47 + }
  48 +
  49 + /** Decode multiple streams in parallel
  50 + *
  51 + * @param ss Pointer array containing streams to be decoded.
  52 + * @param n Number of streams in `ss`.
  53 + */
  54 + void DecodeStreams(OnlineStream **ss, int32_t n);
  55 +
  56 + OnlineRecognizerResult GetResult(OnlineStream *s);
  57 +
  58 + private:
  59 + class Impl;
  60 + std::unique_ptr<Impl> impl_;
  61 +};
  62 +
  63 +} // namespace sherpa_onnx
  64 +
  65 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 #include "sherpa-onnx/csrc/online-stream.h" 4 #include "sherpa-onnx/csrc/online-stream.h"
5 5
6 #include <memory> 6 #include <memory>
  7 +#include <utility>
7 #include <vector> 8 #include <vector>
8 9
9 #include "sherpa-onnx/csrc/features.h" 10 #include "sherpa-onnx/csrc/features.h"
@@ -41,10 +42,17 @@ class OnlineStream::Impl { @@ -41,10 +42,17 @@ class OnlineStream::Impl {
41 42
42 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } 43 int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
43 44
  45 + void SetStates(std::vector<Ort::Value> states) {
  46 + states_ = std::move(states);
  47 + }
  48 +
  49 + std::vector<Ort::Value> &GetStates() { return states_; }
  50 +
44 private: 51 private:
45 FeatureExtractor feat_extractor_; 52 FeatureExtractor feat_extractor_;
46 int32_t num_processed_frames_ = 0; // before subsampling 53 int32_t num_processed_frames_ = 0; // before subsampling
47 OnlineTransducerDecoderResult result_; 54 OnlineTransducerDecoderResult result_;
  55 + std::vector<Ort::Value> states_;
48 }; 56 };
49 57
50 OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) 58 OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
@@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { @@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
86 return impl_->GetResult(); 94 return impl_->GetResult();
87 } 95 }
88 96
  97 +void OnlineStream::SetStates(std::vector<Ort::Value> states) {
  98 + impl_->SetStates(std::move(states));
  99 +}
  100 +
  101 +std::vector<Ort::Value> &OnlineStream::GetStates() {
  102 + return impl_->GetStates();
  103 +}
  104 +
89 } // namespace sherpa_onnx 105 } // namespace sherpa_onnx
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <memory> 8 #include <memory>
9 #include <vector> 9 #include <vector>
10 10
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
11 #include "sherpa-onnx/csrc/features.h" 12 #include "sherpa-onnx/csrc/features.h"
12 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 13 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
13 14
@@ -63,6 +64,9 @@ class OnlineStream { @@ -63,6 +64,9 @@ class OnlineStream {
63 void SetResult(const OnlineTransducerDecoderResult &r); 64 void SetResult(const OnlineTransducerDecoderResult &r);
64 const OnlineTransducerDecoderResult &GetResult() const; 65 const OnlineTransducerDecoderResult &GetResult() const;
65 66
  67 + void SetStates(std::vector<Ort::Value> states);
  68 + std::vector<Ort::Value> &GetStates();
  69 +
66 private: 70 private:
67 class Impl; 71 class Impl;
68 std::unique_ptr<Impl> impl_; 72 std::unique_ptr<Impl> impl_;
@@ -26,13 +26,14 @@ class OnlineTransducerDecoder { @@ -26,13 +26,14 @@ class OnlineTransducerDecoder {
26 * to the beginning of the decoding result, which will be 26 * to the beginning of the decoding result, which will be
27 * stripped by calling `StripPrecedingBlanks()`. 27 * stripped by calling `StripPrecedingBlanks()`.
28 */ 28 */
29 - virtual OnlineTransducerDecoderResult GetEmptyResult() = 0; 29 + virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0;
30 30
31 /** Strip blanks added by `GetEmptyResult()`. 31 /** Strip blanks added by `GetEmptyResult()`.
32 * 32 *
33 * @param r It is changed in-place. 33 * @param r It is changed in-place.
34 */ 34 */
35 - virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {} 35 + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {
  36 + }
36 37
37 /** Run transducer beam search given the output from the encoder model. 38 /** Run transducer beam search given the output from the encoder model.
38 * 39 *
@@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { @@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
33 } 33 }
34 34
35 OnlineTransducerDecoderResult 35 OnlineTransducerDecoderResult
36 -OnlineTransducerGreedySearchDecoder::GetEmptyResult() { 36 +OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
37 int32_t context_size = model_->ContextSize(); 37 int32_t context_size = model_->ContextSize();
38 int32_t blank_id = 0; // always 0 38 int32_t blank_id = 0; // always 0
39 OnlineTransducerDecoderResult r; 39 OnlineTransducerDecoderResult r;
@@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() { @@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
43 } 43 }
44 44
45 void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( 45 void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
46 - OnlineTransducerDecoderResult *r) { 46 + OnlineTransducerDecoderResult *r) const {
47 int32_t context_size = model_->ContextSize(); 47 int32_t context_size = model_->ContextSize();
48 48
49 auto start = r->tokens.begin() + context_size; 49 auto start = r->tokens.begin() + context_size;
@@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { @@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
17 explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) 17 explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
18 : model_(model) {} 18 : model_(model) {}
19 19
20 - OnlineTransducerDecoderResult GetEmptyResult() override; 20 + OnlineTransducerDecoderResult GetEmptyResult() const override;
21 21
22 - void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override; 22 + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override;
23 23
24 void Decode(Ort::Value encoder_out, 24 void Decode(Ort::Value encoder_out,
25 std::vector<OnlineTransducerDecoderResult> *result) override; 25 std::vector<OnlineTransducerDecoderResult> *result) override;
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <string> 8 #include <string>
9 #include <vector> 9 #include <vector>
10 10
  11 +#include "sherpa-onnx/csrc/online-recognizer.h"
11 #include "sherpa-onnx/csrc/online-stream.h" 12 #include "sherpa-onnx/csrc/online-stream.h"
12 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 13 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
13 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 14 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
@@ -35,35 +36,26 @@ for a list of pre-trained models to download. @@ -35,35 +36,26 @@ for a list of pre-trained models to download.
35 return 0; 36 return 0;
36 } 37 }
37 38
38 - std::string tokens = argv[1];  
39 - sherpa_onnx::OnlineTransducerModelConfig config;  
40 - config.debug = false;  
41 - config.encoder_filename = argv[2];  
42 - config.decoder_filename = argv[3];  
43 - config.joiner_filename = argv[4]; 39 + sherpa_onnx::OnlineRecognizerConfig config;
  40 +
  41 + config.tokens = argv[1];
  42 +
  43 + config.model_config.debug = false;
  44 + config.model_config.encoder_filename = argv[2];
  45 + config.model_config.decoder_filename = argv[3];
  46 + config.model_config.joiner_filename = argv[4];
  47 +
44 std::string wav_filename = argv[5]; 48 std::string wav_filename = argv[5];
45 49
46 - config.num_threads = 2; 50 + config.model_config.num_threads = 2;
47 if (argc == 7) { 51 if (argc == 7) {
48 - config.num_threads = atoi(argv[6]); 52 + config.model_config.num_threads = atoi(argv[6]);
49 } 53 }
50 fprintf(stderr, "%s\n", config.ToString().c_str()); 54 fprintf(stderr, "%s\n", config.ToString().c_str());
51 55
52 - auto model = sherpa_onnx::OnlineTransducerModel::Create(config);  
53 -  
54 - sherpa_onnx::SymbolTable sym(tokens);  
55 -  
56 - Ort::AllocatorWithDefaultOptions allocator;  
57 -  
58 - int32_t chunk_size = model->ChunkSize();  
59 - int32_t chunk_shift = model->ChunkShift(); 56 + sherpa_onnx::OnlineRecognizer recognizer(config);
60 57
61 - auto memory_info =  
62 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
63 -  
64 - std::vector<Ort::Value> states = model->GetEncoderInitStates();  
65 -  
66 - float expected_sampling_rate = 16000; 58 + float expected_sampling_rate = config.feat_config.sampling_rate;
67 59
68 bool is_ok = false; 60 bool is_ok = false;
69 std::vector<float> samples = 61 std::vector<float> samples =
@@ -82,44 +74,21 @@ for a list of pre-trained models to download. @@ -82,44 +74,21 @@ for a list of pre-trained models to download.
82 auto begin = std::chrono::steady_clock::now(); 74 auto begin = std::chrono::steady_clock::now();
83 fprintf(stderr, "Started\n"); 75 fprintf(stderr, "Started\n");
84 76
85 - sherpa_onnx::OnlineStream stream;  
86 - stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); 77 + auto s = recognizer.CreateStream();
  78 + s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
87 79
88 std::vector<float> tail_paddings( 80 std::vector<float> tail_paddings(
89 static_cast<int>(0.2 * expected_sampling_rate)); 81 static_cast<int>(0.2 * expected_sampling_rate));
90 - stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),  
91 - tail_paddings.size());  
92 - stream.InputFinished();  
93 -  
94 - int32_t num_frames = stream.NumFramesReady();  
95 - int32_t feature_dim = stream.FeatureDim();  
96 -  
97 - std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};  
98 -  
99 - sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());  
100 - std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {  
101 - decoder.GetEmptyResult()};  
102 - while (stream.NumFramesReady() - stream.GetNumProcessedFrames() >  
103 - chunk_size) {  
104 - std::vector<float> features =  
105 - stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size);  
106 - stream.GetNumProcessedFrames() += chunk_shift;  
107 -  
108 - Ort::Value x =  
109 - Ort::Value::CreateTensor(memory_info, features.data(), features.size(),  
110 - x_shape.data(), x_shape.size());  
111 -  
112 - auto pair = model->RunEncoder(std::move(x), states);  
113 - states = std::move(pair.second);  
114 - decoder.Decode(std::move(pair.first), &result);  
115 - }  
116 - decoder.StripLeadingBlanks(&result[0]);  
117 - const auto &hyp = result[0].tokens;  
118 - std::string text;  
119 - for (auto t : hyp) {  
120 - text += sym[t]; 82 + s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
  83 + tail_paddings.size());
  84 + s->InputFinished();
  85 +
  86 + while (recognizer.IsReady(s.get())) {
  87 + recognizer.DecodeStream(s.get());
121 } 88 }
122 89
  90 + std::string text = recognizer.GetResult(s.get()).text;
  91 +
123 fprintf(stderr, "Done!\n"); 92 fprintf(stderr, "Done!\n");
124 93
125 fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), 94 fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
@@ -131,7 +100,7 @@ for a list of pre-trained models to download. @@ -131,7 +100,7 @@ for a list of pre-trained models to download.
131 .count() / 100 .count() /
132 1000.; 101 1000.;
133 102
134 - fprintf(stderr, "num threads: %d\n", config.num_threads); 103 + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
135 104
136 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); 105 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
137 float rtf = elapsed_seconds / duration; 106 float rtf = elapsed_seconds / duration;