正在显示
5 个修改的文件
包含
191 行增加
和
29 行删除
| @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) | @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) | ||
| 3 | add_executable(sherpa-onnx | 3 | add_executable(sherpa-onnx |
| 4 | features.cc | 4 | features.cc |
| 5 | online-lstm-transducer-model.cc | 5 | online-lstm-transducer-model.cc |
| 6 | + online-stream.cc | ||
| 6 | online-transducer-greedy-search-decoder.cc | 7 | online-transducer-greedy-search-decoder.cc |
| 7 | online-transducer-model-config.cc | 8 | online-transducer-model-config.cc |
| 8 | online-transducer-model.cc | 9 | online-transducer-model.cc |
| @@ -11,16 +11,12 @@ | @@ -11,16 +11,12 @@ | ||
| 11 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 12 | 12 | ||
| 13 | struct FeatureExtractorConfig { | 13 | struct FeatureExtractorConfig { |
| 14 | - int32_t sampling_rate = 16000; | 14 | + float sampling_rate = 16000; |
| 15 | int32_t feature_dim = 80; | 15 | int32_t feature_dim = 80; |
| 16 | }; | 16 | }; |
| 17 | 17 | ||
| 18 | class FeatureExtractor { | 18 | class FeatureExtractor { |
| 19 | public: | 19 | public: |
| 20 | - /** | ||
| 21 | - * @param sampling_rate Sampling rate of the data used to train the model. | ||
| 22 | - * @param feature_dim Dimension of the features used to train the model. | ||
| 23 | - */ | ||
| 24 | explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); | 20 | explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); |
| 25 | ~FeatureExtractor(); | 21 | ~FeatureExtractor(); |
| 26 | 22 | ||
| @@ -32,16 +28,19 @@ class FeatureExtractor { | @@ -32,16 +28,19 @@ class FeatureExtractor { | ||
| 32 | */ | 28 | */ |
| 33 | void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); | 29 | void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); |
| 34 | 30 | ||
| 35 | - // InputFinished() tells the class you won't be providing any | ||
| 36 | - // more waveform. This will help flush out the last frame or two | ||
| 37 | - // of features, in the case where snip-edges == false; it also | ||
| 38 | - // affects the return value of IsLastFrame(). | 31 | + /** |
| 32 | + * InputFinished() tells the class you won't be providing any | ||
| 33 | + * more waveform. This will help flush out the last frame or two | ||
| 34 | + * of features, in the case where snip-edges == false; it also | ||
| 35 | + * affects the return value of IsLastFrame(). | ||
| 36 | + */ | ||
| 39 | void InputFinished(); | 37 | void InputFinished(); |
| 40 | 38 | ||
| 41 | int32_t NumFramesReady() const; | 39 | int32_t NumFramesReady() const; |
| 42 | 40 | ||
| 43 | - // Note: IsLastFrame() will only ever return true if you have called | ||
| 44 | - // InputFinished() (and this frame is the last frame). | 41 | + /** Note: IsLastFrame() will only ever return true if you have called |
| 42 | + * InputFinished() (and this frame is the last frame). | ||
| 43 | + */ | ||
| 45 | bool IsLastFrame(int32_t frame) const; | 44 | bool IsLastFrame(int32_t frame) const; |
| 46 | 45 | ||
| 47 | /** Get n frames starting from the given frame index. | 46 | /** Get n frames starting from the given frame index. |
sherpa-onnx/csrc/online-stream.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-stream.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 5 | + | ||
| 6 | +#include <memory> | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/features.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +class OnlineStream::Impl { | ||
| 14 | + public: | ||
| 15 | + explicit Impl(const FeatureExtractorConfig &config) | ||
| 16 | + : feat_extractor_(config) {} | ||
| 17 | + | ||
| 18 | + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { | ||
| 19 | + feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); | ||
| 20 | + } | ||
| 21 | + | ||
| 22 | + void InputFinished() { feat_extractor_.InputFinished(); } | ||
| 23 | + | ||
| 24 | + int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); } | ||
| 25 | + | ||
| 26 | + bool IsLastFrame(int32_t frame) const { | ||
| 27 | + return feat_extractor_.IsLastFrame(frame); | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { | ||
| 31 | + return feat_extractor_.GetFrames(frame_index, n); | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + void Reset() { feat_extractor_.Reset(); } | ||
| 35 | + | ||
| 36 | + int32_t &GetNumProcessedFrames() { return num_processed_frames_; } | ||
| 37 | + | ||
| 38 | + void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } | ||
| 39 | + | ||
| 40 | + const OnlineTransducerDecoderResult &GetResult() const { return result_; } | ||
| 41 | + | ||
| 42 | + int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } | ||
| 43 | + | ||
| 44 | + private: | ||
| 45 | + FeatureExtractor feat_extractor_; | ||
| 46 | + int32_t num_processed_frames_ = 0; // before subsampling | ||
| 47 | + OnlineTransducerDecoderResult result_; | ||
| 48 | +}; | ||
| 49 | + | ||
| 50 | +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | ||
| 51 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 52 | + | ||
| 53 | +OnlineStream::~OnlineStream() = default; | ||
| 54 | + | ||
| 55 | +void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform, | ||
| 56 | + int32_t n) { | ||
| 57 | + impl_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +void OnlineStream::InputFinished() { impl_->InputFinished(); } | ||
| 61 | + | ||
| 62 | +int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } | ||
| 63 | + | ||
| 64 | +bool OnlineStream::IsLastFrame(int32_t frame) const { | ||
| 65 | + return impl_->IsLastFrame(frame); | ||
| 66 | +} | ||
| 67 | + | ||
| 68 | +std::vector<float> OnlineStream::GetFrames(int32_t frame_index, | ||
| 69 | + int32_t n) const { | ||
| 70 | + return impl_->GetFrames(frame_index, n); | ||
| 71 | +} | ||
| 72 | + | ||
| 73 | +void OnlineStream::Reset() { impl_->Reset(); } | ||
| 74 | + | ||
| 75 | +int32_t OnlineStream::FeatureDim() const { return impl_->FeatureDim(); } | ||
| 76 | + | ||
| 77 | +int32_t &OnlineStream::GetNumProcessedFrames() { | ||
| 78 | + return impl_->GetNumProcessedFrames(); | ||
| 79 | +} | ||
| 80 | + | ||
| 81 | +void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { | ||
| 82 | + impl_->SetResult(r); | ||
| 83 | +} | ||
| 84 | + | ||
| 85 | +const OnlineTransducerDecoderResult &OnlineStream::GetResult() const { | ||
| 86 | + return impl_->GetResult(); | ||
| 87 | +} | ||
| 88 | + | ||
| 89 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-stream.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-stream.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/features.h" | ||
| 12 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OnlineStream { | ||
| 17 | + public: | ||
| 18 | + explicit OnlineStream(const FeatureExtractorConfig &config = {}); | ||
| 19 | + ~OnlineStream(); | ||
| 20 | + | ||
| 21 | + /** | ||
| 22 | + @param sampling_rate The sampling_rate of the input waveform. Should match | ||
| 23 | + the one expected by the feature extractor. | ||
| 24 | + @param waveform Pointer to a 1-D array of size n | ||
| 25 | + @param n Number of entries in waveform | ||
| 26 | + */ | ||
| 27 | + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); | ||
| 28 | + | ||
| 29 | + /** | ||
| 30 | + * InputFinished() tells the class you won't be providing any | ||
| 31 | + * more waveform. This will help flush out the last frame or two | ||
| 32 | + * of features, in the case where snip-edges == false; it also | ||
| 33 | + * affects the return value of IsLastFrame(). | ||
| 34 | + */ | ||
| 35 | + void InputFinished(); | ||
| 36 | + | ||
| 37 | + int32_t NumFramesReady() const; | ||
| 38 | + | ||
| 39 | + /** Note: IsLastFrame() will only ever return true if you have called | ||
| 40 | + * InputFinished() (and this frame is the last frame). | ||
| 41 | + */ | ||
| 42 | + bool IsLastFrame(int32_t frame) const; | ||
| 43 | + | ||
| 44 | + /** Get n frames starting from the given frame index. | ||
| 45 | + * | ||
| 46 | + * @param frame_index The starting frame index | ||
| 47 | + * @param n Number of frames to get. | ||
| 48 | + * @return Return a 2-D tensor of shape (n, feature_dim). | ||
| 49 | + * which is flattened into a 1-D vector (flattened in in row major) | ||
| 50 | + */ | ||
| 51 | + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; | ||
| 52 | + | ||
| 53 | + void Reset(); | ||
| 54 | + | ||
| 55 | + int32_t FeatureDim() const; | ||
| 56 | + | ||
| 57 | + // Return a reference to the number of processed frames so far. | ||
| 58 | + // Initially, it is 0. It is always less than NumFramesReady(). | ||
| 59 | + // | ||
| 60 | + // The returned reference is valid as long as this object is alive. | ||
| 61 | + int32_t &GetNumProcessedFrames(); | ||
| 62 | + | ||
| 63 | + void SetResult(const OnlineTransducerDecoderResult &r); | ||
| 64 | + const OnlineTransducerDecoderResult &GetResult() const; | ||
| 65 | + | ||
| 66 | + private: | ||
| 67 | + class Impl; | ||
| 68 | + std::unique_ptr<Impl> impl_; | ||
| 69 | +}; | ||
| 70 | + | ||
| 71 | +} // namespace sherpa_onnx | ||
| 72 | + | ||
| 73 | +#endif // SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ |
| @@ -8,8 +8,7 @@ | @@ -8,8 +8,7 @@ | ||
| 8 | #include <string> | 8 | #include <string> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | -#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 12 | -#include "sherpa-onnx/csrc/features.h" | 11 | +#include "sherpa-onnx/csrc/online-stream.h" |
| 13 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 12 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 14 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 13 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 15 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 14 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| @@ -64,7 +63,7 @@ for a list of pre-trained models to download. | @@ -64,7 +63,7 @@ for a list of pre-trained models to download. | ||
| 64 | 63 | ||
| 65 | std::vector<Ort::Value> states = model->GetEncoderInitStates(); | 64 | std::vector<Ort::Value> states = model->GetEncoderInitStates(); |
| 66 | 65 | ||
| 67 | - int32_t expected_sampling_rate = 16000; | 66 | + float expected_sampling_rate = 16000; |
| 68 | 67 | ||
| 69 | bool is_ok = false; | 68 | bool is_ok = false; |
| 70 | std::vector<float> samples = | 69 | std::vector<float> samples = |
| @@ -75,7 +74,7 @@ for a list of pre-trained models to download. | @@ -75,7 +74,7 @@ for a list of pre-trained models to download. | ||
| 75 | return -1; | 74 | return -1; |
| 76 | } | 75 | } |
| 77 | 76 | ||
| 78 | - float duration = samples.size() / static_cast<float>(expected_sampling_rate); | 77 | + float duration = samples.size() / expected_sampling_rate; |
| 79 | 78 | ||
| 80 | fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); | 79 | fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); |
| 81 | fprintf(stderr, "wav duration (s): %.3f\n", duration); | 80 | fprintf(stderr, "wav duration (s): %.3f\n", duration); |
| @@ -83,32 +82,33 @@ for a list of pre-trained models to download. | @@ -83,32 +82,33 @@ for a list of pre-trained models to download. | ||
| 83 | auto begin = std::chrono::steady_clock::now(); | 82 | auto begin = std::chrono::steady_clock::now(); |
| 84 | fprintf(stderr, "Started\n"); | 83 | fprintf(stderr, "Started\n"); |
| 85 | 84 | ||
| 86 | - sherpa_onnx::FeatureExtractor feat_extractor; | ||
| 87 | - feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), | ||
| 88 | - samples.size()); | 85 | + sherpa_onnx::OnlineStream stream; |
| 86 | + stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); | ||
| 89 | 87 | ||
| 90 | std::vector<float> tail_paddings( | 88 | std::vector<float> tail_paddings( |
| 91 | static_cast<int>(0.2 * expected_sampling_rate)); | 89 | static_cast<int>(0.2 * expected_sampling_rate)); |
| 92 | - feat_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), | ||
| 93 | - tail_paddings.size()); | ||
| 94 | - feat_extractor.InputFinished(); | 90 | + stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(), |
| 91 | + tail_paddings.size()); | ||
| 92 | + stream.InputFinished(); | ||
| 95 | 93 | ||
| 96 | - int32_t num_frames = feat_extractor.NumFramesReady(); | ||
| 97 | - int32_t feature_dim = feat_extractor.FeatureDim(); | 94 | + int32_t num_frames = stream.NumFramesReady(); |
| 95 | + int32_t feature_dim = stream.FeatureDim(); | ||
| 98 | 96 | ||
| 99 | std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; | 97 | std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; |
| 100 | 98 | ||
| 101 | sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); | 99 | sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); |
| 102 | std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = { | 100 | std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = { |
| 103 | decoder.GetEmptyResult()}; | 101 | decoder.GetEmptyResult()}; |
| 104 | - | ||
| 105 | - for (int32_t start = 0; start + chunk_size < num_frames; | ||
| 106 | - start += chunk_shift) { | ||
| 107 | - std::vector<float> features = feat_extractor.GetFrames(start, chunk_size); | 102 | + while (stream.NumFramesReady() - stream.GetNumProcessedFrames() > |
| 103 | + chunk_size) { | ||
| 104 | + std::vector<float> features = | ||
| 105 | + stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size); | ||
| 106 | + stream.GetNumProcessedFrames() += chunk_shift; | ||
| 108 | 107 | ||
| 109 | Ort::Value x = | 108 | Ort::Value x = |
| 110 | Ort::Value::CreateTensor(memory_info, features.data(), features.size(), | 109 | Ort::Value::CreateTensor(memory_info, features.data(), features.size(), |
| 111 | x_shape.data(), x_shape.size()); | 110 | x_shape.data(), x_shape.size()); |
| 111 | + | ||
| 112 | auto pair = model->RunEncoder(std::move(x), states); | 112 | auto pair = model->RunEncoder(std::move(x), states); |
| 113 | states = std::move(pair.second); | 113 | states = std::move(pair.second); |
| 114 | decoder.Decode(std::move(pair.first), &result); | 114 | decoder.Decode(std::move(pair.first), &result); |
| @@ -116,8 +116,8 @@ for a list of pre-trained models to download. | @@ -116,8 +116,8 @@ for a list of pre-trained models to download. | ||
| 116 | decoder.StripLeadingBlanks(&result[0]); | 116 | decoder.StripLeadingBlanks(&result[0]); |
| 117 | const auto &hyp = result[0].tokens; | 117 | const auto &hyp = result[0].tokens; |
| 118 | std::string text; | 118 | std::string text; |
| 119 | - for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { | ||
| 120 | - text += sym[hyp[i]]; | 119 | + for (auto t : hyp) { |
| 120 | + text += sym[t]; | ||
| 121 | } | 121 | } |
| 122 | 122 | ||
| 123 | fprintf(stderr, "Done!\n"); | 123 | fprintf(stderr, "Done!\n"); |
-
请 注册 或 登录 后发表评论