正在显示
11 个修改的文件
包含
268 行增加
和
62 行删除
| @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 3 | add_executable(sherpa-onnx | 3 | add_executable(sherpa-onnx |
| 4 | features.cc | 4 | features.cc |
| 5 | online-lstm-transducer-model.cc | 5 | online-lstm-transducer-model.cc |
| 6 | + online-recognizer.cc | ||
| 6 | online-stream.cc | 7 | online-stream.cc |
| 7 | online-transducer-greedy-search-decoder.cc | 8 | online-transducer-greedy-search-decoder.cc |
| 8 | online-transducer-model-config.cc | 9 | online-transducer-model-config.cc |
| @@ -7,12 +7,23 @@ | @@ -7,12 +7,23 @@ | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <mutex> // NOLINT | 9 | #include <mutex> // NOLINT |
| 10 | +#include <sstream> | ||
| 10 | #include <vector> | 11 | #include <vector> |
| 11 | 12 | ||
| 12 | #include "kaldi-native-fbank/csrc/online-feature.h" | 13 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| 13 | 14 | ||
| 14 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| 15 | 16 | ||
| 17 | +std::string FeatureExtractorConfig::ToString() const { | ||
| 18 | + std::ostringstream os; | ||
| 19 | + | ||
| 20 | + os << "FeatureExtractorConfig("; | ||
| 21 | + os << "sampling_rate=" << sampling_rate << ", "; | ||
| 22 | + os << "feature_dim=" << feature_dim << ")"; | ||
| 23 | + | ||
| 24 | + return os.str(); | ||
| 25 | +} | ||
| 26 | + | ||
| 16 | class FeatureExtractor::Impl { | 27 | class FeatureExtractor::Impl { |
| 17 | public: | 28 | public: |
| 18 | explicit Impl(const FeatureExtractorConfig &config) { | 29 | explicit Impl(const FeatureExtractorConfig &config) { |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_FEATURES_H_ | 6 | #define SHERPA_ONNX_CSRC_FEATURES_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <string> | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 11 | namespace sherpa_onnx { | 12 | namespace sherpa_onnx { |
| @@ -13,6 +14,8 @@ namespace sherpa_onnx { | @@ -13,6 +14,8 @@ namespace sherpa_onnx { | ||
| 13 | struct FeatureExtractorConfig { | 14 | struct FeatureExtractorConfig { |
| 14 | float sampling_rate = 16000; | 15 | float sampling_rate = 16000; |
| 15 | int32_t feature_dim = 80; | 16 | int32_t feature_dim = 80; |
| 17 | + | ||
| 18 | + std::string ToString() const; | ||
| 16 | }; | 19 | }; |
| 17 | 20 | ||
| 18 | class FeatureExtractor { | 21 | class FeatureExtractor { |
sherpa-onnx/csrc/online-recognizer.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-recognizer.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-recognizer.h" | ||
| 6 | + | ||
| 7 | +#include <assert.h> | ||
| 8 | + | ||
| 9 | +#include <memory> | ||
| 10 | +#include <sstream> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 15 | +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | ||
| 16 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 17 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, | ||
| 22 | + const SymbolTable &sym_table) { | ||
| 23 | + std::string text; | ||
| 24 | + for (auto t : src.tokens) { | ||
| 25 | + text += sym_table[t]; | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + OnlineRecognizerResult ans; | ||
| 29 | + ans.text = std::move(text); | ||
| 30 | + return ans; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +std::string OnlineRecognizerConfig::ToString() const { | ||
| 34 | + std::ostringstream os; | ||
| 35 | + | ||
| 36 | + os << "OnlineRecognizerConfig("; | ||
| 37 | + os << "feat_config=" << feat_config.ToString() << ", "; | ||
| 38 | + os << "model_config=" << model_config.ToString() << ", "; | ||
| 39 | + os << "tokens=\"" << tokens << "\")"; | ||
| 40 | + | ||
| 41 | + return os.str(); | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +class OnlineRecognizer::Impl { | ||
| 45 | + public: | ||
| 46 | + explicit Impl(const OnlineRecognizerConfig &config) | ||
| 47 | + : config_(config), | ||
| 48 | + model_(OnlineTransducerModel::Create(config.model_config)), | ||
| 49 | + sym_(config.tokens) { | ||
| 50 | + decoder_ = | ||
| 51 | + std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | ||
| 52 | + } | ||
| 53 | + | ||
| 54 | + std::unique_ptr<OnlineStream> CreateStream() const { | ||
| 55 | + auto stream = std::make_unique<OnlineStream>(config_.feat_config); | ||
| 56 | + stream->SetResult(decoder_->GetEmptyResult()); | ||
| 57 | + stream->SetStates(model_->GetEncoderInitStates()); | ||
| 58 | + return stream; | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | + bool IsReady(OnlineStream *s) const { | ||
| 62 | + return s->GetNumProcessedFrames() + model_->ChunkSize() < | ||
| 63 | + s->NumFramesReady(); | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + void DecodeStreams(OnlineStream **ss, int32_t n) { | ||
| 67 | + if (n != 1) { | ||
| 68 | + fprintf(stderr, "only n == 1 is implemented\n"); | ||
| 69 | + exit(-1); | ||
| 70 | + } | ||
| 71 | + OnlineStream *s = ss[0]; | ||
| 72 | + assert(IsReady(s)); | ||
| 73 | + | ||
| 74 | + int32_t chunk_size = model_->ChunkSize(); | ||
| 75 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 76 | + | ||
| 77 | + int32_t feature_dim = s->FeatureDim(); | ||
| 78 | + | ||
| 79 | + std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; | ||
| 80 | + | ||
| 81 | + auto memory_info = | ||
| 82 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 83 | + | ||
| 84 | + std::vector<float> features = | ||
| 85 | + s->GetFrames(s->GetNumProcessedFrames(), chunk_size); | ||
| 86 | + | ||
| 87 | + s->GetNumProcessedFrames() += chunk_shift; | ||
| 88 | + | ||
| 89 | + Ort::Value x = | ||
| 90 | + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
| 91 | + x_shape.data(), x_shape.size()); | ||
| 92 | + | ||
| 93 | + auto pair = model_->RunEncoder(std::move(x), s->GetStates()); | ||
| 94 | + | ||
| 95 | + s->SetStates(std::move(pair.second)); | ||
| 96 | + std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()}; | ||
| 97 | + | ||
| 98 | + decoder_->Decode(std::move(pair.first), &results); | ||
| 99 | + s->SetResult(results[0]); | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + OnlineRecognizerResult GetResult(OnlineStream *s) { | ||
| 103 | + OnlineTransducerDecoderResult decoder_result = s->GetResult(); | ||
| 104 | + decoder_->StripLeadingBlanks(&decoder_result); | ||
| 105 | + | ||
| 106 | + return Convert(decoder_result, sym_); | ||
| 107 | + } | ||
| 108 | + | ||
| 109 | + private: | ||
| 110 | + OnlineRecognizerConfig config_; | ||
| 111 | + std::unique_ptr<OnlineTransducerModel> model_; | ||
| 112 | + std::unique_ptr<OnlineTransducerDecoder> decoder_; | ||
| 113 | + SymbolTable sym_; | ||
| 114 | +}; | ||
| 115 | + | ||
| 116 | +OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) | ||
| 117 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 118 | +OnlineRecognizer::~OnlineRecognizer() = default; | ||
| 119 | + | ||
| 120 | +std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const { | ||
| 121 | + return impl_->CreateStream(); | ||
| 122 | +} | ||
| 123 | + | ||
| 124 | +bool OnlineRecognizer::IsReady(OnlineStream *s) const { | ||
| 125 | + return impl_->IsReady(s); | ||
| 126 | +} | ||
| 127 | + | ||
| 128 | +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) { | ||
| 129 | + impl_->DecodeStreams(ss, n); | ||
| 130 | +} | ||
| 131 | + | ||
| 132 | +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) { | ||
| 133 | + return impl_->GetResult(s); | ||
| 134 | +} | ||
| 135 | + | ||
| 136 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-recognizer.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-recognizer.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/features.h" | ||
| 12 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 13 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +struct OnlineRecognizerResult { | ||
| 18 | + std::string text; | ||
| 19 | +}; | ||
| 20 | + | ||
| 21 | +struct OnlineRecognizerConfig { | ||
| 22 | + FeatureExtractorConfig feat_config; | ||
| 23 | + OnlineTransducerModelConfig model_config; | ||
| 24 | + std::string tokens; | ||
| 25 | + | ||
| 26 | + std::string ToString() const; | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +class OnlineRecognizer { | ||
| 30 | + public: | ||
| 31 | + explicit OnlineRecognizer(const OnlineRecognizerConfig &config); | ||
| 32 | + ~OnlineRecognizer(); | ||
| 33 | + | ||
| 34 | + /// Create a stream for decoding. | ||
| 35 | + std::unique_ptr<OnlineStream> CreateStream() const; | ||
| 36 | + | ||
| 37 | + /** | ||
| 38 | + * Return true if the given stream has enough frames for decoding. | ||
| 39 | + * Return false otherwise | ||
| 40 | + */ | ||
| 41 | + bool IsReady(OnlineStream *s) const; | ||
| 42 | + | ||
| 43 | + /** Decode a single stream. */ | ||
| 44 | + void DecodeStream(OnlineStream *s) { | ||
| 45 | + OnlineStream *ss[1] = {s}; | ||
| 46 | + DecodeStreams(ss, 1); | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + /** Decode multiple streams in parallel | ||
| 50 | + * | ||
| 51 | + * @param ss Pointer array containing streams to be decoded. | ||
| 52 | + * @param n Number of streams in `ss`. | ||
| 53 | + */ | ||
| 54 | + void DecodeStreams(OnlineStream **ss, int32_t n); | ||
| 55 | + | ||
| 56 | + OnlineRecognizerResult GetResult(OnlineStream *s); | ||
| 57 | + | ||
| 58 | + private: | ||
| 59 | + class Impl; | ||
| 60 | + std::unique_ptr<Impl> impl_; | ||
| 61 | +}; | ||
| 62 | + | ||
| 63 | +} // namespace sherpa_onnx | ||
| 64 | + | ||
| 65 | +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | #include "sherpa-onnx/csrc/online-stream.h" | 4 | #include "sherpa-onnx/csrc/online-stream.h" |
| 5 | 5 | ||
| 6 | #include <memory> | 6 | #include <memory> |
| 7 | +#include <utility> | ||
| 7 | #include <vector> | 8 | #include <vector> |
| 8 | 9 | ||
| 9 | #include "sherpa-onnx/csrc/features.h" | 10 | #include "sherpa-onnx/csrc/features.h" |
| @@ -41,10 +42,17 @@ class OnlineStream::Impl { | @@ -41,10 +42,17 @@ class OnlineStream::Impl { | ||
| 41 | 42 | ||
| 42 | int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } | 43 | int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } |
| 43 | 44 | ||
| 45 | + void SetStates(std::vector<Ort::Value> states) { | ||
| 46 | + states_ = std::move(states); | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + std::vector<Ort::Value> &GetStates() { return states_; } | ||
| 50 | + | ||
| 44 | private: | 51 | private: |
| 45 | FeatureExtractor feat_extractor_; | 52 | FeatureExtractor feat_extractor_; |
| 46 | int32_t num_processed_frames_ = 0; // before subsampling | 53 | int32_t num_processed_frames_ = 0; // before subsampling |
| 47 | OnlineTransducerDecoderResult result_; | 54 | OnlineTransducerDecoderResult result_; |
| 55 | + std::vector<Ort::Value> states_; | ||
| 48 | }; | 56 | }; |
| 49 | 57 | ||
| 50 | OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | 58 | OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) |
| @@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { | @@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { | ||
| 86 | return impl_->GetResult(); | 94 | return impl_->GetResult(); |
| 87 | } | 95 | } |
| 88 | 96 | ||
| 97 | +void OnlineStream::SetStates(std::vector<Ort::Value> states) { | ||
| 98 | + impl_->SetStates(std::move(states)); | ||
| 99 | +} | ||
| 100 | + | ||
| 101 | +std::vector<Ort::Value> &OnlineStream::GetStates() { | ||
| 102 | + return impl_->GetStates(); | ||
| 103 | +} | ||
| 104 | + | ||
| 89 | } // namespace sherpa_onnx | 105 | } // namespace sherpa_onnx |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | #include "sherpa-onnx/csrc/features.h" | 12 | #include "sherpa-onnx/csrc/features.h" |
| 12 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 13 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 13 | 14 | ||
| @@ -63,6 +64,9 @@ class OnlineStream { | @@ -63,6 +64,9 @@ class OnlineStream { | ||
| 63 | void SetResult(const OnlineTransducerDecoderResult &r); | 64 | void SetResult(const OnlineTransducerDecoderResult &r); |
| 64 | const OnlineTransducerDecoderResult &GetResult() const; | 65 | const OnlineTransducerDecoderResult &GetResult() const; |
| 65 | 66 | ||
| 67 | + void SetStates(std::vector<Ort::Value> states); | ||
| 68 | + std::vector<Ort::Value> &GetStates(); | ||
| 69 | + | ||
| 66 | private: | 70 | private: |
| 67 | class Impl; | 71 | class Impl; |
| 68 | std::unique_ptr<Impl> impl_; | 72 | std::unique_ptr<Impl> impl_; |
| @@ -26,13 +26,14 @@ class OnlineTransducerDecoder { | @@ -26,13 +26,14 @@ class OnlineTransducerDecoder { | ||
| 26 | * to the beginning of the decoding result, which will be | 26 | * to the beginning of the decoding result, which will be |
| 27 | * stripped by calling `StripPrecedingBlanks()`. | 27 | * stripped by calling `StripPrecedingBlanks()`. |
| 28 | */ | 28 | */ |
| 29 | - virtual OnlineTransducerDecoderResult GetEmptyResult() = 0; | 29 | + virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0; |
| 30 | 30 | ||
| 31 | /** Strip blanks added by `GetEmptyResult()`. | 31 | /** Strip blanks added by `GetEmptyResult()`. |
| 32 | * | 32 | * |
| 33 | * @param r It is changed in-place. | 33 | * @param r It is changed in-place. |
| 34 | */ | 34 | */ |
| 35 | - virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {} | 35 | + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const { |
| 36 | + } | ||
| 36 | 37 | ||
| 37 | /** Run transducer beam search given the output from the encoder model. | 38 | /** Run transducer beam search given the output from the encoder model. |
| 38 | * | 39 | * |
| @@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | @@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | ||
| 33 | } | 33 | } |
| 34 | 34 | ||
| 35 | OnlineTransducerDecoderResult | 35 | OnlineTransducerDecoderResult |
| 36 | -OnlineTransducerGreedySearchDecoder::GetEmptyResult() { | 36 | +OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { |
| 37 | int32_t context_size = model_->ContextSize(); | 37 | int32_t context_size = model_->ContextSize(); |
| 38 | int32_t blank_id = 0; // always 0 | 38 | int32_t blank_id = 0; // always 0 |
| 39 | OnlineTransducerDecoderResult r; | 39 | OnlineTransducerDecoderResult r; |
| @@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() { | @@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() { | ||
| 43 | } | 43 | } |
| 44 | 44 | ||
| 45 | void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | 45 | void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( |
| 46 | - OnlineTransducerDecoderResult *r) { | 46 | + OnlineTransducerDecoderResult *r) const { |
| 47 | int32_t context_size = model_->ContextSize(); | 47 | int32_t context_size = model_->ContextSize(); |
| 48 | 48 | ||
| 49 | auto start = r->tokens.begin() + context_size; | 49 | auto start = r->tokens.begin() + context_size; |
| @@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | @@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | ||
| 17 | explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) | 17 | explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) |
| 18 | : model_(model) {} | 18 | : model_(model) {} |
| 19 | 19 | ||
| 20 | - OnlineTransducerDecoderResult GetEmptyResult() override; | 20 | + OnlineTransducerDecoderResult GetEmptyResult() const override; |
| 21 | 21 | ||
| 22 | - void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override; | 22 | + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override; |
| 23 | 23 | ||
| 24 | void Decode(Ort::Value encoder_out, | 24 | void Decode(Ort::Value encoder_out, |
| 25 | std::vector<OnlineTransducerDecoderResult> *result) override; | 25 | std::vector<OnlineTransducerDecoderResult> *result) override; |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#include "sherpa-onnx/csrc/online-recognizer.h" | ||
| 11 | #include "sherpa-onnx/csrc/online-stream.h" | 12 | #include "sherpa-onnx/csrc/online-stream.h" |
| 12 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 13 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 13 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 14 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| @@ -35,35 +36,26 @@ for a list of pre-trained models to download. | @@ -35,35 +36,26 @@ for a list of pre-trained models to download. | ||
| 35 | return 0; | 36 | return 0; |
| 36 | } | 37 | } |
| 37 | 38 | ||
| 38 | - std::string tokens = argv[1]; | ||
| 39 | - sherpa_onnx::OnlineTransducerModelConfig config; | ||
| 40 | - config.debug = false; | ||
| 41 | - config.encoder_filename = argv[2]; | ||
| 42 | - config.decoder_filename = argv[3]; | ||
| 43 | - config.joiner_filename = argv[4]; | 39 | + sherpa_onnx::OnlineRecognizerConfig config; |
| 40 | + | ||
| 41 | + config.tokens = argv[1]; | ||
| 42 | + | ||
| 43 | + config.model_config.debug = false; | ||
| 44 | + config.model_config.encoder_filename = argv[2]; | ||
| 45 | + config.model_config.decoder_filename = argv[3]; | ||
| 46 | + config.model_config.joiner_filename = argv[4]; | ||
| 47 | + | ||
| 44 | std::string wav_filename = argv[5]; | 48 | std::string wav_filename = argv[5]; |
| 45 | 49 | ||
| 46 | - config.num_threads = 2; | 50 | + config.model_config.num_threads = 2; |
| 47 | if (argc == 7) { | 51 | if (argc == 7) { |
| 48 | - config.num_threads = atoi(argv[6]); | 52 | + config.model_config.num_threads = atoi(argv[6]); |
| 49 | } | 53 | } |
| 50 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 54 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 51 | 55 | ||
| 52 | - auto model = sherpa_onnx::OnlineTransducerModel::Create(config); | ||
| 53 | - | ||
| 54 | - sherpa_onnx::SymbolTable sym(tokens); | ||
| 55 | - | ||
| 56 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 57 | - | ||
| 58 | - int32_t chunk_size = model->ChunkSize(); | ||
| 59 | - int32_t chunk_shift = model->ChunkShift(); | 56 | + sherpa_onnx::OnlineRecognizer recognizer(config); |
| 60 | 57 | ||
| 61 | - auto memory_info = | ||
| 62 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 63 | - | ||
| 64 | - std::vector<Ort::Value> states = model->GetEncoderInitStates(); | ||
| 65 | - | ||
| 66 | - float expected_sampling_rate = 16000; | 58 | + float expected_sampling_rate = config.feat_config.sampling_rate; |
| 67 | 59 | ||
| 68 | bool is_ok = false; | 60 | bool is_ok = false; |
| 69 | std::vector<float> samples = | 61 | std::vector<float> samples = |
| @@ -82,44 +74,21 @@ for a list of pre-trained models to download. | @@ -82,44 +74,21 @@ for a list of pre-trained models to download. | ||
| 82 | auto begin = std::chrono::steady_clock::now(); | 74 | auto begin = std::chrono::steady_clock::now(); |
| 83 | fprintf(stderr, "Started\n"); | 75 | fprintf(stderr, "Started\n"); |
| 84 | 76 | ||
| 85 | - sherpa_onnx::OnlineStream stream; | ||
| 86 | - stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); | 77 | + auto s = recognizer.CreateStream(); |
| 78 | + s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); | ||
| 87 | 79 | ||
| 88 | std::vector<float> tail_paddings( | 80 | std::vector<float> tail_paddings( |
| 89 | static_cast<int>(0.2 * expected_sampling_rate)); | 81 | static_cast<int>(0.2 * expected_sampling_rate)); |
| 90 | - stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), | ||
| 91 | - tail_paddings.size()); | ||
| 92 | - stream.InputFinished(); | ||
| 93 | - | ||
| 94 | - int32_t num_frames = stream.NumFramesReady(); | ||
| 95 | - int32_t feature_dim = stream.FeatureDim(); | ||
| 96 | - | ||
| 97 | - std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; | ||
| 98 | - | ||
| 99 | - sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); | ||
| 100 | - std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = { | ||
| 101 | - decoder.GetEmptyResult()}; | ||
| 102 | - while (stream.NumFramesReady() - stream.GetNumProcessedFrames() > | ||
| 103 | - chunk_size) { | ||
| 104 | - std::vector<float> features = | ||
| 105 | - stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size); | ||
| 106 | - stream.GetNumProcessedFrames() += chunk_shift; | ||
| 107 | - | ||
| 108 | - Ort::Value x = | ||
| 109 | - Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | ||
| 110 | - x_shape.data(), x_shape.size()); | ||
| 111 | - | ||
| 112 | - auto pair = model->RunEncoder(std::move(x), states); | ||
| 113 | - states = std::move(pair.second); | ||
| 114 | - decoder.Decode(std::move(pair.first), &result); | ||
| 115 | - } | ||
| 116 | - decoder.StripLeadingBlanks(&result[0]); | ||
| 117 | - const auto &hyp = result[0].tokens; | ||
| 118 | - std::string text; | ||
| 119 | - for (auto t : hyp) { | ||
| 120 | - text += sym[t]; | 82 | + s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(), |
| 83 | + tail_paddings.size()); | ||
| 84 | + s->InputFinished(); | ||
| 85 | + | ||
| 86 | + while (recognizer.IsReady(s.get())) { | ||
| 87 | + recognizer.DecodeStream(s.get()); | ||
| 121 | } | 88 | } |
| 122 | 89 | ||
| 90 | + std::string text = recognizer.GetResult(s.get()).text; | ||
| 91 | + | ||
| 123 | fprintf(stderr, "Done!\n"); | 92 | fprintf(stderr, "Done!\n"); |
| 124 | 93 | ||
| 125 | fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), | 94 | fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), |
| @@ -131,7 +100,7 @@ for a list of pre-trained models to download. | @@ -131,7 +100,7 @@ for a list of pre-trained models to download. | ||
| 131 | .count() / | 100 | .count() / |
| 132 | 1000.; | 101 | 1000.; |
| 133 | 102 | ||
| 134 | - fprintf(stderr, "num threads: %d\n", config.num_threads); | 103 | + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); |
| 135 | 104 | ||
| 136 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | 105 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); |
| 137 | float rtf = elapsed_seconds / duration; | 106 | float rtf = elapsed_seconds / duration; |
-
请 注册 或 登录 后发表评论