正在显示
3 个修改的文件
包含
105 行增加
和
67 行删除
| @@ -6,74 +6,108 @@ | @@ -6,74 +6,108 @@ | ||
| 6 | 6 | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <mutex> // NOLINT | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 12 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 13 | + | ||
| 11 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 12 | 15 | ||
| 13 | -FeatureExtractor::FeatureExtractor() { | ||
| 14 | - opts_.frame_opts.dither = 0; | ||
| 15 | - opts_.frame_opts.snip_edges = false; | ||
| 16 | - opts_.frame_opts.samp_freq = 16000; | 16 | +class FeatureExtractor::Impl { |
| 17 | + public: | ||
| 18 | + Impl(int32_t sampling_rate, int32_t feature_dim) { | ||
| 19 | + opts_.frame_opts.dither = 0; | ||
| 20 | + opts_.frame_opts.snip_edges = false; | ||
| 21 | + opts_.frame_opts.samp_freq = sampling_rate; | ||
| 17 | 22 | ||
| 18 | - // cache 100 seconds of feature frames, which is more than enough | ||
| 19 | - // for real needs | ||
| 20 | - opts_.frame_opts.max_feature_vectors = 100 * 100; | 23 | + // cache 100 seconds of feature frames, which is more than enough |
| 24 | + // for real needs | ||
| 25 | + opts_.frame_opts.max_feature_vectors = 100 * 100; | ||
| 21 | 26 | ||
| 22 | - opts_.mel_opts.num_bins = 80; // feature dim | 27 | + opts_.mel_opts.num_bins = feature_dim; |
| 23 | 28 | ||
| 24 | - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 25 | -} | 29 | + fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 30 | + } | ||
| 26 | 31 | ||
| 27 | -FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) | ||
| 28 | - : opts_(opts) { | ||
| 29 | - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 30 | -} | 32 | + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { |
| 33 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 34 | + fbank_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + void InputFinished() { | ||
| 38 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 39 | + fbank_->InputFinished(); | ||
| 40 | + } | ||
| 41 | + | ||
| 42 | + int32_t NumFramesReady() const { | ||
| 43 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 44 | + return fbank_->NumFramesReady(); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + bool IsLastFrame(int32_t frame) const { | ||
| 48 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 49 | + return fbank_->IsLastFrame(frame); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { | ||
| 53 | + if (frame_index + n > NumFramesReady()) { | ||
| 54 | + fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); | ||
| 55 | + exit(-1); | ||
| 56 | + } | ||
| 57 | + std::lock_guard<std::mutex> lock(mutex_); | ||
| 58 | + | ||
| 59 | + int32_t feature_dim = fbank_->Dim(); | ||
| 60 | + std::vector<float> features(feature_dim * n); | ||
| 61 | + | ||
| 62 | + float *p = features.data(); | ||
| 63 | + | ||
| 64 | + for (int32_t i = 0; i != n; ++i) { | ||
| 65 | + const float *f = fbank_->GetFrame(i + frame_index); | ||
| 66 | + std::copy(f, f + feature_dim, p); | ||
| 67 | + p += feature_dim; | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + return features; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + void Reset() { fbank_ = std::make_unique<knf::OnlineFbank>(opts_); } | ||
| 74 | + | ||
| 75 | + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | ||
| 76 | + | ||
| 77 | + private: | ||
| 78 | + std::unique_ptr<knf::OnlineFbank> fbank_; | ||
| 79 | + knf::FbankOptions opts_; | ||
| 80 | + mutable std::mutex mutex_; | ||
| 81 | +}; | ||
| 82 | + | ||
| 83 | +FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/, | ||
| 84 | + int32_t feature_dim /*=80*/) | ||
| 85 | + : impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {} | ||
| 86 | + | ||
| 87 | +FeatureExtractor::~FeatureExtractor() = default; | ||
| 31 | 88 | ||
| 32 | void FeatureExtractor::AcceptWaveform(float sampling_rate, | 89 | void FeatureExtractor::AcceptWaveform(float sampling_rate, |
| 33 | const float *waveform, int32_t n) { | 90 | const float *waveform, int32_t n) { |
| 34 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 35 | - fbank_->AcceptWaveform(sampling_rate, waveform, n); | 91 | + impl_->AcceptWaveform(sampling_rate, waveform, n); |
| 36 | } | 92 | } |
| 37 | 93 | ||
| 38 | -void FeatureExtractor::InputFinished() { | ||
| 39 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 40 | - fbank_->InputFinished(); | ||
| 41 | -} | 94 | +void FeatureExtractor::InputFinished() { impl_->InputFinished(); } |
| 42 | 95 | ||
| 43 | int32_t FeatureExtractor::NumFramesReady() const { | 96 | int32_t FeatureExtractor::NumFramesReady() const { |
| 44 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 45 | - return fbank_->NumFramesReady(); | 97 | + return impl_->NumFramesReady(); |
| 46 | } | 98 | } |
| 47 | 99 | ||
| 48 | bool FeatureExtractor::IsLastFrame(int32_t frame) const { | 100 | bool FeatureExtractor::IsLastFrame(int32_t frame) const { |
| 49 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 50 | - return fbank_->IsLastFrame(frame); | 101 | + return impl_->IsLastFrame(frame); |
| 51 | } | 102 | } |
| 52 | 103 | ||
| 53 | std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | 104 | std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, |
| 54 | int32_t n) const { | 105 | int32_t n) const { |
| 55 | - if (frame_index + n > NumFramesReady()) { | ||
| 56 | - fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); | ||
| 57 | - exit(-1); | ||
| 58 | - } | ||
| 59 | - std::lock_guard<std::mutex> lock(mutex_); | ||
| 60 | - | ||
| 61 | - int32_t feature_dim = fbank_->Dim(); | ||
| 62 | - std::vector<float> features(feature_dim * n); | ||
| 63 | - | ||
| 64 | - float *p = features.data(); | ||
| 65 | - | ||
| 66 | - for (int32_t i = 0; i != n; ++i) { | ||
| 67 | - const float *f = fbank_->GetFrame(i + frame_index); | ||
| 68 | - std::copy(f, f + feature_dim, p); | ||
| 69 | - p += feature_dim; | ||
| 70 | - } | ||
| 71 | - | ||
| 72 | - return features; | 106 | + return impl_->GetFrames(frame_index, n); |
| 73 | } | 107 | } |
| 74 | 108 | ||
| 75 | -void FeatureExtractor::Reset() { | ||
| 76 | - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 77 | -} | 109 | +void FeatureExtractor::Reset() { impl_->Reset(); } |
| 110 | + | ||
| 111 | +int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } | ||
| 78 | 112 | ||
| 79 | } // namespace sherpa_onnx | 113 | } // namespace sherpa_onnx |
| @@ -6,17 +6,19 @@ | @@ -6,17 +6,19 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_FEATURES_H_ | 6 | #define SHERPA_ONNX_CSRC_FEATURES_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | -#include <mutex> // NOLINT | ||
| 10 | #include <vector> | 9 | #include <vector> |
| 11 | 10 | ||
| 12 | -#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 13 | - | ||
| 14 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 15 | 12 | ||
| 16 | class FeatureExtractor { | 13 | class FeatureExtractor { |
| 17 | public: | 14 | public: |
| 18 | - FeatureExtractor(); | ||
| 19 | - explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); | 15 | + /** |
| 16 | + * @param sampling_rate Sampling rate of the data used to train the model. | ||
| 17 | + * @param feature_dim Dimension of the features used to train the model. | ||
| 18 | + */ | ||
| 19 | + explicit FeatureExtractor(int32_t sampling_rate = 16000, | ||
| 20 | + int32_t feature_dim = 80); | ||
| 21 | + ~FeatureExtractor(); | ||
| 20 | 22 | ||
| 21 | /** | 23 | /** |
| 22 | @param sampling_rate The sampling_rate of the input waveform. Should match | 24 | @param sampling_rate The sampling_rate of the input waveform. Should match |
| @@ -48,12 +50,13 @@ class FeatureExtractor { | @@ -48,12 +50,13 @@ class FeatureExtractor { | ||
| 48 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; | 50 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; |
| 49 | 51 | ||
| 50 | void Reset(); | 52 | void Reset(); |
| 51 | - int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | 53 | + |
| 54 | + /// Return feature dim of this extractor | ||
| 55 | + int32_t FeatureDim() const; | ||
| 52 | 56 | ||
| 53 | private: | 57 | private: |
| 54 | - std::unique_ptr<knf::OnlineFbank> fbank_; | ||
| 55 | - knf::FbankOptions opts_; | ||
| 56 | - mutable std::mutex mutex_; | 58 | + class Impl; |
| 59 | + std::unique_ptr<Impl> impl_; | ||
| 57 | }; | 60 | }; |
| 58 | 61 | ||
| 59 | } // namespace sherpa_onnx | 62 | } // namespace sherpa_onnx |
| @@ -2,8 +2,9 @@ | @@ -2,8 +2,9 @@ | ||
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2022-2023 Xiaomi Corporation | 3 | // Copyright (c) 2022-2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | +#include <stdio.h> | ||
| 6 | + | ||
| 5 | #include <chrono> // NOLINT | 7 | #include <chrono> // NOLINT |
| 6 | -#include <iostream> | ||
| 7 | #include <string> | 8 | #include <string> |
| 8 | #include <vector> | 9 | #include <vector> |
| 9 | 10 | ||
| @@ -30,14 +31,14 @@ Please refer to | @@ -30,14 +31,14 @@ Please refer to | ||
| 30 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 31 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| 31 | for a list of pre-trained models to download. | 32 | for a list of pre-trained models to download. |
| 32 | )usage"; | 33 | )usage"; |
| 33 | - std::cerr << usage << "\n"; | 34 | + fprintf(stderr, "%s\n", usage); |
| 34 | 35 | ||
| 35 | return 0; | 36 | return 0; |
| 36 | } | 37 | } |
| 37 | 38 | ||
| 38 | std::string tokens = argv[1]; | 39 | std::string tokens = argv[1]; |
| 39 | sherpa_onnx::OnlineTransducerModelConfig config; | 40 | sherpa_onnx::OnlineTransducerModelConfig config; |
| 40 | - config.debug = true; | 41 | + config.debug = false; |
| 41 | config.encoder_filename = argv[2]; | 42 | config.encoder_filename = argv[2]; |
| 42 | config.decoder_filename = argv[3]; | 43 | config.decoder_filename = argv[3]; |
| 43 | config.joiner_filename = argv[4]; | 44 | config.joiner_filename = argv[4]; |
| @@ -47,7 +48,7 @@ for a list of pre-trained models to download. | @@ -47,7 +48,7 @@ for a list of pre-trained models to download. | ||
| 47 | if (argc == 7) { | 48 | if (argc == 7) { |
| 48 | config.num_threads = atoi(argv[6]); | 49 | config.num_threads = atoi(argv[6]); |
| 49 | } | 50 | } |
| 50 | - std::cout << config.ToString().c_str() << "\n"; | 51 | + fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 51 | 52 | ||
| 52 | auto model = sherpa_onnx::OnlineTransducerModel::Create(config); | 53 | auto model = sherpa_onnx::OnlineTransducerModel::Create(config); |
| 53 | 54 | ||
| @@ -72,17 +73,17 @@ for a list of pre-trained models to download. | @@ -72,17 +73,17 @@ for a list of pre-trained models to download. | ||
| 72 | sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); | 73 | sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); |
| 73 | 74 | ||
| 74 | if (!is_ok) { | 75 | if (!is_ok) { |
| 75 | - std::cerr << "Failed to read " << wav_filename << "\n"; | 76 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); |
| 76 | return -1; | 77 | return -1; |
| 77 | } | 78 | } |
| 78 | 79 | ||
| 79 | - const float duration = samples.size() / expected_sampling_rate; | 80 | + float duration = samples.size() / static_cast<float>(expected_sampling_rate); |
| 80 | 81 | ||
| 81 | - std::cout << "wav filename: " << wav_filename << "\n"; | ||
| 82 | - std::cout << "wav duration (s): " << duration << "\n"; | 82 | + fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); |
| 83 | + fprintf(stderr, "wav duration (s): %.3f\n", duration); | ||
| 83 | 84 | ||
| 84 | auto begin = std::chrono::steady_clock::now(); | 85 | auto begin = std::chrono::steady_clock::now(); |
| 85 | - std::cout << "Started!\n"; | 86 | + fprintf(stderr, "Started\n"); |
| 86 | 87 | ||
| 87 | sherpa_onnx::FeatureExtractor feat_extractor; | 88 | sherpa_onnx::FeatureExtractor feat_extractor; |
| 88 | feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), | 89 | feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), |
| @@ -115,10 +116,10 @@ for a list of pre-trained models to download. | @@ -115,10 +116,10 @@ for a list of pre-trained models to download. | ||
| 115 | text += sym[hyp[i]]; | 116 | text += sym[hyp[i]]; |
| 116 | } | 117 | } |
| 117 | 118 | ||
| 118 | - std::cout << "Done!\n"; | 119 | + fprintf(stderr, "Done!\n"); |
| 119 | 120 | ||
| 120 | - std::cout << "Recognition result for " << wav_filename << "\n" | ||
| 121 | - << text << "\n"; | 121 | + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), |
| 122 | + text.c_str()); | ||
| 122 | 123 | ||
| 123 | auto end = std::chrono::steady_clock::now(); | 124 | auto end = std::chrono::steady_clock::now(); |
| 124 | float elapsed_seconds = | 125 | float elapsed_seconds = |
| @@ -126,7 +127,7 @@ for a list of pre-trained models to download. | @@ -126,7 +127,7 @@ for a list of pre-trained models to download. | ||
| 126 | .count() / | 127 | .count() / |
| 127 | 1000.; | 128 | 1000.; |
| 128 | 129 | ||
| 129 | - std::cout << "num threads: " << config.num_threads << "\n"; | 130 | + fprintf(stderr, "num threads: %d\n", config.num_threads); |
| 130 | 131 | ||
| 131 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | 132 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); |
| 132 | float rtf = elapsed_seconds / duration; | 133 | float rtf = elapsed_seconds / duration; |
-
请 注册 或 登录 后发表评论