正在显示
3 个修改的文件
包含
82 行增加
和
44 行删除
| @@ -6,52 +6,50 @@ | @@ -6,52 +6,50 @@ | ||
| 6 | 6 | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | +#include <mutex> // NOLINT | ||
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| 12 | +#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 13 | + | ||
| 11 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 12 | 15 | ||
| 13 | -FeatureExtractor::FeatureExtractor() { | 16 | +class FeatureExtractor::Impl { |
| 17 | + public: | ||
| 18 | + Impl(int32_t sampling_rate, int32_t feature_dim) { | ||
| 14 | opts_.frame_opts.dither = 0; | 19 | opts_.frame_opts.dither = 0; |
| 15 | opts_.frame_opts.snip_edges = false; | 20 | opts_.frame_opts.snip_edges = false; |
| 16 | - opts_.frame_opts.samp_freq = 16000; | 21 | + opts_.frame_opts.samp_freq = sampling_rate; |
| 17 | 22 | ||
| 18 | // cache 100 seconds of feature frames, which is more than enough | 23 | // cache 100 seconds of feature frames, which is more than enough |
| 19 | // for real needs | 24 | // for real needs |
| 20 | opts_.frame_opts.max_feature_vectors = 100 * 100; | 25 | opts_.frame_opts.max_feature_vectors = 100 * 100; |
| 21 | 26 | ||
| 22 | - opts_.mel_opts.num_bins = 80; // feature dim | ||
| 23 | - | ||
| 24 | - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | ||
| 25 | -} | 27 | + opts_.mel_opts.num_bins = feature_dim; |
| 26 | 28 | ||
| 27 | -FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) | ||
| 28 | - : opts_(opts) { | ||
| 29 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 29 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 30 | -} | 30 | + } |
| 31 | 31 | ||
| 32 | -void FeatureExtractor::AcceptWaveform(float sampling_rate, | ||
| 33 | - const float *waveform, int32_t n) { | 32 | + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { |
| 34 | std::lock_guard<std::mutex> lock(mutex_); | 33 | std::lock_guard<std::mutex> lock(mutex_); |
| 35 | fbank_->AcceptWaveform(sampling_rate, waveform, n); | 34 | fbank_->AcceptWaveform(sampling_rate, waveform, n); |
| 36 | -} | 35 | + } |
| 37 | 36 | ||
| 38 | -void FeatureExtractor::InputFinished() { | 37 | + void InputFinished() { |
| 39 | std::lock_guard<std::mutex> lock(mutex_); | 38 | std::lock_guard<std::mutex> lock(mutex_); |
| 40 | fbank_->InputFinished(); | 39 | fbank_->InputFinished(); |
| 41 | -} | 40 | + } |
| 42 | 41 | ||
| 43 | -int32_t FeatureExtractor::NumFramesReady() const { | 42 | + int32_t NumFramesReady() const { |
| 44 | std::lock_guard<std::mutex> lock(mutex_); | 43 | std::lock_guard<std::mutex> lock(mutex_); |
| 45 | return fbank_->NumFramesReady(); | 44 | return fbank_->NumFramesReady(); |
| 46 | -} | 45 | + } |
| 47 | 46 | ||
| 48 | -bool FeatureExtractor::IsLastFrame(int32_t frame) const { | 47 | + bool IsLastFrame(int32_t frame) const { |
| 49 | std::lock_guard<std::mutex> lock(mutex_); | 48 | std::lock_guard<std::mutex> lock(mutex_); |
| 50 | return fbank_->IsLastFrame(frame); | 49 | return fbank_->IsLastFrame(frame); |
| 51 | -} | 50 | + } |
| 52 | 51 | ||
| 53 | -std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | ||
| 54 | - int32_t n) const { | 52 | + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const { |
| 55 | if (frame_index + n > NumFramesReady()) { | 53 | if (frame_index + n > NumFramesReady()) { |
| 56 | fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); | 54 | fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); |
| 57 | exit(-1); | 55 | exit(-1); |
| @@ -70,10 +68,46 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | @@ -70,10 +68,46 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | ||
| 70 | } | 68 | } |
| 71 | 69 | ||
| 72 | return features; | 70 | return features; |
| 71 | + } | ||
| 72 | + | ||
| 73 | + void Reset() { fbank_ = std::make_unique<knf::OnlineFbank>(opts_); } | ||
| 74 | + | ||
| 75 | + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | ||
| 76 | + | ||
| 77 | + private: | ||
| 78 | + std::unique_ptr<knf::OnlineFbank> fbank_; | ||
| 79 | + knf::FbankOptions opts_; | ||
| 80 | + mutable std::mutex mutex_; | ||
| 81 | +}; | ||
| 82 | + | ||
| 83 | +FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/, | ||
| 84 | + int32_t feature_dim /*=80*/) | ||
| 85 | + : impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {} | ||
| 86 | + | ||
| 87 | +FeatureExtractor::~FeatureExtractor() = default; | ||
| 88 | + | ||
| 89 | +void FeatureExtractor::AcceptWaveform(float sampling_rate, | ||
| 90 | + const float *waveform, int32_t n) { | ||
| 91 | + impl_->AcceptWaveform(sampling_rate, waveform, n); | ||
| 73 | } | 92 | } |
| 74 | 93 | ||
| 75 | -void FeatureExtractor::Reset() { | ||
| 76 | - fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 94 | +void FeatureExtractor::InputFinished() { impl_->InputFinished(); } |
| 95 | + | ||
| 96 | +int32_t FeatureExtractor::NumFramesReady() const { | ||
| 97 | + return impl_->NumFramesReady(); | ||
| 98 | +} | ||
| 99 | + | ||
| 100 | +bool FeatureExtractor::IsLastFrame(int32_t frame) const { | ||
| 101 | + return impl_->IsLastFrame(frame); | ||
| 102 | +} | ||
| 103 | + | ||
| 104 | +std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, | ||
| 105 | + int32_t n) const { | ||
| 106 | + return impl_->GetFrames(frame_index, n); | ||
| 77 | } | 107 | } |
| 78 | 108 | ||
| 109 | +void FeatureExtractor::Reset() { impl_->Reset(); } | ||
| 110 | + | ||
| 111 | +int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } | ||
| 112 | + | ||
| 79 | } // namespace sherpa_onnx | 113 | } // namespace sherpa_onnx |
| @@ -6,17 +6,19 @@ | @@ -6,17 +6,19 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_FEATURES_H_ | 6 | #define SHERPA_ONNX_CSRC_FEATURES_H_ |
| 7 | 7 | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | -#include <mutex> // NOLINT | ||
| 10 | #include <vector> | 9 | #include <vector> |
| 11 | 10 | ||
| 12 | -#include "kaldi-native-fbank/csrc/online-feature.h" | ||
| 13 | - | ||
| 14 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 15 | 12 | ||
| 16 | class FeatureExtractor { | 13 | class FeatureExtractor { |
| 17 | public: | 14 | public: |
| 18 | - FeatureExtractor(); | ||
| 19 | - explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); | 15 | + /** |
| 16 | + * @param sampling_rate Sampling rate of the data used to train the model. | ||
| 17 | + * @param feature_dim Dimension of the features used to train the model. | ||
| 18 | + */ | ||
| 19 | + explicit FeatureExtractor(int32_t sampling_rate = 16000, | ||
| 20 | + int32_t feature_dim = 80); | ||
| 21 | + ~FeatureExtractor(); | ||
| 20 | 22 | ||
| 21 | /** | 23 | /** |
| 22 | @param sampling_rate The sampling_rate of the input waveform. Should match | 24 | @param sampling_rate The sampling_rate of the input waveform. Should match |
| @@ -48,12 +50,13 @@ class FeatureExtractor { | @@ -48,12 +50,13 @@ class FeatureExtractor { | ||
| 48 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; | 50 | std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; |
| 49 | 51 | ||
| 50 | void Reset(); | 52 | void Reset(); |
| 51 | - int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } | 53 | + |
| 54 | + /// Return feature dim of this extractor | ||
| 55 | + int32_t FeatureDim() const; | ||
| 52 | 56 | ||
| 53 | private: | 57 | private: |
| 54 | - std::unique_ptr<knf::OnlineFbank> fbank_; | ||
| 55 | - knf::FbankOptions opts_; | ||
| 56 | - mutable std::mutex mutex_; | 58 | + class Impl; |
| 59 | + std::unique_ptr<Impl> impl_; | ||
| 57 | }; | 60 | }; |
| 58 | 61 | ||
| 59 | } // namespace sherpa_onnx | 62 | } // namespace sherpa_onnx |
| @@ -2,8 +2,9 @@ | @@ -2,8 +2,9 @@ | ||
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2022-2023 Xiaomi Corporation | 3 | // Copyright (c) 2022-2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | +#include <stdio.h> | ||
| 6 | + | ||
| 5 | #include <chrono> // NOLINT | 7 | #include <chrono> // NOLINT |
| 6 | -#include <iostream> | ||
| 7 | #include <string> | 8 | #include <string> |
| 8 | #include <vector> | 9 | #include <vector> |
| 9 | 10 | ||
| @@ -30,14 +31,14 @@ Please refer to | @@ -30,14 +31,14 @@ Please refer to | ||
| 30 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 31 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| 31 | for a list of pre-trained models to download. | 32 | for a list of pre-trained models to download. |
| 32 | )usage"; | 33 | )usage"; |
| 33 | - std::cerr << usage << "\n"; | 34 | + fprintf(stderr, "%s\n", usage); |
| 34 | 35 | ||
| 35 | return 0; | 36 | return 0; |
| 36 | } | 37 | } |
| 37 | 38 | ||
| 38 | std::string tokens = argv[1]; | 39 | std::string tokens = argv[1]; |
| 39 | sherpa_onnx::OnlineTransducerModelConfig config; | 40 | sherpa_onnx::OnlineTransducerModelConfig config; |
| 40 | - config.debug = true; | 41 | + config.debug = false; |
| 41 | config.encoder_filename = argv[2]; | 42 | config.encoder_filename = argv[2]; |
| 42 | config.decoder_filename = argv[3]; | 43 | config.decoder_filename = argv[3]; |
| 43 | config.joiner_filename = argv[4]; | 44 | config.joiner_filename = argv[4]; |
| @@ -47,7 +48,7 @@ for a list of pre-trained models to download. | @@ -47,7 +48,7 @@ for a list of pre-trained models to download. | ||
| 47 | if (argc == 7) { | 48 | if (argc == 7) { |
| 48 | config.num_threads = atoi(argv[6]); | 49 | config.num_threads = atoi(argv[6]); |
| 49 | } | 50 | } |
| 50 | - std::cout << config.ToString().c_str() << "\n"; | 51 | + fprintf(stderr, "%s\n", config.ToString().c_str()); |
| 51 | 52 | ||
| 52 | auto model = sherpa_onnx::OnlineTransducerModel::Create(config); | 53 | auto model = sherpa_onnx::OnlineTransducerModel::Create(config); |
| 53 | 54 | ||
| @@ -72,17 +73,17 @@ for a list of pre-trained models to download. | @@ -72,17 +73,17 @@ for a list of pre-trained models to download. | ||
| 72 | sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); | 73 | sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); |
| 73 | 74 | ||
| 74 | if (!is_ok) { | 75 | if (!is_ok) { |
| 75 | - std::cerr << "Failed to read " << wav_filename << "\n"; | 76 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); |
| 76 | return -1; | 77 | return -1; |
| 77 | } | 78 | } |
| 78 | 79 | ||
| 79 | - const float duration = samples.size() / expected_sampling_rate; | 80 | + float duration = samples.size() / static_cast<float>(expected_sampling_rate); |
| 80 | 81 | ||
| 81 | - std::cout << "wav filename: " << wav_filename << "\n"; | ||
| 82 | - std::cout << "wav duration (s): " << duration << "\n"; | 82 | + fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); |
| 83 | + fprintf(stderr, "wav duration (s): %.3f\n", duration); | ||
| 83 | 84 | ||
| 84 | auto begin = std::chrono::steady_clock::now(); | 85 | auto begin = std::chrono::steady_clock::now(); |
| 85 | - std::cout << "Started!\n"; | 86 | + fprintf(stderr, "Started\n"); |
| 86 | 87 | ||
| 87 | sherpa_onnx::FeatureExtractor feat_extractor; | 88 | sherpa_onnx::FeatureExtractor feat_extractor; |
| 88 | feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), | 89 | feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), |
| @@ -115,10 +116,10 @@ for a list of pre-trained models to download. | @@ -115,10 +116,10 @@ for a list of pre-trained models to download. | ||
| 115 | text += sym[hyp[i]]; | 116 | text += sym[hyp[i]]; |
| 116 | } | 117 | } |
| 117 | 118 | ||
| 118 | - std::cout << "Done!\n"; | 119 | + fprintf(stderr, "Done!\n"); |
| 119 | 120 | ||
| 120 | - std::cout << "Recognition result for " << wav_filename << "\n" | ||
| 121 | - << text << "\n"; | 121 | + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), |
| 122 | + text.c_str()); | ||
| 122 | 123 | ||
| 123 | auto end = std::chrono::steady_clock::now(); | 124 | auto end = std::chrono::steady_clock::now(); |
| 124 | float elapsed_seconds = | 125 | float elapsed_seconds = |
| @@ -126,7 +127,7 @@ for a list of pre-trained models to download. | @@ -126,7 +127,7 @@ for a list of pre-trained models to download. | ||
| 126 | .count() / | 127 | .count() / |
| 127 | 1000.; | 128 | 1000.; |
| 128 | 129 | ||
| 129 | - std::cout << "num threads: " << config.num_threads << "\n"; | 130 | + fprintf(stderr, "num threads: %d\n", config.num_threads); |
| 130 | 131 | ||
| 131 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | 132 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); |
| 132 | float rtf = elapsed_seconds / duration; | 133 | float rtf = elapsed_seconds / duration; |
-
请 注册 或 登录 后发表评论