正在显示
20 个修改的文件
包含
45 行增加
和
57 行删除
| @@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | @@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { | ||
| 26 | po->Register("feat-dim", &feature_dim, | 26 | po->Register("feat-dim", &feature_dim, |
| 27 | "Feature dimension. Must match the one expected by the model."); | 27 | "Feature dimension. Must match the one expected by the model."); |
| 28 | 28 | ||
| 29 | - po->Register("low-freq", &low_freq, | ||
| 30 | - "Low cutoff frequency for mel bins"); | 29 | + po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); |
| 31 | 30 | ||
| 32 | po->Register("high-freq", &high_freq, | 31 | po->Register("high-freq", &high_freq, |
| 33 | "High cutoff frequency for mel bins " | 32 | "High cutoff frequency for mel bins " |
| @@ -67,7 +66,7 @@ class FeatureExtractor::Impl { | @@ -67,7 +66,7 @@ class FeatureExtractor::Impl { | ||
| 67 | opts_.mel_opts.num_bins = config.feature_dim; | 66 | opts_.mel_opts.num_bins = config.feature_dim; |
| 68 | 67 | ||
| 69 | opts_.mel_opts.high_freq = config.high_freq; | 68 | opts_.mel_opts.high_freq = config.high_freq; |
| 70 | - opts_.mel_opts.low_freq = config.low_freq; | 69 | + opts_.mel_opts.low_freq = config.low_freq; |
| 71 | 70 | ||
| 72 | opts_.mel_opts.is_librosa = config.is_librosa; | 71 | opts_.mel_opts.is_librosa = config.is_librosa; |
| 73 | 72 |
| @@ -15,7 +15,7 @@ void OfflineLMConfig::Register(ParseOptions *po) { | @@ -15,7 +15,7 @@ void OfflineLMConfig::Register(ParseOptions *po) { | ||
| 15 | po->Register("lm", &model, "Path to LM model."); | 15 | po->Register("lm", &model, "Path to LM model."); |
| 16 | po->Register("lm-scale", &scale, "LM scale."); | 16 | po->Register("lm-scale", &scale, "LM scale."); |
| 17 | po->Register("lm-num-threads", &lm_num_threads, | 17 | po->Register("lm-num-threads", &lm_num_threads, |
| 18 | - "Number of threads to run the neural network of LM model"); | 18 | + "Number of threads to run the neural network of LM model"); |
| 19 | po->Register("lm-provider", &lm_provider, | 19 | po->Register("lm-provider", &lm_provider, |
| 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); | 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); |
| 21 | } | 21 | } |
| @@ -80,9 +80,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -80,9 +80,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 80 | InitHotwords(); | 80 | InitHotwords(); |
| 81 | } | 81 | } |
| 82 | if (config_.decoding_method == "greedy_search") { | 82 | if (config_.decoding_method == "greedy_search") { |
| 83 | - decoder_ = | ||
| 84 | - std::make_unique<OfflineTransducerGreedySearchDecoder>( | ||
| 85 | - model_.get(), config_.blank_penalty); | 83 | + decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 84 | + model_.get(), config_.blank_penalty); | ||
| 86 | } else if (config_.decoding_method == "modified_beam_search") { | 85 | } else if (config_.decoding_method == "modified_beam_search") { |
| 87 | if (!config_.lm_config.model.empty()) { | 86 | if (!config_.lm_config.model.empty()) { |
| 88 | lm_ = OfflineLM::Create(config.lm_config); | 87 | lm_ = OfflineLM::Create(config.lm_config); |
| @@ -106,9 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -106,9 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 106 | model_(std::make_unique<OfflineTransducerModel>(mgr, | 105 | model_(std::make_unique<OfflineTransducerModel>(mgr, |
| 107 | config_.model_config)) { | 106 | config_.model_config)) { |
| 108 | if (config_.decoding_method == "greedy_search") { | 107 | if (config_.decoding_method == "greedy_search") { |
| 109 | - decoder_ = | ||
| 110 | - std::make_unique<OfflineTransducerGreedySearchDecoder>( | ||
| 111 | - model_.get(), config_.blank_penalty); | 108 | + decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>( |
| 109 | + model_.get(), config_.blank_penalty); | ||
| 112 | } else if (config_.decoding_method == "modified_beam_search") { | 110 | } else if (config_.decoding_method == "modified_beam_search") { |
| 113 | if (!config_.lm_config.model.empty()) { | 111 | if (!config_.lm_config.model.empty()) { |
| 114 | lm_ = OfflineLM::Create(mgr, config.lm_config); | 112 | lm_ = OfflineLM::Create(mgr, config.lm_config); |
| @@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | @@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { | ||
| 16 | public: | 16 | public: |
| 17 | explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, | 17 | explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, |
| 18 | float blank_penalty) | 18 | float blank_penalty) |
| 19 | - : model_(model), | ||
| 20 | - blank_penalty_(blank_penalty) {} | 19 | + : model_(model), blank_penalty_(blank_penalty) {} |
| 21 | 20 | ||
| 22 | std::vector<OfflineTransducerDecoderResult> Decode( | 21 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 23 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 22 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() { | @@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() { | ||
| 102 | asio::post(server_->GetConnectionContext(), | 102 | asio::post(server_->GetConnectionContext(), |
| 103 | [this, hdl, result = ss[i]->GetResult()]() { | 103 | [this, hdl, result = ss[i]->GetResult()]() { |
| 104 | websocketpp::lib::error_code ec; | 104 | websocketpp::lib::error_code ec; |
| 105 | - server_->GetServer().send( | ||
| 106 | - hdl, result.AsJsonString(), | ||
| 107 | - websocketpp::frame::opcode::text, ec); | 105 | + server_->GetServer().send(hdl, result.AsJsonString(), |
| 106 | + websocketpp::frame::opcode::text, | ||
| 107 | + ec); | ||
| 108 | if (ec) { | 108 | if (ec) { |
| 109 | server_->GetServer().get_alog().write( | 109 | server_->GetServer().get_alog().write( |
| 110 | websocketpp::log::alevel::app, ec.message()); | 110 | websocketpp::log::alevel::app, ec.message()); |
| @@ -15,7 +15,7 @@ void OnlineLMConfig::Register(ParseOptions *po) { | @@ -15,7 +15,7 @@ void OnlineLMConfig::Register(ParseOptions *po) { | ||
| 15 | po->Register("lm", &model, "Path to LM model."); | 15 | po->Register("lm", &model, "Path to LM model."); |
| 16 | po->Register("lm-scale", &scale, "LM scale."); | 16 | po->Register("lm-scale", &scale, "LM scale."); |
| 17 | po->Register("lm-num-threads", &lm_num_threads, | 17 | po->Register("lm-num-threads", &lm_num_threads, |
| 18 | - "Number of threads to run the neural network of LM model"); | 18 | + "Number of threads to run the neural network of LM model"); |
| 19 | po->Register("lm-provider", &lm_provider, | 19 | po->Register("lm-provider", &lm_provider, |
| 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); | 20 | "Specify a provider to LM model use: cpu, cuda, coreml"); |
| 21 | } | 21 | } |
| @@ -22,7 +22,7 @@ struct OnlineLMConfig { | @@ -22,7 +22,7 @@ struct OnlineLMConfig { | ||
| 22 | OnlineLMConfig() = default; | 22 | OnlineLMConfig() = default; |
| 23 | 23 | ||
| 24 | OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, | 24 | OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, |
| 25 | - const std::string &lm_provider) | 25 | + const std::string &lm_provider) |
| 26 | : model(model), | 26 | : model(model), |
| 27 | scale(scale), | 27 | scale(scale), |
| 28 | lm_num_threads(lm_num_threads), | 28 | lm_num_threads(lm_num_threads), |
| @@ -40,8 +40,7 @@ struct OnlineModelConfig { | @@ -40,8 +40,7 @@ struct OnlineModelConfig { | ||
| 40 | const OnlineWenetCtcModelConfig &wenet_ctc, | 40 | const OnlineWenetCtcModelConfig &wenet_ctc, |
| 41 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, | 41 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, |
| 42 | const std::string &tokens, int32_t num_threads, | 42 | const std::string &tokens, int32_t num_threads, |
| 43 | - int32_t warm_up, bool debug, | ||
| 44 | - const std::string &provider, | 43 | + int32_t warm_up, bool debug, const std::string &provider, |
| 45 | const std::string &model_type) | 44 | const std::string &model_type) |
| 46 | : transducer(transducer), | 45 | : transducer(transducer), |
| 47 | paraformer(paraformer), | 46 | paraformer(paraformer), |
| @@ -30,9 +30,9 @@ | @@ -30,9 +30,9 @@ | ||
| 30 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | 30 | #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 31 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 31 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 32 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" | 32 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" |
| 33 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 33 | #include "sherpa-onnx/csrc/symbol-table.h" | 34 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 34 | #include "sherpa-onnx/csrc/utils.h" | 35 | #include "sherpa-onnx/csrc/utils.h" |
| 35 | -#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 36 | 36 | ||
| 37 | namespace sherpa_onnx { | 37 | namespace sherpa_onnx { |
| 38 | 38 | ||
| @@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 185 | } | 185 | } |
| 186 | 186 | ||
| 187 | // Warmping up engine with wp: warm_up count and max-batch-size | 187 | // Warmping up engine with wp: warm_up count and max-batch-size |
| 188 | - void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { | 188 | + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const override { |
| 189 | auto max_batch_size = mbs; | 189 | auto max_batch_size = mbs; |
| 190 | if (warmup <= 0 || warmup > 100) { | 190 | if (warmup <= 0 || warmup > 100) { |
| 191 | return; | 191 | return; |
| @@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 210 | for (int32_t i = 0; i != warmup; ++i) { | 210 | for (int32_t i = 0; i != warmup; ++i) { |
| 211 | auto states = model_->StackStates(states_vec); | 211 | auto states = model_->StackStates(states_vec); |
| 212 | Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | 212 | Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), |
| 213 | - features_vec.size(), x_shape.data(), | ||
| 214 | - x_shape.size()); | 213 | + features_vec.size(), |
| 214 | + x_shape.data(), x_shape.size()); | ||
| 215 | auto x_copy = Clone(model_->Allocator(), &x); | 215 | auto x_copy = Clone(model_->Allocator(), &x); |
| 216 | auto pair = model_->RunEncoder(std::move(x), std::move(states), | 216 | auto pair = model_->RunEncoder(std::move(x), std::move(states), |
| 217 | std::move(x_copy)); | 217 | std::move(x_copy)); |
| @@ -168,7 +168,7 @@ class OnlineRecognizer { | @@ -168,7 +168,7 @@ class OnlineRecognizer { | ||
| 168 | * | 168 | * |
| 169 | * @param warmup Number of warmups. | 169 | * @param warmup Number of warmups. |
| 170 | * @param mbs : max-batch-size Max batch size for the models | 170 | * @param mbs : max-batch-size Max batch size for the models |
| 171 | - */ | 171 | + */ |
| 172 | void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const; | 172 | void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const; |
| 173 | 173 | ||
| 174 | /** Decode multiple streams in parallel | 174 | /** Decode multiple streams in parallel |
| @@ -12,8 +12,8 @@ | @@ -12,8 +12,8 @@ | ||
| 12 | #include "onnxruntime_cxx_api.h" // NOLINT | 12 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 13 | #include "sherpa-onnx/csrc/macros.h" | 13 | #include "sherpa-onnx/csrc/macros.h" |
| 14 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 15 | -#include "sherpa-onnx/csrc/text-utils.h" | ||
| 16 | #include "sherpa-onnx/csrc/session.h" | 15 | #include "sherpa-onnx/csrc/session.h" |
| 16 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 17 | 17 | ||
| 18 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| 19 | 19 | ||
| @@ -42,10 +42,9 @@ class OnlineRnnLM::Impl { | @@ -42,10 +42,9 @@ class OnlineRnnLM::Impl { | ||
| 42 | // nn_lm_scores | 42 | // nn_lm_scores |
| 43 | std::array<int64_t, 2> x_shape{1, 1}; | 43 | std::array<int64_t, 2> x_shape{1, 1}; |
| 44 | Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | 44 | Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), |
| 45 | - x_shape.size()); | 45 | + x_shape.size()); |
| 46 | *x.GetTensorMutableData<int64_t>() = hyp->ys.back(); | 46 | *x.GetTensorMutableData<int64_t>() = hyp->ys.back(); |
| 47 | - auto lm_out = | ||
| 48 | - ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); | 47 | + auto lm_out = ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); |
| 49 | hyp->nn_lm_scores.value = std::move(lm_out.first); | 48 | hyp->nn_lm_scores.value = std::move(lm_out.first); |
| 50 | hyp->nn_lm_states = Convert(std::move(lm_out.second)); | 49 | hyp->nn_lm_states = Convert(std::move(lm_out.second)); |
| 51 | } | 50 | } |
| @@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | @@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | ||
| 71 | r->tokens = std::vector<int64_t>(start, end); | 71 | r->tokens = std::vector<int64_t>(start, end); |
| 72 | } | 72 | } |
| 73 | 73 | ||
| 74 | - | ||
| 75 | void OnlineTransducerGreedySearchDecoder::Decode( | 74 | void OnlineTransducerGreedySearchDecoder::Decode( |
| 76 | Ort::Value encoder_out, | 75 | Ort::Value encoder_out, |
| 77 | std::vector<OnlineTransducerDecoderResult> *result) { | 76 | std::vector<OnlineTransducerDecoderResult> *result) { |
| 78 | - | ||
| 79 | std::vector<int64_t> encoder_out_shape = | 77 | std::vector<int64_t> encoder_out_shape = |
| 80 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 78 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 81 | 79 | ||
| @@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 106 | r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 104 | r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 107 | decoder_out_shape[0] = batch_size; | 105 | decoder_out_shape[0] = batch_size; |
| 108 | decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), | 106 | decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), |
| 109 | - decoder_out_shape.data(), decoder_out_shape.size()); | 107 | + decoder_out_shape.data(), |
| 108 | + decoder_out_shape.size()); | ||
| 110 | UseCachedDecoderOut(*result, &decoder_out); | 109 | UseCachedDecoderOut(*result, &decoder_out); |
| 111 | } else { | 110 | } else { |
| 112 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); | 111 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| @@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 116 | for (int32_t t = 0; t != num_frames; ++t) { | 115 | for (int32_t t = 0; t != num_frames; ++t) { |
| 117 | Ort::Value cur_encoder_out = | 116 | Ort::Value cur_encoder_out = |
| 118 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | 117 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); |
| 119 | - Ort::Value logit = model_->RunJoiner( | ||
| 120 | - std::move(cur_encoder_out), View(&decoder_out)); | 118 | + Ort::Value logit = |
| 119 | + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); | ||
| 121 | 120 | ||
| 122 | float *p_logit = logit.GetTensorMutableData<float>(); | 121 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 123 | 122 | ||
| @@ -145,9 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -145,9 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 145 | 144 | ||
| 146 | // export the per-token log scores | 145 | // export the per-token log scores |
| 147 | if (y != 0 && y != unk_id_) { | 146 | if (y != 0 && y != unk_id_) { |
| 148 | - LogSoftmax(p_logit, vocab_size); // renormalize probabilities, | ||
| 149 | - // save time by doing it only for | ||
| 150 | - // emitted symbols | 147 | + LogSoftmax(p_logit, vocab_size); // renormalize probabilities, |
| 148 | + // save time by doing it only for | ||
| 149 | + // emitted symbols | ||
| 151 | const float *p_logprob = p_logit; // rename p_logit as p_logprob, | 150 | const float *p_logprob = p_logit; // rename p_logit as p_logprob, |
| 152 | // now it contains normalized | 151 | // now it contains normalized |
| 153 | // probability | 152 | // probability |
| @@ -15,8 +15,7 @@ namespace sherpa_onnx { | @@ -15,8 +15,7 @@ namespace sherpa_onnx { | ||
| 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | 15 | class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { |
| 16 | public: | 16 | public: |
| 17 | OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, | 17 | OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, |
| 18 | - int32_t unk_id, | ||
| 19 | - float blank_penalty) | 18 | + int32_t unk_id, float blank_penalty) |
| 20 | : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} | 19 | : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} |
| 21 | 20 | ||
| 22 | OnlineTransducerDecoderResult GetEmptyResult() const override; | 21 | OnlineTransducerDecoderResult GetEmptyResult() const override; |
| @@ -69,7 +69,7 @@ class OnlineTransducerModel { | @@ -69,7 +69,7 @@ class OnlineTransducerModel { | ||
| 69 | * This has to be called before GetEncoderInitStates(), so the `encoder_embed` | 69 | * This has to be called before GetEncoderInitStates(), so the `encoder_embed` |
| 70 | * init state has the correct `embed_dim` of its output. | 70 | * init state has the correct `embed_dim` of its output. |
| 71 | */ | 71 | */ |
| 72 | - virtual void SetFeatureDim(int32_t feature_dim) { } | 72 | + virtual void SetFeatureDim(int32_t feature_dim) {} |
| 73 | 73 | ||
| 74 | /** Run the encoder. | 74 | /** Run the encoder. |
| 75 | * | 75 | * |
| @@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 188 | // score of the transducer | 188 | // score of the transducer |
| 189 | // export the per-token log scores | 189 | // export the per-token log scores |
| 190 | if (new_token != 0 && new_token != unk_id_) { | 190 | if (new_token != 0 && new_token != unk_id_) { |
| 191 | - const Hypothesis& prev_i = prev[hyp_index]; | 191 | + const Hypothesis &prev_i = prev[hyp_index]; |
| 192 | // subtract 'prev[i]' path scores, which were added before | 192 | // subtract 'prev[i]' path scores, which were added before |
| 193 | // getting topk tokens | 193 | // getting topk tokens |
| 194 | float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; | 194 | float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; |
| @@ -16,10 +16,10 @@ TEST(Stack, Test1DTensors) { | @@ -16,10 +16,10 @@ TEST(Stack, Test1DTensors) { | ||
| 16 | std::array<int64_t, 1> b_shape{3}; | 16 | std::array<int64_t, 1> b_shape{3}; |
| 17 | 17 | ||
| 18 | Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), | 18 | Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), |
| 19 | - a_shape.size()); | 19 | + a_shape.size()); |
| 20 | 20 | ||
| 21 | Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), | 21 | Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), |
| 22 | - b_shape.size()); | 22 | + b_shape.size()); |
| 23 | float *pa = a.GetTensorMutableData<float>(); | 23 | float *pa = a.GetTensorMutableData<float>(); |
| 24 | float *pb = b.GetTensorMutableData<float>(); | 24 | float *pb = b.GetTensorMutableData<float>(); |
| 25 | for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { | 25 | for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { |
| @@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) { | @@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) { | ||
| 51 | std::array<int64_t, 2> a_shape{2, 3}; | 51 | std::array<int64_t, 2> a_shape{2, 3}; |
| 52 | std::array<int64_t, 2> b_shape{2, 3}; | 52 | std::array<int64_t, 2> b_shape{2, 3}; |
| 53 | 53 | ||
| 54 | - Ort::Value a = Ort::Value::CreateTensor<float>( | ||
| 55 | - allocator, a_shape.data(), a_shape.size()); | 54 | + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), |
| 55 | + a_shape.size()); | ||
| 56 | 56 | ||
| 57 | - Ort::Value b = Ort::Value::CreateTensor<float>( | ||
| 58 | - allocator, b_shape.data(), b_shape.size()); | 57 | + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), |
| 58 | + b_shape.size()); | ||
| 59 | 59 | ||
| 60 | float *pa = a.GetTensorMutableData<float>(); | 60 | float *pa = a.GetTensorMutableData<float>(); |
| 61 | float *pb = b.GetTensorMutableData<float>(); | 61 | float *pb = b.GetTensorMutableData<float>(); |
| @@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) { | @@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) { | ||
| 12 | using PyClass = FeatureExtractorConfig; | 12 | using PyClass = FeatureExtractorConfig; |
| 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") | 13 | py::class_<PyClass>(*m, "FeatureExtractorConfig") |
| 14 | .def(py::init<int32_t, int32_t, float, float, float>(), | 14 | .def(py::init<int32_t, int32_t, float, float, float>(), |
| 15 | - py::arg("sampling_rate") = 16000, | ||
| 16 | - py::arg("feature_dim") = 80, | ||
| 17 | - py::arg("low_freq") = 20.0f, | ||
| 18 | - py::arg("high_freq") = -400.0f, | 15 | + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, |
| 16 | + py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f, | ||
| 19 | py::arg("dither") = 0.0f) | 17 | py::arg("dither") = 0.0f) |
| 20 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) | 18 | .def_readwrite("sampling_rate", &PyClass::sampling_rate) |
| 21 | .def_readwrite("feature_dim", &PyClass::feature_dim) | 19 | .def_readwrite("feature_dim", &PyClass::feature_dim) |
| @@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | 23 | py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), |
| 24 | py::arg("decoding_method") = "greedy_search", | 24 | py::arg("decoding_method") = "greedy_search", |
| 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 26 | - py::arg("hotwords_score") = 1.5, | ||
| 27 | - py::arg("blank_penalty") = 0.0) | 26 | + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0) |
| 28 | .def_readwrite("feat_config", &PyClass::feat_config) | 27 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 29 | .def_readwrite("model_config", &PyClass::model_config) | 28 | .def_readwrite("model_config", &PyClass::model_config) |
| 30 | .def_readwrite("lm_config", &PyClass::lm_config) | 29 | .def_readwrite("lm_config", &PyClass::lm_config) |
| @@ -4,7 +4,6 @@ | @@ -4,7 +4,6 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 5 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" |
| 6 | 6 | ||
| 7 | - | ||
| 8 | #include <string> | 7 | #include <string> |
| 9 | #include <vector> | 8 | #include <vector> |
| 10 | 9 | ||
| @@ -16,7 +15,7 @@ void PybindOfflineTransducerModelConfig(py::module *m) { | @@ -16,7 +15,7 @@ void PybindOfflineTransducerModelConfig(py::module *m) { | ||
| 16 | using PyClass = OfflineTransducerModelConfig; | 15 | using PyClass = OfflineTransducerModelConfig; |
| 17 | py::class_<PyClass>(*m, "OfflineTransducerModelConfig") | 16 | py::class_<PyClass>(*m, "OfflineTransducerModelConfig") |
| 18 | .def(py::init<const std::string &, const std::string &, | 17 | .def(py::init<const std::string &, const std::string &, |
| 19 | - const std::string &>(), | 18 | + const std::string &>(), |
| 20 | py::arg("encoder_filename"), py::arg("decoder_filename"), | 19 | py::arg("encoder_filename"), py::arg("decoder_filename"), |
| 21 | py::arg("joiner_filename")) | 20 | py::arg("joiner_filename")) |
| 22 | .def_readwrite("encoder_filename", &PyClass::encoder_filename) | 21 | .def_readwrite("encoder_filename", &PyClass::encoder_filename) |
| @@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 27 | .def(py::init<const OnlineTransducerModelConfig &, | 27 | .def(py::init<const OnlineTransducerModelConfig &, |
| 28 | const OnlineParaformerModelConfig &, | 28 | const OnlineParaformerModelConfig &, |
| 29 | const OnlineWenetCtcModelConfig &, | 29 | const OnlineWenetCtcModelConfig &, |
| 30 | - const OnlineZipformer2CtcModelConfig &, | ||
| 31 | - const std::string &, int32_t, int32_t, | ||
| 32 | - bool, const std::string &, const std::string &>(), | 30 | + const OnlineZipformer2CtcModelConfig &, const std::string &, |
| 31 | + int32_t, int32_t, bool, const std::string &, | ||
| 32 | + const std::string &>(), | ||
| 33 | py::arg("transducer") = OnlineTransducerModelConfig(), | 33 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 34 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 34 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 35 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | 35 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), |
-
请 注册 或 登录 后发表评论