Fangjun Kuang
Committed by GitHub

Fix code style issues (#774)

@@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) { @@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
26 po->Register("feat-dim", &feature_dim, 26 po->Register("feat-dim", &feature_dim,
27 "Feature dimension. Must match the one expected by the model."); 27 "Feature dimension. Must match the one expected by the model.");
28 28
29 - po->Register("low-freq", &low_freq,  
30 - "Low cutoff frequency for mel bins"); 29 + po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
31 30
32 po->Register("high-freq", &high_freq, 31 po->Register("high-freq", &high_freq,
33 "High cutoff frequency for mel bins " 32 "High cutoff frequency for mel bins "
@@ -67,7 +66,7 @@ class FeatureExtractor::Impl { @@ -67,7 +66,7 @@ class FeatureExtractor::Impl {
67 opts_.mel_opts.num_bins = config.feature_dim; 66 opts_.mel_opts.num_bins = config.feature_dim;
68 67
69 opts_.mel_opts.high_freq = config.high_freq; 68 opts_.mel_opts.high_freq = config.high_freq;
70 - opts_.mel_opts.low_freq = config.low_freq; 69 + opts_.mel_opts.low_freq = config.low_freq;
71 70
72 opts_.mel_opts.is_librosa = config.is_librosa; 71 opts_.mel_opts.is_librosa = config.is_librosa;
73 72
@@ -15,7 +15,7 @@ void OfflineLMConfig::Register(ParseOptions *po) { @@ -15,7 +15,7 @@ void OfflineLMConfig::Register(ParseOptions *po) {
15 po->Register("lm", &model, "Path to LM model."); 15 po->Register("lm", &model, "Path to LM model.");
16 po->Register("lm-scale", &scale, "LM scale."); 16 po->Register("lm-scale", &scale, "LM scale.");
17 po->Register("lm-num-threads", &lm_num_threads, 17 po->Register("lm-num-threads", &lm_num_threads,
18 - "Number of threads to run the neural network of LM model"); 18 + "Number of threads to run the neural network of LM model");
19 po->Register("lm-provider", &lm_provider, 19 po->Register("lm-provider", &lm_provider,
20 "Specify a provider to LM model use: cpu, cuda, coreml"); 20 "Specify a provider to LM model use: cpu, cuda, coreml");
21 } 21 }
@@ -80,9 +80,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -80,9 +80,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
80 InitHotwords(); 80 InitHotwords();
81 } 81 }
82 if (config_.decoding_method == "greedy_search") { 82 if (config_.decoding_method == "greedy_search") {
83 - decoder_ =  
84 - std::make_unique<OfflineTransducerGreedySearchDecoder>(  
85 - model_.get(), config_.blank_penalty); 83 + decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
  84 + model_.get(), config_.blank_penalty);
86 } else if (config_.decoding_method == "modified_beam_search") { 85 } else if (config_.decoding_method == "modified_beam_search") {
87 if (!config_.lm_config.model.empty()) { 86 if (!config_.lm_config.model.empty()) {
88 lm_ = OfflineLM::Create(config.lm_config); 87 lm_ = OfflineLM::Create(config.lm_config);
@@ -106,9 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -106,9 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
106 model_(std::make_unique<OfflineTransducerModel>(mgr, 105 model_(std::make_unique<OfflineTransducerModel>(mgr,
107 config_.model_config)) { 106 config_.model_config)) {
108 if (config_.decoding_method == "greedy_search") { 107 if (config_.decoding_method == "greedy_search") {
109 - decoder_ =  
110 - std::make_unique<OfflineTransducerGreedySearchDecoder>(  
111 - model_.get(), config_.blank_penalty); 108 + decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
  109 + model_.get(), config_.blank_penalty);
112 } else if (config_.decoding_method == "modified_beam_search") { 110 } else if (config_.decoding_method == "modified_beam_search") {
113 if (!config_.lm_config.model.empty()) { 111 if (!config_.lm_config.model.empty()) {
114 lm_ = OfflineLM::Create(mgr, config.lm_config); 112 lm_ = OfflineLM::Create(mgr, config.lm_config);
@@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { @@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
16 public: 16 public:
17 explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, 17 explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
18 float blank_penalty) 18 float blank_penalty)
19 - : model_(model),  
20 - blank_penalty_(blank_penalty) {} 19 + : model_(model), blank_penalty_(blank_penalty) {}
21 20
22 std::vector<OfflineTransducerDecoderResult> Decode( 21 std::vector<OfflineTransducerDecoderResult> Decode(
23 Ort::Value encoder_out, Ort::Value encoder_out_length, 22 Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() { @@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() {
102 asio::post(server_->GetConnectionContext(), 102 asio::post(server_->GetConnectionContext(),
103 [this, hdl, result = ss[i]->GetResult()]() { 103 [this, hdl, result = ss[i]->GetResult()]() {
104 websocketpp::lib::error_code ec; 104 websocketpp::lib::error_code ec;
105 - server_->GetServer().send(  
106 - hdl, result.AsJsonString(),  
107 - websocketpp::frame::opcode::text, ec); 105 + server_->GetServer().send(hdl, result.AsJsonString(),
  106 + websocketpp::frame::opcode::text,
  107 + ec);
108 if (ec) { 108 if (ec) {
109 server_->GetServer().get_alog().write( 109 server_->GetServer().get_alog().write(
110 websocketpp::log::alevel::app, ec.message()); 110 websocketpp::log::alevel::app, ec.message());
@@ -15,7 +15,7 @@ void OnlineLMConfig::Register(ParseOptions *po) { @@ -15,7 +15,7 @@ void OnlineLMConfig::Register(ParseOptions *po) {
15 po->Register("lm", &model, "Path to LM model."); 15 po->Register("lm", &model, "Path to LM model.");
16 po->Register("lm-scale", &scale, "LM scale."); 16 po->Register("lm-scale", &scale, "LM scale.");
17 po->Register("lm-num-threads", &lm_num_threads, 17 po->Register("lm-num-threads", &lm_num_threads,
18 - "Number of threads to run the neural network of LM model"); 18 + "Number of threads to run the neural network of LM model");
19 po->Register("lm-provider", &lm_provider, 19 po->Register("lm-provider", &lm_provider,
20 "Specify a provider to LM model use: cpu, cuda, coreml"); 20 "Specify a provider to LM model use: cpu, cuda, coreml");
21 } 21 }
@@ -22,7 +22,7 @@ struct OnlineLMConfig { @@ -22,7 +22,7 @@ struct OnlineLMConfig {
22 OnlineLMConfig() = default; 22 OnlineLMConfig() = default;
23 23
24 OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, 24 OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
25 - const std::string &lm_provider) 25 + const std::string &lm_provider)
26 : model(model), 26 : model(model),
27 scale(scale), 27 scale(scale),
28 lm_num_threads(lm_num_threads), 28 lm_num_threads(lm_num_threads),
@@ -40,8 +40,7 @@ struct OnlineModelConfig { @@ -40,8 +40,7 @@ struct OnlineModelConfig {
40 const OnlineWenetCtcModelConfig &wenet_ctc, 40 const OnlineWenetCtcModelConfig &wenet_ctc,
41 const OnlineZipformer2CtcModelConfig &zipformer2_ctc, 41 const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
42 const std::string &tokens, int32_t num_threads, 42 const std::string &tokens, int32_t num_threads,
43 - int32_t warm_up, bool debug,  
44 - const std::string &provider, 43 + int32_t warm_up, bool debug, const std::string &provider,
45 const std::string &model_type) 44 const std::string &model_type)
46 : transducer(transducer), 45 : transducer(transducer),
47 paraformer(paraformer), 46 paraformer(paraformer),
@@ -30,9 +30,9 @@ @@ -30,9 +30,9 @@
30 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" 30 #include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
31 #include "sherpa-onnx/csrc/online-transducer-model.h" 31 #include "sherpa-onnx/csrc/online-transducer-model.h"
32 #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" 32 #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
  33 +#include "sherpa-onnx/csrc/onnx-utils.h"
33 #include "sherpa-onnx/csrc/symbol-table.h" 34 #include "sherpa-onnx/csrc/symbol-table.h"
34 #include "sherpa-onnx/csrc/utils.h" 35 #include "sherpa-onnx/csrc/utils.h"
35 -#include "sherpa-onnx/csrc/onnx-utils.h"  
36 36
37 namespace sherpa_onnx { 37 namespace sherpa_onnx {
38 38
@@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
185 } 185 }
186 186
187 // Warmping up engine with wp: warm_up count and max-batch-size 187 // Warmping up engine with wp: warm_up count and max-batch-size
188 - void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { 188 + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const override {
189 auto max_batch_size = mbs; 189 auto max_batch_size = mbs;
190 if (warmup <= 0 || warmup > 100) { 190 if (warmup <= 0 || warmup > 100) {
191 return; 191 return;
@@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
210 for (int32_t i = 0; i != warmup; ++i) { 210 for (int32_t i = 0; i != warmup; ++i) {
211 auto states = model_->StackStates(states_vec); 211 auto states = model_->StackStates(states_vec);
212 Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), 212 Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
213 - features_vec.size(), x_shape.data(),  
214 - x_shape.size()); 213 + features_vec.size(),
  214 + x_shape.data(), x_shape.size());
215 auto x_copy = Clone(model_->Allocator(), &x); 215 auto x_copy = Clone(model_->Allocator(), &x);
216 auto pair = model_->RunEncoder(std::move(x), std::move(states), 216 auto pair = model_->RunEncoder(std::move(x), std::move(states),
217 std::move(x_copy)); 217 std::move(x_copy));
@@ -168,7 +168,7 @@ class OnlineRecognizer { @@ -168,7 +168,7 @@ class OnlineRecognizer {
168 * 168 *
169 * @param warmup Number of warmups. 169 * @param warmup Number of warmups.
170 * @param mbs : max-batch-size Max batch size for the models 170 * @param mbs : max-batch-size Max batch size for the models
171 - */ 171 + */
172 void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const; 172 void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
173 173
174 /** Decode multiple streams in parallel 174 /** Decode multiple streams in parallel
@@ -12,8 +12,8 @@ @@ -12,8 +12,8 @@
12 #include "onnxruntime_cxx_api.h" // NOLINT 12 #include "onnxruntime_cxx_api.h" // NOLINT
13 #include "sherpa-onnx/csrc/macros.h" 13 #include "sherpa-onnx/csrc/macros.h"
14 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
15 -#include "sherpa-onnx/csrc/text-utils.h"  
16 #include "sherpa-onnx/csrc/session.h" 15 #include "sherpa-onnx/csrc/session.h"
  16 +#include "sherpa-onnx/csrc/text-utils.h"
17 17
18 namespace sherpa_onnx { 18 namespace sherpa_onnx {
19 19
@@ -42,10 +42,9 @@ class OnlineRnnLM::Impl { @@ -42,10 +42,9 @@ class OnlineRnnLM::Impl {
42 // nn_lm_scores 42 // nn_lm_scores
43 std::array<int64_t, 2> x_shape{1, 1}; 43 std::array<int64_t, 2> x_shape{1, 1};
44 Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), 44 Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
45 - x_shape.size()); 45 + x_shape.size());
46 *x.GetTensorMutableData<int64_t>() = hyp->ys.back(); 46 *x.GetTensorMutableData<int64_t>() = hyp->ys.back();
47 - auto lm_out =  
48 - ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); 47 + auto lm_out = ScoreToken(std::move(x), Convert(hyp->nn_lm_states));
49 hyp->nn_lm_scores.value = std::move(lm_out.first); 48 hyp->nn_lm_scores.value = std::move(lm_out.first);
50 hyp->nn_lm_states = Convert(std::move(lm_out.second)); 49 hyp->nn_lm_states = Convert(std::move(lm_out.second));
51 } 50 }
@@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( @@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
71 r->tokens = std::vector<int64_t>(start, end); 71 r->tokens = std::vector<int64_t>(start, end);
72 } 72 }
73 73
74 -  
75 void OnlineTransducerGreedySearchDecoder::Decode( 74 void OnlineTransducerGreedySearchDecoder::Decode(
76 Ort::Value encoder_out, 75 Ort::Value encoder_out,
77 std::vector<OnlineTransducerDecoderResult> *result) { 76 std::vector<OnlineTransducerDecoderResult> *result) {
78 -  
79 std::vector<int64_t> encoder_out_shape = 77 std::vector<int64_t> encoder_out_shape =
80 encoder_out.GetTensorTypeAndShapeInfo().GetShape(); 78 encoder_out.GetTensorTypeAndShapeInfo().GetShape();
81 79
@@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
106 r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); 104 r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
107 decoder_out_shape[0] = batch_size; 105 decoder_out_shape[0] = batch_size;
108 decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), 106 decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
109 - decoder_out_shape.data(), decoder_out_shape.size()); 107 + decoder_out_shape.data(),
  108 + decoder_out_shape.size());
110 UseCachedDecoderOut(*result, &decoder_out); 109 UseCachedDecoderOut(*result, &decoder_out);
111 } else { 110 } else {
112 Ort::Value decoder_input = model_->BuildDecoderInput(*result); 111 Ort::Value decoder_input = model_->BuildDecoderInput(*result);
@@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
116 for (int32_t t = 0; t != num_frames; ++t) { 115 for (int32_t t = 0; t != num_frames; ++t) {
117 Ort::Value cur_encoder_out = 116 Ort::Value cur_encoder_out =
118 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); 117 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
119 - Ort::Value logit = model_->RunJoiner(  
120 - std::move(cur_encoder_out), View(&decoder_out)); 118 + Ort::Value logit =
  119 + model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
121 120
122 float *p_logit = logit.GetTensorMutableData<float>(); 121 float *p_logit = logit.GetTensorMutableData<float>();
123 122
@@ -145,9 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -145,9 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
145 144
146 // export the per-token log scores 145 // export the per-token log scores
147 if (y != 0 && y != unk_id_) { 146 if (y != 0 && y != unk_id_) {
148 - LogSoftmax(p_logit, vocab_size); // renormalize probabilities,  
149 - // save time by doing it only for  
150 - // emitted symbols 147 + LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
  148 + // save time by doing it only for
  149 + // emitted symbols
151 const float *p_logprob = p_logit; // rename p_logit as p_logprob, 150 const float *p_logprob = p_logit; // rename p_logit as p_logprob,
152 // now it contains normalized 151 // now it contains normalized
153 // probability 152 // probability
@@ -15,8 +15,7 @@ namespace sherpa_onnx { @@ -15,8 +15,7 @@ namespace sherpa_onnx {
15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { 15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
16 public: 16 public:
17 OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, 17 OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
18 - int32_t unk_id,  
19 - float blank_penalty) 18 + int32_t unk_id, float blank_penalty)
20 : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} 19 : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
21 20
22 OnlineTransducerDecoderResult GetEmptyResult() const override; 21 OnlineTransducerDecoderResult GetEmptyResult() const override;
@@ -69,7 +69,7 @@ class OnlineTransducerModel { @@ -69,7 +69,7 @@ class OnlineTransducerModel {
69 * This has to be called before GetEncoderInitStates(), so the `encoder_embed` 69 * This has to be called before GetEncoderInitStates(), so the `encoder_embed`
70 * init state has the correct `embed_dim` of its output. 70 * init state has the correct `embed_dim` of its output.
71 */ 71 */
72 - virtual void SetFeatureDim(int32_t feature_dim) { } 72 + virtual void SetFeatureDim(int32_t feature_dim) {}
73 73
74 /** Run the encoder. 74 /** Run the encoder.
75 * 75 *
@@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
188 // score of the transducer 188 // score of the transducer
189 // export the per-token log scores 189 // export the per-token log scores
190 if (new_token != 0 && new_token != unk_id_) { 190 if (new_token != 0 && new_token != unk_id_) {
191 - const Hypothesis& prev_i = prev[hyp_index]; 191 + const Hypothesis &prev_i = prev[hyp_index];
192 // subtract 'prev[i]' path scores, which were added before 192 // subtract 'prev[i]' path scores, which were added before
193 // getting topk tokens 193 // getting topk tokens
194 float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; 194 float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
@@ -16,10 +16,10 @@ TEST(Stack, Test1DTensors) { @@ -16,10 +16,10 @@ TEST(Stack, Test1DTensors) {
16 std::array<int64_t, 1> b_shape{3}; 16 std::array<int64_t, 1> b_shape{3};
17 17
18 Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(), 18 Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
19 - a_shape.size()); 19 + a_shape.size());
20 20
21 Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(), 21 Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
22 - b_shape.size()); 22 + b_shape.size());
23 float *pa = a.GetTensorMutableData<float>(); 23 float *pa = a.GetTensorMutableData<float>();
24 float *pb = b.GetTensorMutableData<float>(); 24 float *pb = b.GetTensorMutableData<float>();
25 for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) { 25 for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
@@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) { @@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) {
51 std::array<int64_t, 2> a_shape{2, 3}; 51 std::array<int64_t, 2> a_shape{2, 3};
52 std::array<int64_t, 2> b_shape{2, 3}; 52 std::array<int64_t, 2> b_shape{2, 3};
53 53
54 - Ort::Value a = Ort::Value::CreateTensor<float>(  
55 - allocator, a_shape.data(), a_shape.size()); 54 + Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
  55 + a_shape.size());
56 56
57 - Ort::Value b = Ort::Value::CreateTensor<float>(  
58 - allocator, b_shape.data(), b_shape.size()); 57 + Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
  58 + b_shape.size());
59 59
60 float *pa = a.GetTensorMutableData<float>(); 60 float *pa = a.GetTensorMutableData<float>();
61 float *pb = b.GetTensorMutableData<float>(); 61 float *pb = b.GetTensorMutableData<float>();
@@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) { @@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) {
12 using PyClass = FeatureExtractorConfig; 12 using PyClass = FeatureExtractorConfig;
13 py::class_<PyClass>(*m, "FeatureExtractorConfig") 13 py::class_<PyClass>(*m, "FeatureExtractorConfig")
14 .def(py::init<int32_t, int32_t, float, float, float>(), 14 .def(py::init<int32_t, int32_t, float, float, float>(),
15 - py::arg("sampling_rate") = 16000,  
16 - py::arg("feature_dim") = 80,  
17 - py::arg("low_freq") = 20.0f,  
18 - py::arg("high_freq") = -400.0f, 15 + py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
  16 + py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f,
19 py::arg("dither") = 0.0f) 17 py::arg("dither") = 0.0f)
20 .def_readwrite("sampling_rate", &PyClass::sampling_rate) 18 .def_readwrite("sampling_rate", &PyClass::sampling_rate)
21 .def_readwrite("feature_dim", &PyClass::feature_dim) 19 .def_readwrite("feature_dim", &PyClass::feature_dim)
@@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), 23 py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
24 py::arg("decoding_method") = "greedy_search", 24 py::arg("decoding_method") = "greedy_search",
25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
26 - py::arg("hotwords_score") = 1.5,  
27 - py::arg("blank_penalty") = 0.0) 26 + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0)
28 .def_readwrite("feat_config", &PyClass::feat_config) 27 .def_readwrite("feat_config", &PyClass::feat_config)
29 .def_readwrite("model_config", &PyClass::model_config) 28 .def_readwrite("model_config", &PyClass::model_config)
30 .def_readwrite("lm_config", &PyClass::lm_config) 29 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -4,7 +4,6 @@ @@ -4,7 +4,6 @@
4 4
5 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 5 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
6 6
7 -  
8 #include <string> 7 #include <string>
9 #include <vector> 8 #include <vector>
10 9
@@ -16,7 +15,7 @@ void PybindOfflineTransducerModelConfig(py::module *m) { @@ -16,7 +15,7 @@ void PybindOfflineTransducerModelConfig(py::module *m) {
16 using PyClass = OfflineTransducerModelConfig; 15 using PyClass = OfflineTransducerModelConfig;
17 py::class_<PyClass>(*m, "OfflineTransducerModelConfig") 16 py::class_<PyClass>(*m, "OfflineTransducerModelConfig")
18 .def(py::init<const std::string &, const std::string &, 17 .def(py::init<const std::string &, const std::string &,
19 - const std::string &>(), 18 + const std::string &>(),
20 py::arg("encoder_filename"), py::arg("decoder_filename"), 19 py::arg("encoder_filename"), py::arg("decoder_filename"),
21 py::arg("joiner_filename")) 20 py::arg("joiner_filename"))
22 .def_readwrite("encoder_filename", &PyClass::encoder_filename) 21 .def_readwrite("encoder_filename", &PyClass::encoder_filename)
@@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) { @@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) {
27 .def(py::init<const OnlineTransducerModelConfig &, 27 .def(py::init<const OnlineTransducerModelConfig &,
28 const OnlineParaformerModelConfig &, 28 const OnlineParaformerModelConfig &,
29 const OnlineWenetCtcModelConfig &, 29 const OnlineWenetCtcModelConfig &,
30 - const OnlineZipformer2CtcModelConfig &,  
31 - const std::string &, int32_t, int32_t,  
32 - bool, const std::string &, const std::string &>(), 30 + const OnlineZipformer2CtcModelConfig &, const std::string &,
  31 + int32_t, int32_t, bool, const std::string &,
  32 + const std::string &>(),
33 py::arg("transducer") = OnlineTransducerModelConfig(), 33 py::arg("transducer") = OnlineTransducerModelConfig(),
34 py::arg("paraformer") = OnlineParaformerModelConfig(), 34 py::arg("paraformer") = OnlineParaformerModelConfig(),
35 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), 35 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),