正在显示
10 个修改的文件
包含
96 行增加
和
26 行删除
| @@ -78,8 +78,6 @@ def get_args(): | @@ -78,8 +78,6 @@ def get_args(): | ||
| 78 | 78 | ||
| 79 | 79 | ||
| 80 | def main(): | 80 | def main(): |
| 81 | - sample_rate = 16000 | ||
| 82 | - | ||
| 83 | args = get_args() | 81 | args = get_args() |
| 84 | assert_file_exists(args.encoder) | 82 | assert_file_exists(args.encoder) |
| 85 | assert_file_exists(args.decoder) | 83 | assert_file_exists(args.decoder) |
| @@ -95,12 +93,16 @@ def main(): | @@ -95,12 +93,16 @@ def main(): | ||
| 95 | decoder=args.decoder, | 93 | decoder=args.decoder, |
| 96 | joiner=args.joiner, | 94 | joiner=args.joiner, |
| 97 | num_threads=args.num_threads, | 95 | num_threads=args.num_threads, |
| 98 | - sample_rate=sample_rate, | 96 | + sample_rate=16000, |
| 99 | feature_dim=80, | 97 | feature_dim=80, |
| 100 | decoding_method=args.decoding_method, | 98 | decoding_method=args.decoding_method, |
| 101 | ) | 99 | ) |
| 102 | with wave.open(args.wave_filename) as f: | 100 | with wave.open(args.wave_filename) as f: |
| 103 | - assert f.getframerate() == sample_rate, f.getframerate() | 101 | + # If the wave file has a different sampling rate from the one |
| 102 | + # expected by the model (16 kHz in our case), we will do | ||
| 103 | + # resampling inside sherpa-onnx | ||
| 104 | + wave_file_sample_rate = f.getframerate() | ||
| 105 | + | ||
| 104 | assert f.getnchannels() == 1, f.getnchannels() | 106 | assert f.getnchannels() == 1, f.getnchannels() |
| 105 | assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | 107 | assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes |
| 106 | num_samples = f.getnframes() | 108 | num_samples = f.getnframes() |
| @@ -110,17 +112,17 @@ def main(): | @@ -110,17 +112,17 @@ def main(): | ||
| 110 | 112 | ||
| 111 | samples_float32 = samples_float32 / 32768 | 113 | samples_float32 = samples_float32 / 32768 |
| 112 | 114 | ||
| 113 | - duration = len(samples_float32) / sample_rate | 115 | + duration = len(samples_float32) / wave_file_sample_rate |
| 114 | 116 | ||
| 115 | start_time = time.time() | 117 | start_time = time.time() |
| 116 | print("Started!") | 118 | print("Started!") |
| 117 | 119 | ||
| 118 | stream = recognizer.create_stream() | 120 | stream = recognizer.create_stream() |
| 119 | 121 | ||
| 120 | - stream.accept_waveform(sample_rate, samples_float32) | 122 | + stream.accept_waveform(wave_file_sample_rate, samples_float32) |
| 121 | 123 | ||
| 122 | - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) | ||
| 123 | - stream.accept_waveform(sample_rate, tail_paddings) | 124 | + tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32) |
| 125 | + stream.accept_waveform(wave_file_sample_rate, tail_paddings) | ||
| 124 | 126 | ||
| 125 | stream.input_finished() | 127 | stream.input_finished() |
| 126 | 128 |
| @@ -100,7 +100,9 @@ def main(): | @@ -100,7 +100,9 @@ def main(): | ||
| 100 | recognizer = create_recognizer() | 100 | recognizer = create_recognizer() |
| 101 | print("Started! Please speak") | 101 | print("Started! Please speak") |
| 102 | 102 | ||
| 103 | - sample_rate = 16000 | 103 | + # The model is using 16 kHz, we use 48 kHz here to demonstrate that |
| 104 | + # sherpa-onnx will do resampling inside. | ||
| 105 | + sample_rate = 48000 | ||
| 104 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | 106 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms |
| 105 | last_result = "" | 107 | last_result = "" |
| 106 | stream = recognizer.create_stream() | 108 | stream = recognizer.create_stream() |
| @@ -92,9 +92,12 @@ def create_recognizer(): | @@ -92,9 +92,12 @@ def create_recognizer(): | ||
| 92 | 92 | ||
| 93 | 93 | ||
| 94 | def main(): | 94 | def main(): |
| 95 | - print("Started! Please speak") | ||
| 96 | recognizer = create_recognizer() | 95 | recognizer = create_recognizer() |
| 97 | - sample_rate = 16000 | 96 | + print("Started! Please speak") |
| 97 | + | ||
| 98 | + # The model is using 16 kHz, we use 48 kHz here to demonstrate that | ||
| 99 | + # sherpa-onnx will do resampling inside. | ||
| 100 | + sample_rate = 48000 | ||
| 98 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms | 101 | samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms |
| 99 | last_result = "" | 102 | last_result = "" |
| 100 | stream = recognizer.create_stream() | 103 | stream = recognizer.create_stream() |
| @@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); | @@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); | ||
| 115 | /// decoding. | 115 | /// decoding. |
| 116 | /// | 116 | /// |
| 117 | /// @param stream A pointer returned by CreateOnlineStream(). | 117 | /// @param stream A pointer returned by CreateOnlineStream(). |
| 118 | -/// @param sample_rate Sampler rate of the input samples. It has to be 16 kHz | ||
| 119 | -/// for models from icefall. | 118 | +/// @param sample_rate Sample rate of the input samples. If it is different |
| 119 | +/// from config.feat_config.sample_rate, we will do | ||
| 120 | +/// resampling inside sherpa-onnx. | ||
| 120 | /// @param samples A pointer to a 1-D array containing audio samples. | 121 | /// @param samples A pointer to a 1-D array containing audio samples. |
| 121 | /// The range of samples has to be normalized to [-1, 1]. | 122 | /// The range of samples has to be normalized to [-1, 1]. |
| 122 | /// @param n Number of elements in the samples array. | 123 | /// @param n Number of elements in the samples array. |
| @@ -11,6 +11,8 @@ | @@ -11,6 +11,8 @@ | ||
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | #include "kaldi-native-fbank/csrc/online-feature.h" | 13 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 15 | +#include "sherpa-onnx/csrc/resample.h" | ||
| 14 | 16 | ||
| 15 | namespace sherpa_onnx { | 17 | namespace sherpa_onnx { |
| 16 | 18 | ||
| @@ -50,6 +52,46 @@ class FeatureExtractor::Impl { | @@ -50,6 +52,46 @@ class FeatureExtractor::Impl { | ||
| 50 | 52 | ||
| 51 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 53 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 52 | std::lock_guard<std::mutex> lock(mutex_); | 54 | std::lock_guard<std::mutex> lock(mutex_); |
| 55 | + | ||
| 56 | + if (resampler_) { | ||
| 57 | + if (sampling_rate != resampler_->GetInputSamplingRate()) { | ||
| 58 | + SHERPA_ONNX_LOGE( | ||
| 59 | + "You changed the input sampling rate!! Expected: %d, given: " | ||
| 60 | + "%d", | ||
| 61 | + resampler_->GetInputSamplingRate(), sampling_rate); | ||
| 62 | + exit(-1); | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + std::vector<float> samples; | ||
| 66 | + resampler_->Resample(waveform, n, false, &samples); | ||
| 67 | + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), | ||
| 68 | + samples.size()); | ||
| 69 | + return; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + if (sampling_rate != opts_.frame_opts.samp_freq) { | ||
| 73 | + SHERPA_ONNX_LOGE( | ||
| 74 | + "Creating a resampler:\n" | ||
| 75 | + " in_sample_rate: %d\n" | ||
| 76 | + " output_sample_rate: %d\n", | ||
| 77 | + sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq)); | ||
| 78 | + | ||
| 79 | + float min_freq = | ||
| 80 | + std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq); | ||
| 81 | + float lowpass_cutoff = 0.99 * 0.5 * min_freq; | ||
| 82 | + | ||
| 83 | + int32_t lowpass_filter_width = 6; | ||
| 84 | + resampler_ = std::make_unique<LinearResample>( | ||
| 85 | + sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff, | ||
| 86 | + lowpass_filter_width); | ||
| 87 | + | ||
| 88 | + std::vector<float> samples; | ||
| 89 | + resampler_->Resample(waveform, n, false, &samples); | ||
| 90 | + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), | ||
| 91 | + samples.size()); | ||
| 92 | + return; | ||
| 93 | + } | ||
| 94 | + | ||
| 53 | fbank_->AcceptWaveform(sampling_rate, waveform, n); | 95 | fbank_->AcceptWaveform(sampling_rate, waveform, n); |
| 54 | } | 96 | } |
| 55 | 97 | ||
| @@ -100,6 +142,7 @@ class FeatureExtractor::Impl { | @@ -100,6 +142,7 @@ class FeatureExtractor::Impl { | ||
| 100 | std::unique_ptr<knf::OnlineFbank> fbank_; | 142 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 101 | knf::FbankOptions opts_; | 143 | knf::FbankOptions opts_; |
| 102 | mutable std::mutex mutex_; | 144 | mutable std::mutex mutex_; |
| 145 | + std::unique_ptr<LinearResample> resampler_; | ||
| 103 | }; | 146 | }; |
| 104 | 147 | ||
| 105 | FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) | 148 | FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) |
| @@ -29,9 +29,11 @@ class FeatureExtractor { | @@ -29,9 +29,11 @@ class FeatureExtractor { | ||
| 29 | ~FeatureExtractor(); | 29 | ~FeatureExtractor(); |
| 30 | 30 | ||
| 31 | /** | 31 | /** |
| 32 | - @param sampling_rate The sampling_rate of the input waveform. Should match | ||
| 33 | - the one expected by the feature extractor. | ||
| 34 | - @param waveform Pointer to a 1-D array of size n | 32 | + @param sampling_rate The sampling_rate of the input waveform. If it does |
| 33 | + not equal to config.sampling_rate, we will do | ||
| 34 | + resampling inside. | ||
| 35 | + @param waveform Pointer to a 1-D array of size n. It must be normalized to | ||
| 36 | + the range [-1, 1]. | ||
| 35 | @param n Number of entries in waveform | 37 | @param n Number of entries in waveform |
| 36 | */ | 38 | */ |
| 37 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); | 39 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); |
| @@ -16,7 +16,7 @@ class OnlineStream::Impl { | @@ -16,7 +16,7 @@ class OnlineStream::Impl { | ||
| 16 | explicit Impl(const FeatureExtractorConfig &config) | 16 | explicit Impl(const FeatureExtractorConfig &config) |
| 17 | : feat_extractor_(config) {} | 17 | : feat_extractor_(config) {} |
| 18 | 18 | ||
| 19 | - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { | 19 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 20 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); | 20 | feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); |
| 21 | } | 21 | } |
| 22 | 22 | ||
| @@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | @@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) | ||
| 67 | 67 | ||
| 68 | OnlineStream::~OnlineStream() = default; | 68 | OnlineStream::~OnlineStream() = default; |
| 69 | 69 | ||
| 70 | -void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform, | 70 | +void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, |
| 71 | int32_t n) { | 71 | int32_t n) { |
| 72 | impl_->AcceptWaveform(sampling_rate, waveform, n); | 72 | impl_->AcceptWaveform(sampling_rate, waveform, n); |
| 73 | } | 73 | } |
| @@ -20,12 +20,14 @@ class OnlineStream { | @@ -20,12 +20,14 @@ class OnlineStream { | ||
| 20 | ~OnlineStream(); | 20 | ~OnlineStream(); |
| 21 | 21 | ||
| 22 | /** | 22 | /** |
| 23 | - @param sampling_rate The sampling_rate of the input waveform. Should match | ||
| 24 | - the one expected by the feature extractor. | ||
| 25 | - @param waveform Pointer to a 1-D array of size n | 23 | + @param sampling_rate The sampling_rate of the input waveform. If it does |
| 24 | + not equal to config.sampling_rate, we will do | ||
| 25 | + resampling inside. | ||
| 26 | + @param waveform Pointer to a 1-D array of size n. It must be normalized to | ||
| 27 | + the range [-1, 1]. | ||
| 26 | @param n Number of entries in waveform | 28 | @param n Number of entries in waveform |
| 27 | */ | 29 | */ |
| 28 | - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); | 30 | + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); |
| 29 | 31 | ||
| 30 | /** | 32 | /** |
| 31 | * InputFinished() tells the class you won't be providing any | 33 | * InputFinished() tells the class you won't be providing any |
| @@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | @@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | ||
| 76 | std::vector<int64_t> blanks(context_size, blank_id); | 76 | std::vector<int64_t> blanks(context_size, blank_id); |
| 77 | Hypotheses blank_hyp({{blanks, 0}}); | 77 | Hypotheses blank_hyp({{blanks, 0}}); |
| 78 | r.hyps = std::move(blank_hyp); | 78 | r.hyps = std::move(blank_hyp); |
| 79 | + r.tokens = std::move(blanks); | ||
| 79 | return r; | 80 | return r; |
| 80 | } | 81 | } |
| 81 | 82 |
| @@ -8,13 +8,27 @@ | @@ -8,13 +8,27 @@ | ||
| 8 | 8 | ||
| 9 | namespace sherpa_onnx { | 9 | namespace sherpa_onnx { |
| 10 | 10 | ||
| 11 | +constexpr const char *kAcceptWaveformUsage = R"( | ||
| 12 | +Process audio samples. | ||
| 13 | + | ||
| 14 | +Args: | ||
| 15 | + sample_rate: | ||
| 16 | + Sample rate of the input samples. If it is different from the one | ||
| 17 | + expected by the model, we will do resampling inside. | ||
| 18 | + waveform: | ||
| 19 | + A 1-D float32 tensor containing audio samples. It must be normalized | ||
| 20 | + to the range [-1, 1]. | ||
| 21 | +)"; | ||
| 22 | + | ||
| 11 | void PybindOnlineStream(py::module *m) { | 23 | void PybindOnlineStream(py::module *m) { |
| 12 | using PyClass = OnlineStream; | 24 | using PyClass = OnlineStream; |
| 13 | py::class_<PyClass>(*m, "OnlineStream") | 25 | py::class_<PyClass>(*m, "OnlineStream") |
| 14 | - .def("accept_waveform", | ||
| 15 | - [](PyClass &self, float sample_rate, py::array_t<float> waveform) { | ||
| 16 | - self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); | ||
| 17 | - }) | 26 | + .def( |
| 27 | + "accept_waveform", | ||
| 28 | + [](PyClass &self, float sample_rate, py::array_t<float> waveform) { | ||
| 29 | + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); | ||
| 30 | + }, | ||
| 31 | + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) | ||
| 18 | .def("input_finished", &PyClass::InputFinished); | 32 | .def("input_finished", &PyClass::InputFinished); |
| 19 | } | 33 | } |
| 20 | 34 |
-
请 注册 或 登录 后发表评论