Fangjun Kuang
Committed by GitHub

Refactor feature extractor (#26)

@@ -6,74 +6,108 @@ @@ -6,74 +6,108 @@
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <memory> 8 #include <memory>
  9 +#include <mutex> // NOLINT
9 #include <vector> 10 #include <vector>
10 11
  12 +#include "kaldi-native-fbank/csrc/online-feature.h"
  13 +
11 namespace sherpa_onnx { 14 namespace sherpa_onnx {
12 15
13 -FeatureExtractor::FeatureExtractor() {  
14 - opts_.frame_opts.dither = 0;  
15 - opts_.frame_opts.snip_edges = false;  
16 - opts_.frame_opts.samp_freq = 16000; 16 +class FeatureExtractor::Impl {
  17 + public:
  18 + Impl(int32_t sampling_rate, int32_t feature_dim) {
  19 + opts_.frame_opts.dither = 0;
  20 + opts_.frame_opts.snip_edges = false;
  21 + opts_.frame_opts.samp_freq = sampling_rate;
17 22
18 - // cache 100 seconds of feature frames, which is more than enough  
19 - // for real needs  
20 - opts_.frame_opts.max_feature_vectors = 100 * 100; 23 + // cache 100 seconds of feature frames, which is more than enough
  24 + // for real needs
  25 + opts_.frame_opts.max_feature_vectors = 100 * 100;
21 26
22 - opts_.mel_opts.num_bins = 80; // feature dim 27 + opts_.mel_opts.num_bins = feature_dim;
23 28
24 - fbank_ = std::make_unique<knf::OnlineFbank>(opts_);  
25 -} 29 + fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
  30 + }
26 31
27 -FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts)  
28 - : opts_(opts) {  
29 - fbank_ = std::make_unique<knf::OnlineFbank>(opts_);  
30 -} 32 + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
  33 + std::lock_guard<std::mutex> lock(mutex_);
  34 + fbank_->AcceptWaveform(sampling_rate, waveform, n);
  35 + }
  36 +
  37 + void InputFinished() {
  38 + std::lock_guard<std::mutex> lock(mutex_);
  39 + fbank_->InputFinished();
  40 + }
  41 +
  42 + int32_t NumFramesReady() const {
  43 + std::lock_guard<std::mutex> lock(mutex_);
  44 + return fbank_->NumFramesReady();
  45 + }
  46 +
  47 + bool IsLastFrame(int32_t frame) const {
  48 + std::lock_guard<std::mutex> lock(mutex_);
  49 + return fbank_->IsLastFrame(frame);
  50 + }
  51 +
  52 + std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
  53 + if (frame_index + n > NumFramesReady()) {
  54 + fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
  55 + exit(-1);
  56 + }
  57 + std::lock_guard<std::mutex> lock(mutex_);
  58 +
  59 + int32_t feature_dim = fbank_->Dim();
  60 + std::vector<float> features(feature_dim * n);
  61 +
  62 + float *p = features.data();
  63 +
  64 + for (int32_t i = 0; i != n; ++i) {
  65 + const float *f = fbank_->GetFrame(i + frame_index);
  66 + std::copy(f, f + feature_dim, p);
  67 + p += feature_dim;
  68 + }
  69 +
  70 + return features;
  71 + }
  72 +
  73 + void Reset() { fbank_ = std::make_unique<knf::OnlineFbank>(opts_); }
  74 +
  75 + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
  76 +
  77 + private:
  78 + std::unique_ptr<knf::OnlineFbank> fbank_;
  79 + knf::FbankOptions opts_;
  80 + mutable std::mutex mutex_;
  81 +};
  82 +
  83 +FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/,
  84 + int32_t feature_dim /*=80*/)
  85 + : impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {}
  86 +
  87 +FeatureExtractor::~FeatureExtractor() = default;
31 88
32 void FeatureExtractor::AcceptWaveform(float sampling_rate, 89 void FeatureExtractor::AcceptWaveform(float sampling_rate,
33 const float *waveform, int32_t n) { 90 const float *waveform, int32_t n) {
34 - std::lock_guard<std::mutex> lock(mutex_);  
35 - fbank_->AcceptWaveform(sampling_rate, waveform, n); 91 + impl_->AcceptWaveform(sampling_rate, waveform, n);
36 } 92 }
37 93
38 -void FeatureExtractor::InputFinished() {  
39 - std::lock_guard<std::mutex> lock(mutex_);  
40 - fbank_->InputFinished();  
41 -} 94 +void FeatureExtractor::InputFinished() { impl_->InputFinished(); }
42 95
43 int32_t FeatureExtractor::NumFramesReady() const { 96 int32_t FeatureExtractor::NumFramesReady() const {
44 - std::lock_guard<std::mutex> lock(mutex_);  
45 - return fbank_->NumFramesReady(); 97 + return impl_->NumFramesReady();
46 } 98 }
47 99
48 bool FeatureExtractor::IsLastFrame(int32_t frame) const { 100 bool FeatureExtractor::IsLastFrame(int32_t frame) const {
49 - std::lock_guard<std::mutex> lock(mutex_);  
50 - return fbank_->IsLastFrame(frame); 101 + return impl_->IsLastFrame(frame);
51 } 102 }
52 103
53 std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, 104 std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
54 int32_t n) const { 105 int32_t n) const {
55 - if (frame_index + n > NumFramesReady()) {  
56 - fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());  
57 - exit(-1);  
58 - }  
59 - std::lock_guard<std::mutex> lock(mutex_);  
60 -  
61 - int32_t feature_dim = fbank_->Dim();  
62 - std::vector<float> features(feature_dim * n);  
63 -  
64 - float *p = features.data();  
65 -  
66 - for (int32_t i = 0; i != n; ++i) {  
67 - const float *f = fbank_->GetFrame(i + frame_index);  
68 - std::copy(f, f + feature_dim, p);  
69 - p += feature_dim;  
70 - }  
71 -  
72 - return features; 106 + return impl_->GetFrames(frame_index, n);
73 } 107 }
74 108
75 -void FeatureExtractor::Reset() {  
76 - fbank_ = std::make_unique<knf::OnlineFbank>(opts_);  
77 -} 109 +void FeatureExtractor::Reset() { impl_->Reset(); }
  110 +
  111 +int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
78 112
79 } // namespace sherpa_onnx 113 } // namespace sherpa_onnx
@@ -6,17 +6,19 @@ @@ -6,17 +6,19 @@
6 #define SHERPA_ONNX_CSRC_FEATURES_H_ 6 #define SHERPA_ONNX_CSRC_FEATURES_H_
7 7
8 #include <memory> 8 #include <memory>
9 -#include <mutex> // NOLINT  
10 #include <vector> 9 #include <vector>
11 10
12 -#include "kaldi-native-fbank/csrc/online-feature.h"  
13 -  
14 namespace sherpa_onnx { 11 namespace sherpa_onnx {
15 12
16 class FeatureExtractor { 13 class FeatureExtractor {
17 public: 14 public:
18 - FeatureExtractor();  
19 - explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); 15 + /**
  16 + * @param sampling_rate Sampling rate of the data used to train the model.
  17 + * @param feature_dim Dimension of the features used to train the model.
  18 + */
  19 + explicit FeatureExtractor(int32_t sampling_rate = 16000,
  20 + int32_t feature_dim = 80);
  21 + ~FeatureExtractor();
20 22
21 /** 23 /**
22 @param sampling_rate The sampling_rate of the input waveform. Should match 24 @param sampling_rate The sampling_rate of the input waveform. Should match
@@ -48,12 +50,13 @@ class FeatureExtractor { @@ -48,12 +50,13 @@ class FeatureExtractor {
48 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; 50 std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
49 51
50 void Reset(); 52 void Reset();
51 - int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } 53 +
  54 + /// Return feature dim of this extractor
  55 + int32_t FeatureDim() const;
52 56
53 private: 57 private:
54 - std::unique_ptr<knf::OnlineFbank> fbank_;  
55 - knf::FbankOptions opts_;  
56 - mutable std::mutex mutex_; 58 + class Impl;
  59 + std::unique_ptr<Impl> impl_;
57 }; 60 };
58 61
59 } // namespace sherpa_onnx 62 } // namespace sherpa_onnx
@@ -2,8 +2,9 @@ @@ -2,8 +2,9 @@
2 // 2 //
3 // Copyright (c) 2022-2023 Xiaomi Corporation 3 // Copyright (c) 2022-2023 Xiaomi Corporation
4 4
  5 +#include <stdio.h>
  6 +
5 #include <chrono> // NOLINT 7 #include <chrono> // NOLINT
6 -#include <iostream>  
7 #include <string> 8 #include <string>
8 #include <vector> 9 #include <vector>
9 10
@@ -30,14 +31,14 @@ Please refer to @@ -30,14 +31,14 @@ Please refer to
30 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 31 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
31 for a list of pre-trained models to download. 32 for a list of pre-trained models to download.
32 )usage"; 33 )usage";
33 - std::cerr << usage << "\n"; 34 + fprintf(stderr, "%s\n", usage);
34 35
35 return 0; 36 return 0;
36 } 37 }
37 38
38 std::string tokens = argv[1]; 39 std::string tokens = argv[1];
39 sherpa_onnx::OnlineTransducerModelConfig config; 40 sherpa_onnx::OnlineTransducerModelConfig config;
40 - config.debug = true; 41 + config.debug = false;
41 config.encoder_filename = argv[2]; 42 config.encoder_filename = argv[2];
42 config.decoder_filename = argv[3]; 43 config.decoder_filename = argv[3];
43 config.joiner_filename = argv[4]; 44 config.joiner_filename = argv[4];
@@ -47,7 +48,7 @@ for a list of pre-trained models to download. @@ -47,7 +48,7 @@ for a list of pre-trained models to download.
47 if (argc == 7) { 48 if (argc == 7) {
48 config.num_threads = atoi(argv[6]); 49 config.num_threads = atoi(argv[6]);
49 } 50 }
50 - std::cout << config.ToString().c_str() << "\n"; 51 + fprintf(stderr, "%s\n", config.ToString().c_str());
51 52
52 auto model = sherpa_onnx::OnlineTransducerModel::Create(config); 53 auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
53 54
@@ -72,17 +73,17 @@ for a list of pre-trained models to download. @@ -72,17 +73,17 @@ for a list of pre-trained models to download.
72 sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); 73 sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
73 74
74 if (!is_ok) { 75 if (!is_ok) {
75 - std::cerr << "Failed to read " << wav_filename << "\n"; 76 + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
76 return -1; 77 return -1;
77 } 78 }
78 79
79 - const float duration = samples.size() / expected_sampling_rate; 80 + float duration = samples.size() / static_cast<float>(expected_sampling_rate);
80 81
81 - std::cout << "wav filename: " << wav_filename << "\n";  
82 - std::cout << "wav duration (s): " << duration << "\n"; 82 + fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
  83 + fprintf(stderr, "wav duration (s): %.3f\n", duration);
83 84
84 auto begin = std::chrono::steady_clock::now(); 85 auto begin = std::chrono::steady_clock::now();
85 - std::cout << "Started!\n"; 86 + fprintf(stderr, "Started\n");
86 87
87 sherpa_onnx::FeatureExtractor feat_extractor; 88 sherpa_onnx::FeatureExtractor feat_extractor;
88 feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), 89 feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
@@ -115,10 +116,10 @@ for a list of pre-trained models to download. @@ -115,10 +116,10 @@ for a list of pre-trained models to download.
115 text += sym[hyp[i]]; 116 text += sym[hyp[i]];
116 } 117 }
117 118
118 - std::cout << "Done!\n"; 119 + fprintf(stderr, "Done!\n");
119 120
120 - std::cout << "Recognition result for " << wav_filename << "\n"  
121 - << text << "\n"; 121 + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
  122 + text.c_str());
122 123
123 auto end = std::chrono::steady_clock::now(); 124 auto end = std::chrono::steady_clock::now();
124 float elapsed_seconds = 125 float elapsed_seconds =
@@ -126,7 +127,7 @@ for a list of pre-trained models to download. @@ -126,7 +127,7 @@ for a list of pre-trained models to download.
126 .count() / 127 .count() /
127 1000.; 128 1000.;
128 129
129 - std::cout << "num threads: " << config.num_threads << "\n"; 130 + fprintf(stderr, "num threads: %d\n", config.num_threads);
130 131
131 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); 132 fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
132 float rtf = elapsed_seconds / duration; 133 float rtf = elapsed_seconds / duration;