正在显示
20 个修改的文件
包含
31 行增加
和
43 行删除
| @@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | @@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 26 | po->Register("feat-dim", &feature_dim, | 26 | po->Register("feat-dim", &feature_dim, |
| 27 | "Feature dimension. Must match the one expected by the model."); | 27 | "Feature dimension. Must match the one expected by the model."); |
| 28 | 28 | ||
| 29 | - po->Register("low-freq", &low_freq, | ||
| 30 | - "Low cutoff frequency for mel bins"); | 29 | + po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); |
| 31 | 30 | ||
| 32 | po->Register("high-freq", &high_freq, | 31 | po->Register("high-freq", &high_freq, |
| 33 | "High cutoff frequency for mel bins " | 32 | "High cutoff frequency for mel bins " |
| @@ -80,8 +80,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -80,8 +80,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 80 | InitHotwords(); | 80 | InitHotwords(); |
| 81 | } | 81 | } |
| 82 | if (config_.decoding_method == "greedy_search") { | 82 | if (config_.decoding_method == "greedy_search") { |
| 83 | - decoder_ = | ||
| 84 | - std::make_unique<OfflineTransducerGreedySearchDecoder>( | 83 | + decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 85 | model_.get(), config_.blank_penalty); | 84 | model_.get(), config_.blank_penalty); |
| 86 | } else if (config_.decoding_method == "modified_beam_search") { | 85 | } else if (config_.decoding_method == "modified_beam_search") { |
| 87 | if (!config_.lm_config.model.empty()) { | 86 | if (!config_.lm_config.model.empty()) { |
| @@ -106,8 +105,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -106,8 +105,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 106 | model_(std::make_unique<OfflineTransducerModel>(mgr, | 105 | model_(std::make_unique<OfflineTransducerModel>(mgr, |
| 107 | config_.model_config)) { | 106 | config_.model_config)) { |
| 108 | if (config_.decoding_method == "greedy_search") { | 107 | if (config_.decoding_method == "greedy_search") { |
| 109 | - decoder_ = | ||
| 110 | - std::make_unique<OfflineTransducerGreedySearchDecoder>( | 108 | + decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 111 | model_.get(), config_.blank_penalty); | 109 | model_.get(), config_.blank_penalty); |
| 112 | } else if (config_.decoding_method == "modified_beam_search") { | 110 | } else if (config_.decoding_method == "modified_beam_search") { |
| 113 | if (!config_.lm_config.model.empty()) { | 111 | if (!config_.lm_config.model.empty()) { |
| @@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | @@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | ||
| 16 | public: | 16 | public: |
| 17 | explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, | 17 | explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, |
| 18 | float blank_penalty) | 18 | float blank_penalty) |
| 19 | - : model_(model), | ||
| 20 | - blank_penalty_(blank_penalty) {} | 19 | + : model_(model), blank_penalty_(blank_penalty) {} |
| 21 | 20 | ||
| 22 | std::vector<OfflineTransducerDecoderResult> Decode( | 21 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 23 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 22 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() { | @@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() { | ||
| 102 | asio::post(server_->GetConnectionContext(), | 102 | asio::post(server_->GetConnectionContext(), |
| 103 | [this, hdl, result = ss[i]->GetResult()]() { | 103 | [this, hdl, result = ss[i]->GetResult()]() { |
| 104 | websocketpp::lib::error_code ec; | 104 | websocketpp::lib::error_code ec; |
| 105 | - server_->GetServer().send( | ||
| 106 | - hdl, result.AsJsonString(), | ||
| 107 | - websocketpp::frame::opcode::text, ec); | 105 | + server_->GetServer().send(hdl, result.AsJsonString(), |
| 106 | + websocketpp::frame::opcode::text, | ||
| 107 | + ec); | ||
| 108 | if (ec) { | 108 | if (ec) { |
| 109 | server_->GetServer().get_alog().write( | 109 | server_->GetServer().get_alog().write( |
| 110 | websocketpp::log::alevel::app, ec.message()); | 110 | websocketpp::log::alevel::app, ec.message()); |
| @@ -40,8 +40,7 @@ struct OnlineModelConfig { | @@ -40,8 +40,7 @@ struct OnlineModelConfig { | ||
| 40 | const OnlineWenetCtcModelConfig &wenet_ctc, | 40 | const OnlineWenetCtcModelConfig &wenet_ctc, |
| 41 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, | 41 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, |
| 42 | const std::string &tokens, int32_t num_threads, | 42 | const std::string &tokens, int32_t num_threads, |
| 43 | - int32_t warm_up, bool debug, | ||
| 44 | - const std::string &provider, | 43 | + int32_t warm_up, bool debug, const std::string &provider, |
| 45 | const std::string &model_type) | 44 | const std::string &model_type) |
| 46 | : transducer(transducer), | 45 | : transducer(transducer), |
| 47 | paraformer(paraformer), | 46 | paraformer(paraformer), |
| @@ -30,9 +30,9 @@ | @@ -30,9 +30,9 @@ | ||
| 30 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 30 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 31 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 31 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 32 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" | 32 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" |
| 33 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 33 | #include "sherpa-onnx/csrc/symbol-table.h" | 34 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 34 | #include "sherpa-onnx/csrc/utils.h" | 35 | #include "sherpa-onnx/csrc/utils.h" |
| 35 | -#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 36 | 36 | ||
| 37 | namespace sherpa_onnx { | 37 | namespace sherpa_onnx { |
| 38 | 38 | ||
| @@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 185 | } | 185 | } |
| 186 | 186 | ||
| 187 | // Warmping up engine with wp: warm_up count and max-batch-size | 187 | // Warmping up engine with wp: warm_up count and max-batch-size |
| 188 | - void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { | 188 | + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const override { |
| 189 | auto max_batch_size = mbs; | 189 | auto max_batch_size = mbs; |
| 190 | if (warmup <= 0 || warmup > 100) { | 190 | if (warmup <= 0 || warmup > 100) { |
| 191 | return; | 191 | return; |
| @@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 210 | for (int32_t i = 0; i != warmup; ++i) { | 210 | for (int32_t i = 0; i != warmup; ++i) { |
| 211 | auto states = model_->StackStates(states_vec); | 211 | auto states = model_->StackStates(states_vec); |
| 212 | Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | 212 | Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), |
| 213 | - features_vec.size(), x_shape.data(), | ||
| 214 | - x_shape.size()); | 213 | + features_vec.size(), |
| 214 | + x_shape.data(), x_shape.size()); | ||
| 215 | auto x_copy = Clone(model_->Allocator(), &x); | 215 | auto x_copy = Clone(model_->Allocator(), &x); |
| 216 | auto pair = model_->RunEncoder(std::move(x), std::move(states), | 216 | auto pair = model_->RunEncoder(std::move(x), std::move(states), |
| 217 | std::move(x_copy)); | 217 | std::move(x_copy)); |
| @@ -12,8 +12,8 @@ | @@ -12,8 +12,8 @@ | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 12 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 15 | -#include "sherpa-onnx/csrc/text-utils.h" | ||
| 16 | #include "sherpa-onnx/csrc/session.h" | 15 | #include "sherpa-onnx/csrc/session.h" |
| 16 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 17 | 17 | ||
| 18 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| 19 | 19 | ||
| @@ -44,8 +44,7 @@ class OnlineRnnLM::Impl { | @@ -44,8 +44,7 @@ class OnlineRnnLM::Impl { | ||
| 44 | Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | 44 | Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), |
| 45 | x_shape.size()); | 45 | x_shape.size()); |
| 46 | *x.GetTensorMutableData<int64_t>() = hyp->ys.back(); | 46 | *x.GetTensorMutableData<int64_t>() = hyp->ys.back(); |
| 47 | - auto lm_out = | ||
| 48 | - ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); | 47 | + auto lm_out = ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); |
| 49 | hyp->nn_lm_scores.value = std::move(lm_out.first); | 48 | hyp->nn_lm_scores.value = std::move(lm_out.first); |
| 50 | hyp->nn_lm_states = Convert(std::move(lm_out.second)); | 49 | hyp->nn_lm_states = Convert(std::move(lm_out.second)); |
| 51 | } | 50 | } |
| @@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | @@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | ||
| 71 | r->tokens = std::vector<int64_t>(start, end); | 71 | r->tokens = std::vector<int64_t>(start, end); |
| 72 | } | 72 | } |
| 73 | 73 | ||
| 74 | - | ||
| 75 | void OnlineTransducerGreedySearchDecoder::Decode( | 74 | void OnlineTransducerGreedySearchDecoder::Decode( |
| 76 | Ort::Value encoder_out, | 75 | Ort::Value encoder_out, |
| 77 | std::vector<OnlineTransducerDecoderResult> *result) { | 76 | std::vector<OnlineTransducerDecoderResult> *result) { |
| 78 | - | ||
| 79 | std::vector<int64_t> encoder_out_shape = | 77 | std::vector<int64_t> encoder_out_shape = |
| 80 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 78 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 81 | 79 | ||
| @@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 106 | r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 104 | r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 107 | decoder_out_shape[0] = batch_size; | 105 | decoder_out_shape[0] = batch_size; |
| 108 | decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), | 106 | decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), |
| 109 | - decoder_out_shape.data(), decoder_out_shape.size()); | 107 | + decoder_out_shape.data(), |
| 108 | + decoder_out_shape.size()); | ||
| 110 | UseCachedDecoderOut(*result, &decoder_out); | 109 | UseCachedDecoderOut(*result, &decoder_out); |
| 111 | } else { | 110 | } else { |
| 112 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); | 111 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| @@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 116 | for (int32_t t = 0; t != num_frames; ++t) { | 115 | for (int32_t t = 0; t != num_frames; ++t) { |
| 117 | Ort::Value cur_encoder_out = | 116 | Ort::Value cur_encoder_out = |
| 118 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | 117 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); |
| 119 | - Ort::Value logit = model_->RunJoiner( | ||
| 120 | - std::move(cur_encoder_out), View(&decoder_out)); | 118 | + Ort::Value logit = |
| 119 | + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | ||
| 121 | 120 | ||
| 122 | float *p_logit = logit.GetTensorMutableData<float>(); | 121 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 123 | 122 |
| @@ -15,8 +15,7 @@ namespace sherpa_onnx { | @@ -15,8 +15,7 @@ namespace sherpa_onnx { | ||
| 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, | 17 | OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, |
| 18 | - int32_t unk_id, | ||
| 19 | - float blank_penalty) | 18 | + int32_t unk_id, float blank_penalty) |
| 20 | : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} | 19 | : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} |
| 21 | 20 | ||
| 22 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 21 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| @@ -69,7 +69,7 @@ class OnlineTransducerModel { | @@ -69,7 +69,7 @@ class OnlineTransducerModel { | ||
| 69 | * This has to be called before GetEncoderInitStates(), so the `encoder_embed` | 69 | * This has to be called before GetEncoderInitStates(), so the `encoder_embed` |
| 70 | * init state has the correct `embed_dim` of its output. | 70 | * init state has the correct `embed_dim` of its output. |
| 71 | */ | 71 | */ |
| 72 | - virtual void SetFeatureDim(int32_t feature_dim) { } | 72 | + virtual void SetFeatureDim(int32_t feature_dim) {} |
| 73 | 73 | ||
| 74 | /** Run the encoder. | 74 | /** Run the encoder. |
| 75 | * | 75 | * |
| @@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 188 | // score of the transducer | 188 | // score of the transducer |
| 189 | // export the per-token log scores | 189 | // export the per-token log scores |
| 190 | if (new_token != 0 && new_token != unk_id_) { | 190 | if (new_token != 0 && new_token != unk_id_) { |
| 191 | - const Hypothesis& prev_i = prev[hyp_index]; | 191 | + const Hypothesis &prev_i = prev[hyp_index]; |
| 192 | // subtract 'prev[i]' path scores, which were added before | 192 | // subtract 'prev[i]' path scores, which were added before |
| 193 | // getting topk tokens | 193 | // getting topk tokens |
| 194 | float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; | 194 | float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; |
| @@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) { | @@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) { | ||
| 51 | std::array<int64_t, 2> a_shape{2, 3}; | 51 | std::array<int64_t, 2> a_shape{2, 3}; |
| 52 | std::array<int64_t, 2> b_shape{2, 3}; | 52 | std::array<int64_t, 2> b_shape{2, 3}; |
| 53 | 53 | ||
| 54 | - Ort::Value a = Ort::Value::CreateTensor<float>( | ||
| 55 | - allocator, a_shape.data(), a_shape.size()); | 54 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), |
| 55 | + a_shape.size()); | ||
| 56 | 56 | ||
| 57 | - Ort::Value b = Ort::Value::CreateTensor<float>( | ||
| 58 | - allocator, b_shape.data(), b_shape.size()); | 57 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), |
| 58 | + b_shape.size()); | ||
| 59 | 59 | ||
| 60 | float *pa = a.GetTensorMutableData<float>(); | 60 | float *pa = a.GetTensorMutableData<float>(); |
| 61 | float *pb = b.GetTensorMutableData<float>(); | 61 | float *pb = b.GetTensorMutableData<float>(); |
| @@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) { | @@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) { | ||
| 12 | using PyClass = FeatureExtractorConfig; | 12 | using PyClass = FeatureExtractorConfig; |
| 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") | 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") |
| 14 | .def(py::init<int32_t, int32_t, float, float, float>(), | 14 | .def(py::init<int32_t, int32_t, float, float, float>(), |
| 15 | - py::arg("sampling_rate") = 16000, | ||
| 16 | - py::arg("feature_dim") = 80, | ||
| 17 | - py::arg("low_freq") = 20.0f, | ||
| 18 | - py::arg("high_freq") = -400.0f, | 15 | + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, |
| 16 | + py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f, | ||
| 19 | py::arg("dither") = 0.0f) | 17 | py::arg("dither") = 0.0f) |
| 20 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) | 18 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) |
| 21 | .def_readwrite("feature_dim", &PyClass::feature_dim) | 19 | .def_readwrite("feature_dim", &PyClass::feature_dim) |
| @@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), |
| 24 | py::arg("decoding_method") = "greedy_search", | 24 | py::arg("decoding_method") = "greedy_search", |
| 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 26 | - py::arg("hotwords_score") = 1.5, | ||
| 27 | - py::arg("blank_penalty") = 0.0) | 26 | + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0) |
| 28 | .def_readwrite("feat_config", &PyClass::feat_config) | 27 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 29 | .def_readwrite("model_config", &PyClass::model_config) | 28 | .def_readwrite("model_config", &PyClass::model_config) |
| 30 | .def_readwrite("lm_config", &PyClass::lm_config) | 29 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 27 | .def(py::init<const OnlineTransducerModelConfig &, | 27 | .def(py::init<const OnlineTransducerModelConfig &, |
| 28 | const OnlineParaformerModelConfig &, | 28 | const OnlineParaformerModelConfig &, |
| 29 | const OnlineWenetCtcModelConfig &, | 29 | const OnlineWenetCtcModelConfig &, |
| 30 | - const OnlineZipformer2CtcModelConfig &, | ||
| 31 | - const std::string &, int32_t, int32_t, | ||
| 32 | - bool, const std::string &, const std::string &>(), | 30 | + const OnlineZipformer2CtcModelConfig &, const std::string &, |
| 31 | + int32_t, int32_t, bool, const std::string &, | ||
| 32 | + const std::string &>(), | ||
| 33 | py::arg("transducer") = OnlineTransducerModelConfig(), | 33 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 34 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 34 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 35 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | 35 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), |
-
请 注册 或 登录 后发表评论