正在显示
22 个修改的文件
包含
206 行增加
和
77 行删除
sherpa-onnx/csharp-api/CPPLINT.cfg
已删除
100644 → 0
| 1 | -exclude_files=.* |
| @@ -35,8 +35,8 @@ set(sources | @@ -35,8 +35,8 @@ set(sources | ||
| 35 | offline-transducer-model.cc | 35 | offline-transducer-model.cc |
| 36 | offline-transducer-modified-beam-search-decoder.cc | 36 | offline-transducer-modified-beam-search-decoder.cc |
| 37 | online-conformer-transducer-model.cc | 37 | online-conformer-transducer-model.cc |
| 38 | - online-lm.cc | ||
| 39 | online-lm-config.cc | 38 | online-lm-config.cc |
| 39 | + online-lm.cc | ||
| 40 | online-lstm-transducer-model.cc | 40 | online-lstm-transducer-model.cc |
| 41 | online-recognizer.cc | 41 | online-recognizer.cc |
| 42 | online-rnn-lm.cc | 42 | online-rnn-lm.cc |
| @@ -48,9 +48,11 @@ set(sources | @@ -48,9 +48,11 @@ set(sources | ||
| 48 | online-transducer-modified-beam-search-decoder.cc | 48 | online-transducer-modified-beam-search-decoder.cc |
| 49 | online-zipformer-transducer-model.cc | 49 | online-zipformer-transducer-model.cc |
| 50 | onnx-utils.cc | 50 | onnx-utils.cc |
| 51 | + session.cc | ||
| 51 | packed-sequence.cc | 52 | packed-sequence.cc |
| 52 | pad-sequence.cc | 53 | pad-sequence.cc |
| 53 | parse-options.cc | 54 | parse-options.cc |
| 55 | + provider.cc | ||
| 54 | resample.cc | 56 | resample.cc |
| 55 | slice.cc | 57 | slice.cc |
| 56 | stack.cc | 58 | stack.cc |
| @@ -22,6 +22,9 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -22,6 +22,9 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 22 | 22 | ||
| 23 | po->Register("debug", &debug, | 23 | po->Register("debug", &debug, |
| 24 | "true to print model information while loading it."); | 24 | "true to print model information while loading it."); |
| 25 | + | ||
| 26 | + po->Register("provider", &provider, | ||
| 27 | + "Specify a provider to use: cpu, cuda, coreml"); | ||
| 25 | } | 28 | } |
| 26 | 29 | ||
| 27 | bool OfflineModelConfig::Validate() const { | 30 | bool OfflineModelConfig::Validate() const { |
| @@ -55,7 +58,8 @@ std::string OfflineModelConfig::ToString() const { | @@ -55,7 +58,8 @@ std::string OfflineModelConfig::ToString() const { | ||
| 55 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | 58 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; |
| 56 | os << "tokens=\"" << tokens << "\", "; | 59 | os << "tokens=\"" << tokens << "\", "; |
| 57 | os << "num_threads=" << num_threads << ", "; | 60 | os << "num_threads=" << num_threads << ", "; |
| 58 | - os << "debug=" << (debug ? "True" : "False") << ")"; | 61 | + os << "debug=" << (debug ? "True" : "False") << ", "; |
| 62 | + os << "provider=\"" << provider << "\")"; | ||
| 59 | 63 | ||
| 60 | return os.str(); | 64 | return os.str(); |
| 61 | } | 65 | } |
| @@ -20,18 +20,21 @@ struct OfflineModelConfig { | @@ -20,18 +20,21 @@ struct OfflineModelConfig { | ||
| 20 | std::string tokens; | 20 | std::string tokens; |
| 21 | int32_t num_threads = 2; | 21 | int32_t num_threads = 2; |
| 22 | bool debug = false; | 22 | bool debug = false; |
| 23 | + std::string provider = "cpu"; | ||
| 23 | 24 | ||
| 24 | OfflineModelConfig() = default; | 25 | OfflineModelConfig() = default; |
| 25 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, | 26 | OfflineModelConfig(const OfflineTransducerModelConfig &transducer, |
| 26 | const OfflineParaformerModelConfig ¶former, | 27 | const OfflineParaformerModelConfig ¶former, |
| 27 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, | 28 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, |
| 28 | - const std::string &tokens, int32_t num_threads, bool debug) | 29 | + const std::string &tokens, int32_t num_threads, bool debug, |
| 30 | + const std::string &provider) | ||
| 29 | : transducer(transducer), | 31 | : transducer(transducer), |
| 30 | paraformer(paraformer), | 32 | paraformer(paraformer), |
| 31 | nemo_ctc(nemo_ctc), | 33 | nemo_ctc(nemo_ctc), |
| 32 | tokens(tokens), | 34 | tokens(tokens), |
| 33 | num_threads(num_threads), | 35 | num_threads(num_threads), |
| 34 | - debug(debug) {} | 36 | + debug(debug), |
| 37 | + provider(provider) {} | ||
| 35 | 38 | ||
| 36 | void Register(ParseOptions *po); | 39 | void Register(ParseOptions *po); |
| 37 | bool Validate() const; | 40 | bool Validate() const; |
| @@ -6,6 +6,7 @@ | @@ -6,6 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include "sherpa-onnx/csrc/macros.h" | 7 | #include "sherpa-onnx/csrc/macros.h" |
| 8 | #include "sherpa-onnx/csrc/onnx-utils.h" | 8 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 9 | +#include "sherpa-onnx/csrc/session.h" | ||
| 9 | #include "sherpa-onnx/csrc/text-utils.h" | 10 | #include "sherpa-onnx/csrc/text-utils.h" |
| 10 | #include "sherpa-onnx/csrc/transpose.h" | 11 | #include "sherpa-onnx/csrc/transpose.h" |
| 11 | 12 | ||
| @@ -16,11 +17,8 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -16,11 +17,8 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 16 | explicit Impl(const OfflineModelConfig &config) | 17 | explicit Impl(const OfflineModelConfig &config) |
| 17 | : config_(config), | 18 | : config_(config), |
| 18 | env_(ORT_LOGGING_LEVEL_ERROR), | 19 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 19 | - sess_opts_{}, | 20 | + sess_opts_(GetSessionOptions(config)), |
| 20 | allocator_{} { | 21 | allocator_{} { |
| 21 | - sess_opts_.SetIntraOpNumThreads(config_.num_threads); | ||
| 22 | - sess_opts_.SetInterOpNumThreads(config_.num_threads); | ||
| 23 | - | ||
| 24 | Init(); | 22 | Init(); |
| 25 | } | 23 | } |
| 26 | 24 |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/macros.h" | 10 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | #include "sherpa-onnx/csrc/onnx-utils.h" | 11 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 12 | +#include "sherpa-onnx/csrc/session.h" | ||
| 12 | #include "sherpa-onnx/csrc/text-utils.h" | 13 | #include "sherpa-onnx/csrc/text-utils.h" |
| 13 | 14 | ||
| 14 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| @@ -18,11 +19,8 @@ class OfflineParaformerModel::Impl { | @@ -18,11 +19,8 @@ class OfflineParaformerModel::Impl { | ||
| 18 | explicit Impl(const OfflineModelConfig &config) | 19 | explicit Impl(const OfflineModelConfig &config) |
| 19 | : config_(config), | 20 | : config_(config), |
| 20 | env_(ORT_LOGGING_LEVEL_ERROR), | 21 | env_(ORT_LOGGING_LEVEL_ERROR), |
| 21 | - sess_opts_{}, | 22 | + sess_opts_(GetSessionOptions(config)), |
| 22 | allocator_{} { | 23 | allocator_{} { |
| 23 | - sess_opts_.SetIntraOpNumThreads(config_.num_threads); | ||
| 24 | - sess_opts_.SetInterOpNumThreads(config_.num_threads); | ||
| 25 | - | ||
| 26 | Init(); | 24 | Init(); |
| 27 | } | 25 | } |
| 28 | 26 |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/macros.h" | 11 | #include "sherpa-onnx/csrc/macros.h" |
| 12 | #include "sherpa-onnx/csrc/offline-transducer-decoder.h" | 12 | #include "sherpa-onnx/csrc/offline-transducer-decoder.h" |
| 13 | #include "sherpa-onnx/csrc/onnx-utils.h" | 13 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 14 | +#include "sherpa-onnx/csrc/session.h" | ||
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| @@ -19,10 +20,8 @@ class OfflineTransducerModel::Impl { | @@ -19,10 +20,8 @@ class OfflineTransducerModel::Impl { | ||
| 19 | explicit Impl(const OfflineModelConfig &config) | 20 | explicit Impl(const OfflineModelConfig &config) |
| 20 | : config_(config), | 21 | : config_(config), |
| 21 | env_(ORT_LOGGING_LEVEL_WARNING), | 22 | env_(ORT_LOGGING_LEVEL_WARNING), |
| 22 | - sess_opts_{}, | 23 | + sess_opts_(GetSessionOptions(config)), |
| 23 | allocator_{} { | 24 | allocator_{} { |
| 24 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 25 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 26 | { | 25 | { |
| 27 | auto buf = ReadFile(config.transducer.encoder_filename); | 26 | auto buf = ReadFile(config.transducer.encoder_filename); |
| 28 | InitEncoder(buf.data(), buf.size()); | 27 | InitEncoder(buf.data(), buf.size()); |
| @@ -9,7 +9,6 @@ | @@ -9,7 +9,6 @@ | ||
| 9 | #include <algorithm> | 9 | #include <algorithm> |
| 10 | #include <memory> | 10 | #include <memory> |
| 11 | #include <sstream> | 11 | #include <sstream> |
| 12 | -#include <iostream> | ||
| 13 | #include <string> | 12 | #include <string> |
| 14 | #include <utility> | 13 | #include <utility> |
| 15 | #include <vector> | 14 | #include <vector> |
| @@ -24,6 +23,7 @@ | @@ -24,6 +23,7 @@ | ||
| 24 | #include "sherpa-onnx/csrc/macros.h" | 23 | #include "sherpa-onnx/csrc/macros.h" |
| 25 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 24 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 26 | #include "sherpa-onnx/csrc/onnx-utils.h" | 25 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 26 | +#include "sherpa-onnx/csrc/session.h" | ||
| 27 | #include "sherpa-onnx/csrc/text-utils.h" | 27 | #include "sherpa-onnx/csrc/text-utils.h" |
| 28 | #include "sherpa-onnx/csrc/unbind.h" | 28 | #include "sherpa-onnx/csrc/unbind.h" |
| 29 | 29 | ||
| @@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( | @@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( | ||
| 33 | const OnlineTransducerModelConfig &config) | 33 | const OnlineTransducerModelConfig &config) |
| 34 | : env_(ORT_LOGGING_LEVEL_WARNING), | 34 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 35 | config_(config), | 35 | config_(config), |
| 36 | - sess_opts_{}, | 36 | + sess_opts_(GetSessionOptions(config)), |
| 37 | allocator_{} { | 37 | allocator_{} { |
| 38 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 39 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 40 | - | ||
| 41 | { | 38 | { |
| 42 | auto buf = ReadFile(config.encoder_filename); | 39 | auto buf = ReadFile(config.encoder_filename); |
| 43 | InitEncoder(buf.data(), buf.size()); | 40 | InitEncoder(buf.data(), buf.size()); |
| @@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( | @@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( | ||
| 59 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) | 56 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) |
| 60 | : env_(ORT_LOGGING_LEVEL_WARNING), | 57 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 61 | config_(config), | 58 | config_(config), |
| 62 | - sess_opts_{}, | 59 | + sess_opts_(GetSessionOptions(config)), |
| 63 | allocator_{} { | 60 | allocator_{} { |
| 64 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 65 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 66 | - | ||
| 67 | { | 61 | { |
| 68 | auto buf = ReadFile(mgr, config.encoder_filename); | 62 | auto buf = ReadFile(mgr, config.encoder_filename); |
| 69 | InitEncoder(buf.data(), buf.size()); | 63 | InitEncoder(buf.data(), buf.size()); |
| @@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() { | @@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() { | ||
| 209 | // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 | 203 | // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 |
| 210 | // for details | 204 | // for details |
| 211 | constexpr int32_t kBatchSize = 1; | 205 | constexpr int32_t kBatchSize = 1; |
| 212 | - std::array<int64_t, 4> h_shape{ | ||
| 213 | - num_encoder_layers_, left_context_, kBatchSize, encoder_dim_}; | 206 | + std::array<int64_t, 4> h_shape{num_encoder_layers_, left_context_, kBatchSize, |
| 207 | + encoder_dim_}; | ||
| 214 | Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), | 208 | Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), |
| 215 | h_shape.size()); | 209 | h_shape.size()); |
| 216 | 210 | ||
| @@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features, | @@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features, | ||
| 238 | std::vector<Ort::Value> states, | 232 | std::vector<Ort::Value> states, |
| 239 | Ort::Value processed_frames) { | 233 | Ort::Value processed_frames) { |
| 240 | std::array<Ort::Value, 4> encoder_inputs = { | 234 | std::array<Ort::Value, 4> encoder_inputs = { |
| 241 | - std::move(features), | ||
| 242 | - std::move(states[0]), | ||
| 243 | - std::move(states[1]), | 235 | + std::move(features), std::move(states[0]), std::move(states[1]), |
| 244 | std::move(processed_frames)}; | 236 | std::move(processed_frames)}; |
| 245 | 237 | ||
| 246 | auto encoder_out = encoder_sess_->Run( | 238 | auto encoder_out = encoder_sess_->Run( |
| @@ -22,6 +22,7 @@ | @@ -22,6 +22,7 @@ | ||
| 22 | #include "sherpa-onnx/csrc/macros.h" | 22 | #include "sherpa-onnx/csrc/macros.h" |
| 23 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 23 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 24 | #include "sherpa-onnx/csrc/onnx-utils.h" | 24 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 25 | +#include "sherpa-onnx/csrc/session.h" | ||
| 25 | #include "sherpa-onnx/csrc/unbind.h" | 26 | #include "sherpa-onnx/csrc/unbind.h" |
| 26 | 27 | ||
| 27 | namespace sherpa_onnx { | 28 | namespace sherpa_onnx { |
| @@ -30,11 +31,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( | @@ -30,11 +31,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( | ||
| 30 | const OnlineTransducerModelConfig &config) | 31 | const OnlineTransducerModelConfig &config) |
| 31 | : env_(ORT_LOGGING_LEVEL_WARNING), | 32 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 32 | config_(config), | 33 | config_(config), |
| 33 | - sess_opts_{}, | 34 | + sess_opts_(GetSessionOptions(config)), |
| 34 | allocator_{} { | 35 | allocator_{} { |
| 35 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 36 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 37 | - | ||
| 38 | { | 36 | { |
| 39 | auto buf = ReadFile(config.encoder_filename); | 37 | auto buf = ReadFile(config.encoder_filename); |
| 40 | InitEncoder(buf.data(), buf.size()); | 38 | InitEncoder(buf.data(), buf.size()); |
| @@ -56,11 +54,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( | @@ -56,11 +54,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( | ||
| 56 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) | 54 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) |
| 57 | : env_(ORT_LOGGING_LEVEL_WARNING), | 55 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 58 | config_(config), | 56 | config_(config), |
| 59 | - sess_opts_{}, | 57 | + sess_opts_(GetSessionOptions(config)), |
| 60 | allocator_{} { | 58 | allocator_{} { |
| 61 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 62 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 63 | - | ||
| 64 | { | 59 | { |
| 65 | auto buf = ReadFile(mgr, config.encoder_filename); | 60 | auto buf = ReadFile(mgr, config.encoder_filename); |
| 66 | InitEncoder(buf.data(), buf.size()); | 61 | InitEncoder(buf.data(), buf.size()); |
| @@ -9,7 +9,6 @@ | @@ -9,7 +9,6 @@ | ||
| 9 | 9 | ||
| 10 | #include <algorithm> | 10 | #include <algorithm> |
| 11 | #include <iomanip> | 11 | #include <iomanip> |
| 12 | -#include <iostream> | ||
| 13 | #include <memory> | 12 | #include <memory> |
| 14 | #include <sstream> | 13 | #include <sstream> |
| 15 | #include <utility> | 14 | #include <utility> |
| @@ -140,7 +139,7 @@ class OnlineRecognizer::Impl { | @@ -140,7 +139,7 @@ class OnlineRecognizer::Impl { | ||
| 140 | decoder_ = | 139 | decoder_ = |
| 141 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 140 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| 142 | } else { | 141 | } else { |
| 143 | - fprintf(stderr, "Unsupported decoding method: %s\n", | 142 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 144 | config.decoding_method.c_str()); | 143 | config.decoding_method.c_str()); |
| 145 | exit(-1); | 144 | exit(-1); |
| 146 | } | 145 | } |
| @@ -160,7 +159,7 @@ class OnlineRecognizer::Impl { | @@ -160,7 +159,7 @@ class OnlineRecognizer::Impl { | ||
| 160 | decoder_ = | 159 | decoder_ = |
| 161 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); | 160 | std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); |
| 162 | } else { | 161 | } else { |
| 163 | - fprintf(stderr, "Unsupported decoding method: %s\n", | 162 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 164 | config.decoding_method.c_str()); | 163 | config.decoding_method.c_str()); |
| 165 | exit(-1); | 164 | exit(-1); |
| 166 | } | 165 | } |
| @@ -219,16 +218,13 @@ class OnlineRecognizer::Impl { | @@ -219,16 +218,13 @@ class OnlineRecognizer::Impl { | ||
| 219 | static_cast<int64_t>(all_processed_frames.size())}; | 218 | static_cast<int64_t>(all_processed_frames.size())}; |
| 220 | 219 | ||
| 221 | Ort::Value processed_frames = Ort::Value::CreateTensor( | 220 | Ort::Value processed_frames = Ort::Value::CreateTensor( |
| 222 | - memory_info, | ||
| 223 | - all_processed_frames.data(), | ||
| 224 | - all_processed_frames.size(), | ||
| 225 | - processed_frames_shape.data(), | ||
| 226 | - processed_frames_shape.size()); | 221 | + memory_info, all_processed_frames.data(), all_processed_frames.size(), |
| 222 | + processed_frames_shape.data(), processed_frames_shape.size()); | ||
| 227 | 223 | ||
| 228 | auto states = model_->StackStates(states_vec); | 224 | auto states = model_->StackStates(states_vec); |
| 229 | 225 | ||
| 230 | - auto pair = model_->RunEncoder( | ||
| 231 | - std::move(x), std::move(states), std::move(processed_frames)); | 226 | + auto pair = model_->RunEncoder(std::move(x), std::move(states), |
| 227 | + std::move(processed_frames)); | ||
| 232 | 228 | ||
| 233 | decoder_->Decode(std::move(pair.first), &results); | 229 | decoder_->Decode(std::move(pair.first), &results); |
| 234 | 230 |
| @@ -17,19 +17,21 @@ struct OnlineTransducerModelConfig { | @@ -17,19 +17,21 @@ struct OnlineTransducerModelConfig { | ||
| 17 | std::string tokens; | 17 | std::string tokens; |
| 18 | int32_t num_threads = 2; | 18 | int32_t num_threads = 2; |
| 19 | bool debug = false; | 19 | bool debug = false; |
| 20 | + std::string provider = "cpu"; | ||
| 20 | 21 | ||
| 21 | OnlineTransducerModelConfig() = default; | 22 | OnlineTransducerModelConfig() = default; |
| 22 | OnlineTransducerModelConfig(const std::string &encoder_filename, | 23 | OnlineTransducerModelConfig(const std::string &encoder_filename, |
| 23 | const std::string &decoder_filename, | 24 | const std::string &decoder_filename, |
| 24 | const std::string &joiner_filename, | 25 | const std::string &joiner_filename, |
| 25 | const std::string &tokens, int32_t num_threads, | 26 | const std::string &tokens, int32_t num_threads, |
| 26 | - bool debug) | 27 | + bool debug, const std::string &provider) |
| 27 | : encoder_filename(encoder_filename), | 28 | : encoder_filename(encoder_filename), |
| 28 | decoder_filename(decoder_filename), | 29 | decoder_filename(decoder_filename), |
| 29 | joiner_filename(joiner_filename), | 30 | joiner_filename(joiner_filename), |
| 30 | tokens(tokens), | 31 | tokens(tokens), |
| 31 | num_threads(num_threads), | 32 | num_threads(num_threads), |
| 32 | - debug(debug) {} | 33 | + debug(debug), |
| 34 | + provider(provider) {} | ||
| 33 | 35 | ||
| 34 | void Register(ParseOptions *po); | 36 | void Register(ParseOptions *po); |
| 35 | bool Validate() const; | 37 | bool Validate() const; |
| @@ -23,6 +23,7 @@ | @@ -23,6 +23,7 @@ | ||
| 23 | #include "sherpa-onnx/csrc/macros.h" | 23 | #include "sherpa-onnx/csrc/macros.h" |
| 24 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 24 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 25 | #include "sherpa-onnx/csrc/onnx-utils.h" | 25 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 26 | +#include "sherpa-onnx/csrc/session.h" | ||
| 26 | #include "sherpa-onnx/csrc/text-utils.h" | 27 | #include "sherpa-onnx/csrc/text-utils.h" |
| 27 | #include "sherpa-onnx/csrc/unbind.h" | 28 | #include "sherpa-onnx/csrc/unbind.h" |
| 28 | 29 | ||
| @@ -32,11 +33,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | @@ -32,11 +33,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | ||
| 32 | const OnlineTransducerModelConfig &config) | 33 | const OnlineTransducerModelConfig &config) |
| 33 | : env_(ORT_LOGGING_LEVEL_WARNING), | 34 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 34 | config_(config), | 35 | config_(config), |
| 35 | - sess_opts_{}, | 36 | + sess_opts_(GetSessionOptions(config)), |
| 36 | allocator_{} { | 37 | allocator_{} { |
| 37 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 38 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 39 | - | ||
| 40 | { | 38 | { |
| 41 | auto buf = ReadFile(config.encoder_filename); | 39 | auto buf = ReadFile(config.encoder_filename); |
| 42 | InitEncoder(buf.data(), buf.size()); | 40 | InitEncoder(buf.data(), buf.size()); |
| @@ -58,11 +56,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | @@ -58,11 +56,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( | ||
| 58 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) | 56 | AAssetManager *mgr, const OnlineTransducerModelConfig &config) |
| 59 | : env_(ORT_LOGGING_LEVEL_WARNING), | 57 | : env_(ORT_LOGGING_LEVEL_WARNING), |
| 60 | config_(config), | 58 | config_(config), |
| 61 | - sess_opts_{}, | 59 | + sess_opts_(GetSessionOptions(config)), |
| 62 | allocator_{} { | 60 | allocator_{} { |
| 63 | - sess_opts_.SetIntraOpNumThreads(config.num_threads); | ||
| 64 | - sess_opts_.SetInterOpNumThreads(config.num_threads); | ||
| 65 | - | ||
| 66 | { | 61 | { |
| 67 | auto buf = ReadFile(mgr, config.encoder_filename); | 62 | auto buf = ReadFile(mgr, config.encoder_filename); |
| 68 | InitEncoder(buf.data(), buf.size()); | 63 | InitEncoder(buf.data(), buf.size()); |
sherpa-onnx/csrc/provider.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/provider.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/provider.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <cctype> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +Provider StringToProvider(std::string s) { | ||
| 15 | + std::transform(s.cbegin(), s.cend(), s.begin(), | ||
| 16 | + [](unsigned char c) { return std::tolower(c); }); | ||
| 17 | + if (s == "cpu") { | ||
| 18 | + return Provider::kCPU; | ||
| 19 | + } else if (s == "cuda") { | ||
| 20 | + return Provider::kCUDA; | ||
| 21 | + } else if (s == "coreml") { | ||
| 22 | + return Provider::kCoreML; | ||
| 23 | + } else { | ||
| 24 | + SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); | ||
| 25 | + return Provider::kCPU; | ||
| 26 | + } | ||
| 27 | +} | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/provider.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/provider.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_PROVIDER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_PROVIDER_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +// Please refer to | ||
| 13 | +// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java | ||
| 14 | +// for a list of available providers | ||
| 15 | +enum class Provider { | ||
| 16 | + kCPU = 0, // CPUExecutionProvider | ||
| 17 | + kCUDA = 1, // CUDAExecutionProvider | ||
| 18 | + kCoreML = 2, // CoreMLExecutionProvider | ||
| 19 | +}; | ||
| 20 | + | ||
| 21 | +/** | ||
| 22 | + * Convert a string to an enum. | ||
| 23 | + * | ||
| 24 | + * @param s We will convert it to lowercase before comparing. | ||
| 25 | + * @return Return an instance of Provider. | ||
| 26 | + */ | ||
| 27 | +Provider StringToProvider(std::string s); | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_PROVIDER_H_ |
sherpa-onnx/csrc/session.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/session.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/session.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/provider.h" | ||
| 12 | +#if defined(__APPLE__) | ||
| 13 | +#include "coreml_provider_factory.h" // NOLINT | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, | ||
| 19 | + std::string provider_str) { | ||
| 20 | + Provider p = StringToProvider(std::move(provider_str)); | ||
| 21 | + | ||
| 22 | + Ort::SessionOptions sess_opts; | ||
| 23 | + sess_opts.SetIntraOpNumThreads(num_threads); | ||
| 24 | + sess_opts.SetInterOpNumThreads(num_threads); | ||
| 25 | + | ||
| 26 | + switch (p) { | ||
| 27 | + case Provider::kCPU: | ||
| 28 | + break; // nothing to do for the CPU provider | ||
| 29 | + case Provider::kCUDA: { | ||
| 30 | + OrtCUDAProviderOptions options; | ||
| 31 | + options.device_id = 0; | ||
| 32 | + // set more options on need | ||
| 33 | + sess_opts.AppendExecutionProvider_CUDA(options); | ||
| 34 | + break; | ||
| 35 | + } | ||
| 36 | + case Provider::kCoreML: { | ||
| 37 | +#if defined(__APPLE__) | ||
| 38 | + uint32_t coreml_flags = 0; | ||
| 39 | + (void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts, | ||
| 40 | + coreml_flags); | ||
| 41 | +#else | ||
| 42 | + SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!"); | ||
| 43 | +#endif | ||
| 44 | + break; | ||
| 45 | + } | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + return sess_opts; | ||
| 49 | +} | ||
| 50 | + | ||
| 51 | +Ort::SessionOptions GetSessionOptions( | ||
| 52 | + const OnlineTransducerModelConfig &config) { | ||
| 53 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 54 | +} | ||
| 55 | + | ||
| 56 | +Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { | ||
| 57 | + return GetSessionOptionsImpl(config.num_threads, config.provider); | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/session.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/session.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_SESSION_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_SESSION_H_ | ||
| 7 | + | ||
| 8 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 9 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 10 | +#include "sherpa-onnx/csrc/online-transducer-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +Ort::SessionOptions GetSessionOptions( | ||
| 15 | + const OnlineTransducerModelConfig &config); | ||
| 16 | + | ||
| 17 | +Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); | ||
| 18 | + | ||
| 19 | +} // namespace sherpa_onnx | ||
| 20 | + | ||
| 21 | +#endif // SHERPA_ONNX_CSRC_SESSION_H_ |
| @@ -6,7 +6,6 @@ | @@ -6,7 +6,6 @@ | ||
| 6 | 6 | ||
| 7 | #include <algorithm> | 7 | #include <algorithm> |
| 8 | #include <functional> | 8 | #include <functional> |
| 9 | -#include <iostream> | ||
| 10 | #include <numeric> | 9 | #include <numeric> |
| 11 | #include <utility> | 10 | #include <utility> |
| 12 | 11 | ||
| @@ -58,21 +57,17 @@ Ort::Value Stack(OrtAllocator *allocator, | @@ -58,21 +57,17 @@ Ort::Value Stack(OrtAllocator *allocator, | ||
| 58 | ans_shape.reserve(v0_shape.size() + 1); | 57 | ans_shape.reserve(v0_shape.size() + 1); |
| 59 | ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); | 58 | ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); |
| 60 | ans_shape.push_back(values.size()); | 59 | ans_shape.push_back(values.size()); |
| 61 | - ans_shape.insert( | ||
| 62 | - ans_shape.end(), | ||
| 63 | - v0_shape.data() + dim, | 60 | + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim, |
| 64 | v0_shape.data() + v0_shape.size()); | 61 | v0_shape.data() + v0_shape.size()); |
| 65 | 62 | ||
| 66 | auto leading_size = static_cast<int32_t>(std::accumulate( | 63 | auto leading_size = static_cast<int32_t>(std::accumulate( |
| 67 | v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>())); | 64 | v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>())); |
| 68 | 65 | ||
| 69 | - auto trailing_size = static_cast<int32_t>( | ||
| 70 | - std::accumulate(v0_shape.begin() + dim, | ||
| 71 | - v0_shape.end(), 1, | ||
| 72 | - std::multiplies<int64_t>())); | 66 | + auto trailing_size = static_cast<int32_t>(std::accumulate( |
| 67 | + v0_shape.begin() + dim, v0_shape.end(), 1, std::multiplies<int64_t>())); | ||
| 73 | 68 | ||
| 74 | - Ort::Value ans = Ort::Value::CreateTensor<T>( | ||
| 75 | - allocator, ans_shape.data(), ans_shape.size()); | 69 | + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(), |
| 70 | + ans_shape.size()); | ||
| 76 | T *dst = ans.GetTensorMutableData<T>(); | 71 | T *dst = ans.GetTensorMutableData<T>(); |
| 77 | 72 | ||
| 78 | for (int32_t i = 0; i != leading_size; ++i) { | 73 | for (int32_t i = 0; i != leading_size; ++i) { |
| @@ -88,14 +83,12 @@ Ort::Value Stack(OrtAllocator *allocator, | @@ -88,14 +83,12 @@ Ort::Value Stack(OrtAllocator *allocator, | ||
| 88 | return ans; | 83 | return ans; |
| 89 | } | 84 | } |
| 90 | 85 | ||
| 91 | -template Ort::Value Stack<float>( | ||
| 92 | - OrtAllocator *allocator, | 86 | +template Ort::Value Stack<float>(OrtAllocator *allocator, |
| 93 | const std::vector<const Ort::Value *> &values, | 87 | const std::vector<const Ort::Value *> &values, |
| 94 | int32_t dim); | 88 | int32_t dim); |
| 95 | 89 | ||
| 96 | template Ort::Value Stack<int64_t>( | 90 | template Ort::Value Stack<int64_t>( |
| 97 | - OrtAllocator *allocator, | ||
| 98 | - const std::vector<const Ort::Value *> &values, | 91 | + OrtAllocator *allocator, const std::vector<const Ort::Value *> &values, |
| 99 | int32_t dim); | 92 | int32_t dim); |
| 100 | 93 | ||
| 101 | } // namespace sherpa_onnx | 94 | } // namespace sherpa_onnx |
| @@ -24,17 +24,19 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -24,17 +24,19 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 24 | .def(py::init<const OfflineTransducerModelConfig &, | 24 | .def(py::init<const OfflineTransducerModelConfig &, |
| 25 | const OfflineParaformerModelConfig &, | 25 | const OfflineParaformerModelConfig &, |
| 26 | const OfflineNemoEncDecCtcModelConfig &, | 26 | const OfflineNemoEncDecCtcModelConfig &, |
| 27 | - const std::string &, int32_t, bool>(), | 27 | + const std::string &, int32_t, bool, const std::string &>(), |
| 28 | py::arg("transducer") = OfflineTransducerModelConfig(), | 28 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| 29 | py::arg("paraformer") = OfflineParaformerModelConfig(), | 29 | py::arg("paraformer") = OfflineParaformerModelConfig(), |
| 30 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | 30 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), |
| 31 | - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false) | 31 | + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 32 | + py::arg("provider") = "cpu") | ||
| 32 | .def_readwrite("transducer", &PyClass::transducer) | 33 | .def_readwrite("transducer", &PyClass::transducer) |
| 33 | .def_readwrite("paraformer", &PyClass::paraformer) | 34 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 34 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 35 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| 35 | .def_readwrite("tokens", &PyClass::tokens) | 36 | .def_readwrite("tokens", &PyClass::tokens) |
| 36 | .def_readwrite("num_threads", &PyClass::num_threads) | 37 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 37 | .def_readwrite("debug", &PyClass::debug) | 38 | .def_readwrite("debug", &PyClass::debug) |
| 39 | + .def_readwrite("provider", &PyClass::provider) | ||
| 38 | .def("__str__", &PyClass::ToString); | 40 | .def("__str__", &PyClass::ToString); |
| 39 | } | 41 | } |
| 40 | 42 |
| @@ -14,16 +14,19 @@ void PybindOnlineTransducerModelConfig(py::module *m) { | @@ -14,16 +14,19 @@ void PybindOnlineTransducerModelConfig(py::module *m) { | ||
| 14 | using PyClass = OnlineTransducerModelConfig; | 14 | using PyClass = OnlineTransducerModelConfig; |
| 15 | py::class_<PyClass>(*m, "OnlineTransducerModelConfig") | 15 | py::class_<PyClass>(*m, "OnlineTransducerModelConfig") |
| 16 | .def(py::init<const std::string &, const std::string &, | 16 | .def(py::init<const std::string &, const std::string &, |
| 17 | - const std::string &, const std::string &, int32_t, bool>(), | 17 | + const std::string &, const std::string &, int32_t, bool, |
| 18 | + const std::string &>(), | ||
| 18 | py::arg("encoder_filename"), py::arg("decoder_filename"), | 19 | py::arg("encoder_filename"), py::arg("decoder_filename"), |
| 19 | py::arg("joiner_filename"), py::arg("tokens"), | 20 | py::arg("joiner_filename"), py::arg("tokens"), |
| 20 | - py::arg("num_threads"), py::arg("debug") = false) | 21 | + py::arg("num_threads"), py::arg("debug") = false, |
| 22 | + py::arg("provider") = "cpu") | ||
| 21 | .def_readwrite("encoder_filename", &PyClass::encoder_filename) | 23 | .def_readwrite("encoder_filename", &PyClass::encoder_filename) |
| 22 | .def_readwrite("decoder_filename", &PyClass::decoder_filename) | 24 | .def_readwrite("decoder_filename", &PyClass::decoder_filename) |
| 23 | .def_readwrite("joiner_filename", &PyClass::joiner_filename) | 25 | .def_readwrite("joiner_filename", &PyClass::joiner_filename) |
| 24 | .def_readwrite("tokens", &PyClass::tokens) | 26 | .def_readwrite("tokens", &PyClass::tokens) |
| 25 | .def_readwrite("num_threads", &PyClass::num_threads) | 27 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 26 | .def_readwrite("debug", &PyClass::debug) | 28 | .def_readwrite("debug", &PyClass::debug) |
| 29 | + .def_readwrite("provider", &PyClass::provider) | ||
| 27 | .def("__str__", &PyClass::ToString); | 30 | .def("__str__", &PyClass::ToString); |
| 28 | } | 31 | } |
| 29 | 32 |
| @@ -40,6 +40,7 @@ class OfflineRecognizer(object): | @@ -40,6 +40,7 @@ class OfflineRecognizer(object): | ||
| 40 | feature_dim: int = 80, | 40 | feature_dim: int = 80, |
| 41 | decoding_method: str = "greedy_search", | 41 | decoding_method: str = "greedy_search", |
| 42 | debug: bool = False, | 42 | debug: bool = False, |
| 43 | + provider: str = "cpu", | ||
| 43 | ): | 44 | ): |
| 44 | """ | 45 | """ |
| 45 | Please refer to | 46 | Please refer to |
| @@ -70,6 +71,8 @@ class OfflineRecognizer(object): | @@ -70,6 +71,8 @@ class OfflineRecognizer(object): | ||
| 70 | Support only greedy_search for now. | 71 | Support only greedy_search for now. |
| 71 | debug: | 72 | debug: |
| 72 | True to show debug messages. | 73 | True to show debug messages. |
| 74 | + provider: | ||
| 75 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 73 | """ | 76 | """ |
| 74 | self = cls.__new__(cls) | 77 | self = cls.__new__(cls) |
| 75 | model_config = OfflineModelConfig( | 78 | model_config = OfflineModelConfig( |
| @@ -81,6 +84,7 @@ class OfflineRecognizer(object): | @@ -81,6 +84,7 @@ class OfflineRecognizer(object): | ||
| 81 | tokens=tokens, | 84 | tokens=tokens, |
| 82 | num_threads=num_threads, | 85 | num_threads=num_threads, |
| 83 | debug=debug, | 86 | debug=debug, |
| 87 | + provider=provider, | ||
| 84 | ) | 88 | ) |
| 85 | 89 | ||
| 86 | feat_config = OfflineFeatureExtractorConfig( | 90 | feat_config = OfflineFeatureExtractorConfig( |
| @@ -39,6 +39,7 @@ class OnlineRecognizer(object): | @@ -39,6 +39,7 @@ class OnlineRecognizer(object): | ||
| 39 | rule3_min_utterance_length: float = 20.0, | 39 | rule3_min_utterance_length: float = 20.0, |
| 40 | decoding_method: str = "greedy_search", | 40 | decoding_method: str = "greedy_search", |
| 41 | max_active_paths: int = 4, | 41 | max_active_paths: int = 4, |
| 42 | + provider: str = "cpu", | ||
| 42 | ): | 43 | ): |
| 43 | """ | 44 | """ |
| 44 | Please refer to | 45 | Please refer to |
| @@ -86,6 +87,8 @@ class OnlineRecognizer(object): | @@ -86,6 +87,8 @@ class OnlineRecognizer(object): | ||
| 86 | max_active_paths: | 87 | max_active_paths: |
| 87 | Use only when decoding_method is modified_beam_search. It specifies | 88 | Use only when decoding_method is modified_beam_search. It specifies |
| 88 | the maximum number of active paths during beam search. | 89 | the maximum number of active paths during beam search. |
| 90 | + provider: | ||
| 91 | + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | ||
| 89 | """ | 92 | """ |
| 90 | _assert_file_exists(tokens) | 93 | _assert_file_exists(tokens) |
| 91 | _assert_file_exists(encoder) | 94 | _assert_file_exists(encoder) |
| @@ -100,6 +103,7 @@ class OnlineRecognizer(object): | @@ -100,6 +103,7 @@ class OnlineRecognizer(object): | ||
| 100 | joiner_filename=joiner, | 103 | joiner_filename=joiner, |
| 101 | tokens=tokens, | 104 | tokens=tokens, |
| 102 | num_threads=num_threads, | 105 | num_threads=num_threads, |
| 106 | + provider=provider, | ||
| 103 | ) | 107 | ) |
| 104 | 108 | ||
| 105 | feat_config = FeatureExtractorConfig( | 109 | feat_config = FeatureExtractorConfig( |
-
请 注册 或 登录 后发表评论