Fangjun Kuang
Committed by GitHub

Support resampling (#77)

@@ -78,8 +78,6 @@ def get_args(): @@ -78,8 +78,6 @@ def get_args():
78 78
79 79
80 def main(): 80 def main():
81 - sample_rate = 16000  
82 -  
83 args = get_args() 81 args = get_args()
84 assert_file_exists(args.encoder) 82 assert_file_exists(args.encoder)
85 assert_file_exists(args.decoder) 83 assert_file_exists(args.decoder)
@@ -95,12 +93,16 @@ def main(): @@ -95,12 +93,16 @@ def main():
95 decoder=args.decoder, 93 decoder=args.decoder,
96 joiner=args.joiner, 94 joiner=args.joiner,
97 num_threads=args.num_threads, 95 num_threads=args.num_threads,
98 - sample_rate=sample_rate, 96 + sample_rate=16000,
99 feature_dim=80, 97 feature_dim=80,
100 decoding_method=args.decoding_method, 98 decoding_method=args.decoding_method,
101 ) 99 )
102 with wave.open(args.wave_filename) as f: 100 with wave.open(args.wave_filename) as f:
103 - assert f.getframerate() == sample_rate, f.getframerate() 101 + # If the wave file has a different sampling rate from the one
  102 + # expected by the model (16 kHz in our case), we will do
  103 + # resampling inside sherpa-onnx
  104 + wave_file_sample_rate = f.getframerate()
  105 +
104 assert f.getnchannels() == 1, f.getnchannels() 106 assert f.getnchannels() == 1, f.getnchannels()
105 assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes 107 assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
106 num_samples = f.getnframes() 108 num_samples = f.getnframes()
@@ -110,17 +112,17 @@ def main(): @@ -110,17 +112,17 @@ def main():
110 112
111 samples_float32 = samples_float32 / 32768 113 samples_float32 = samples_float32 / 32768
112 114
113 - duration = len(samples_float32) / sample_rate 115 + duration = len(samples_float32) / wave_file_sample_rate
114 116
115 start_time = time.time() 117 start_time = time.time()
116 print("Started!") 118 print("Started!")
117 119
118 stream = recognizer.create_stream() 120 stream = recognizer.create_stream()
119 121
120 - stream.accept_waveform(sample_rate, samples_float32) 122 + stream.accept_waveform(wave_file_sample_rate, samples_float32)
121 123
122 - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)  
123 - stream.accept_waveform(sample_rate, tail_paddings) 124 + tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32)
  125 + stream.accept_waveform(wave_file_sample_rate, tail_paddings)
124 126
125 stream.input_finished() 127 stream.input_finished()
126 128
@@ -100,7 +100,9 @@ def main(): @@ -100,7 +100,9 @@ def main():
100 recognizer = create_recognizer() 100 recognizer = create_recognizer()
101 print("Started! Please speak") 101 print("Started! Please speak")
102 102
103 - sample_rate = 16000 103 + # The model is using 16 kHz, we use 48 kHz here to demonstrate that
  104 + # sherpa-onnx will do resampling inside.
  105 + sample_rate = 48000
104 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 106 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
105 last_result = "" 107 last_result = ""
106 stream = recognizer.create_stream() 108 stream = recognizer.create_stream()
@@ -92,9 +92,12 @@ def create_recognizer(): @@ -92,9 +92,12 @@ def create_recognizer():
92 92
93 93
94 def main(): 94 def main():
95 - print("Started! Please speak")  
96 recognizer = create_recognizer() 95 recognizer = create_recognizer()
97 - sample_rate = 16000 96 + print("Started! Please speak")
  97 +
  98 + # The model is using 16 kHz, we use 48 kHz here to demonstrate that
  99 + # sherpa-onnx will do resampling inside.
  100 + sample_rate = 48000
98 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms 101 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
99 last_result = "" 102 last_result = ""
100 stream = recognizer.create_stream() 103 stream = recognizer.create_stream()
@@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); @@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
115 /// decoding. 115 /// decoding.
116 /// 116 ///
117 /// @param stream A pointer returned by CreateOnlineStream(). 117 /// @param stream A pointer returned by CreateOnlineStream().
118 -/// @param sample_rate Sampler rate of the input samples. It has to be 16 kHz  
119 -/// for models from icefall. 118 +/// @param sample_rate Sample rate of the input samples. If it is different
  119 +/// from config.feat_config.sample_rate, we will do
  120 +/// resampling inside sherpa-onnx.
120 /// @param samples A pointer to a 1-D array containing audio samples. 121 /// @param samples A pointer to a 1-D array containing audio samples.
121 /// The range of samples has to be normalized to [-1, 1]. 122 /// The range of samples has to be normalized to [-1, 1].
122 /// @param n Number of elements in the samples array. 123 /// @param n Number of elements in the samples array.
@@ -11,6 +11,8 @@ @@ -11,6 +11,8 @@
11 #include <vector> 11 #include <vector>
12 12
13 #include "kaldi-native-fbank/csrc/online-feature.h" 13 #include "kaldi-native-fbank/csrc/online-feature.h"
  14 +#include "sherpa-onnx/csrc/macros.h"
  15 +#include "sherpa-onnx/csrc/resample.h"
14 16
15 namespace sherpa_onnx { 17 namespace sherpa_onnx {
16 18
@@ -50,6 +52,46 @@ class FeatureExtractor::Impl { @@ -50,6 +52,46 @@ class FeatureExtractor::Impl {
50 52
51 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 53 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
52 std::lock_guard<std::mutex> lock(mutex_); 54 std::lock_guard<std::mutex> lock(mutex_);
  55 +
  56 + if (resampler_) {
  57 + if (sampling_rate != resampler_->GetInputSamplingRate()) {
  58 + SHERPA_ONNX_LOGE(
  59 + "You changed the input sampling rate!! Expected: %d, given: "
  60 + "%d",
  61 + resampler_->GetInputSamplingRate(), sampling_rate);
  62 + exit(-1);
  63 + }
  64 +
  65 + std::vector<float> samples;
  66 + resampler_->Resample(waveform, n, false, &samples);
  67 + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
  68 + samples.size());
  69 + return;
  70 + }
  71 +
  72 + if (sampling_rate != opts_.frame_opts.samp_freq) {
  73 + SHERPA_ONNX_LOGE(
  74 + "Creating a resampler:\n"
  75 + " in_sample_rate: %d\n"
  76 + " output_sample_rate: %d\n",
  77 + sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
  78 +
  79 + float min_freq =
  80 + std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
  81 + float lowpass_cutoff = 0.99 * 0.5 * min_freq;
  82 +
  83 + int32_t lowpass_filter_width = 6;
  84 + resampler_ = std::make_unique<LinearResample>(
  85 + sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
  86 + lowpass_filter_width);
  87 +
  88 + std::vector<float> samples;
  89 + resampler_->Resample(waveform, n, false, &samples);
  90 + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
  91 + samples.size());
  92 + return;
  93 + }
  94 +
53 fbank_->AcceptWaveform(sampling_rate, waveform, n); 95 fbank_->AcceptWaveform(sampling_rate, waveform, n);
54 } 96 }
55 97
@@ -100,6 +142,7 @@ class FeatureExtractor::Impl { @@ -100,6 +142,7 @@ class FeatureExtractor::Impl {
100 std::unique_ptr<knf::OnlineFbank> fbank_; 142 std::unique_ptr<knf::OnlineFbank> fbank_;
101 knf::FbankOptions opts_; 143 knf::FbankOptions opts_;
102 mutable std::mutex mutex_; 144 mutable std::mutex mutex_;
  145 + std::unique_ptr<LinearResample> resampler_;
103 }; 146 };
104 147
105 FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) 148 FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
@@ -29,9 +29,11 @@ class FeatureExtractor { @@ -29,9 +29,11 @@ class FeatureExtractor {
29 ~FeatureExtractor(); 29 ~FeatureExtractor();
30 30
31 /** 31 /**
32 - @param sampling_rate The sampling_rate of the input waveform. Should match  
33 - the one expected by the feature extractor.  
34 - @param waveform Pointer to a 1-D array of size n 32 + @param sampling_rate The sampling_rate of the input waveform. If it does
  33 + not equal to config.sampling_rate, we will do
  34 + resampling inside.
  35 + @param waveform Pointer to a 1-D array of size n. It must be normalized to
  36 + the range [-1, 1].
35 @param n Number of entries in waveform 37 @param n Number of entries in waveform
36 */ 38 */
37 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); 39 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
@@ -16,7 +16,7 @@ class OnlineStream::Impl { @@ -16,7 +16,7 @@ class OnlineStream::Impl {
16 explicit Impl(const FeatureExtractorConfig &config) 16 explicit Impl(const FeatureExtractorConfig &config)
17 : feat_extractor_(config) {} 17 : feat_extractor_(config) {}
18 18
19 - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { 19 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
20 feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); 20 feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
21 } 21 }
22 22
@@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) @@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
67 67
68 OnlineStream::~OnlineStream() = default; 68 OnlineStream::~OnlineStream() = default;
69 69
70 -void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform, 70 +void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
71 int32_t n) { 71 int32_t n) {
72 impl_->AcceptWaveform(sampling_rate, waveform, n); 72 impl_->AcceptWaveform(sampling_rate, waveform, n);
73 } 73 }
@@ -20,12 +20,14 @@ class OnlineStream { @@ -20,12 +20,14 @@ class OnlineStream {
20 ~OnlineStream(); 20 ~OnlineStream();
21 21
22 /** 22 /**
23 - @param sampling_rate The sampling_rate of the input waveform. Should match  
24 - the one expected by the feature extractor.  
25 - @param waveform Pointer to a 1-D array of size n 23 + @param sampling_rate The sampling_rate of the input waveform. If it does
  24 + not equal to config.sampling_rate, we will do
  25 + resampling inside.
  26 + @param waveform Pointer to a 1-D array of size n. It must be normalized to
  27 + the range [-1, 1].
26 @param n Number of entries in waveform 28 @param n Number of entries in waveform
27 */ 29 */
28 - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); 30 + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
29 31
30 /** 32 /**
31 * InputFinished() tells the class you won't be providing any 33 * InputFinished() tells the class you won't be providing any
@@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { @@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
76 std::vector<int64_t> blanks(context_size, blank_id); 76 std::vector<int64_t> blanks(context_size, blank_id);
77 Hypotheses blank_hyp({{blanks, 0}}); 77 Hypotheses blank_hyp({{blanks, 0}});
78 r.hyps = std::move(blank_hyp); 78 r.hyps = std::move(blank_hyp);
  79 + r.tokens = std::move(blanks);
79 return r; 80 return r;
80 } 81 }
81 82
@@ -8,13 +8,27 @@ @@ -8,13 +8,27 @@
8 8
9 namespace sherpa_onnx { 9 namespace sherpa_onnx {
10 10
  11 +constexpr const char *kAcceptWaveformUsage = R"(
  12 +Process audio samples.
  13 +
  14 +Args:
  15 + sample_rate:
  16 + Sample rate of the input samples. If it is different from the one
  17 + expected by the model, we will do resampling inside.
  18 + waveform:
  19 + A 1-D float32 tensor containing audio samples. It must be normalized
  20 + to the range [-1, 1].
  21 +)";
  22 +
11 void PybindOnlineStream(py::module *m) { 23 void PybindOnlineStream(py::module *m) {
12 using PyClass = OnlineStream; 24 using PyClass = OnlineStream;
13 py::class_<PyClass>(*m, "OnlineStream") 25 py::class_<PyClass>(*m, "OnlineStream")
14 - .def("accept_waveform",  
15 - [](PyClass &self, float sample_rate, py::array_t<float> waveform) {  
16 - self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());  
17 - }) 26 + .def(
  27 + "accept_waveform",
  28 + [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
  29 + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
  30 + },
  31 + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
18 .def("input_finished", &PyClass::InputFinished); 32 .def("input_finished", &PyClass::InputFinished);
19 } 33 }
20 34