Fangjun Kuang
Committed by GitHub

Support CoreML for macOS (#151)

@@ -35,8 +35,8 @@ set(sources @@ -35,8 +35,8 @@ set(sources
35 offline-transducer-model.cc 35 offline-transducer-model.cc
36 offline-transducer-modified-beam-search-decoder.cc 36 offline-transducer-modified-beam-search-decoder.cc
37 online-conformer-transducer-model.cc 37 online-conformer-transducer-model.cc
38 - online-lm.cc  
39 online-lm-config.cc 38 online-lm-config.cc
  39 + online-lm.cc
40 online-lstm-transducer-model.cc 40 online-lstm-transducer-model.cc
41 online-recognizer.cc 41 online-recognizer.cc
42 online-rnn-lm.cc 42 online-rnn-lm.cc
@@ -48,9 +48,11 @@ set(sources @@ -48,9 +48,11 @@ set(sources
48 online-transducer-modified-beam-search-decoder.cc 48 online-transducer-modified-beam-search-decoder.cc
49 online-zipformer-transducer-model.cc 49 online-zipformer-transducer-model.cc
50 onnx-utils.cc 50 onnx-utils.cc
  51 + session.cc
51 packed-sequence.cc 52 packed-sequence.cc
52 pad-sequence.cc 53 pad-sequence.cc
53 parse-options.cc 54 parse-options.cc
  55 + provider.cc
54 resample.cc 56 resample.cc
55 slice.cc 57 slice.cc
56 stack.cc 58 stack.cc
@@ -22,6 +22,9 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -22,6 +22,9 @@ void OfflineModelConfig::Register(ParseOptions *po) {
22 22
23 po->Register("debug", &debug, 23 po->Register("debug", &debug,
24 "true to print model information while loading it."); 24 "true to print model information while loading it.");
  25 +
  26 + po->Register("provider", &provider,
  27 + "Specify a provider to use: cpu, cuda, coreml");
25 } 28 }
26 29
27 bool OfflineModelConfig::Validate() const { 30 bool OfflineModelConfig::Validate() const {
@@ -55,7 +58,8 @@ std::string OfflineModelConfig::ToString() const { @@ -55,7 +58,8 @@ std::string OfflineModelConfig::ToString() const {
55 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 58 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
56 os << "tokens=\"" << tokens << "\", "; 59 os << "tokens=\"" << tokens << "\", ";
57 os << "num_threads=" << num_threads << ", "; 60 os << "num_threads=" << num_threads << ", ";
58 - os << "debug=" << (debug ? "True" : "False") << ")"; 61 + os << "debug=" << (debug ? "True" : "False") << ", ";
  62 + os << "provider=\"" << provider << "\")";
59 63
60 return os.str(); 64 return os.str();
61 } 65 }
@@ -20,18 +20,21 @@ struct OfflineModelConfig { @@ -20,18 +20,21 @@ struct OfflineModelConfig {
20 std::string tokens; 20 std::string tokens;
21 int32_t num_threads = 2; 21 int32_t num_threads = 2;
22 bool debug = false; 22 bool debug = false;
  23 + std::string provider = "cpu";
23 24
24 OfflineModelConfig() = default; 25 OfflineModelConfig() = default;
25 OfflineModelConfig(const OfflineTransducerModelConfig &transducer, 26 OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
26 const OfflineParaformerModelConfig &paraformer, 27 const OfflineParaformerModelConfig &paraformer,
27 const OfflineNemoEncDecCtcModelConfig &nemo_ctc, 28 const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
28 - const std::string &tokens, int32_t num_threads, bool debug) 29 + const std::string &tokens, int32_t num_threads, bool debug,
  30 + const std::string &provider)
29 : transducer(transducer), 31 : transducer(transducer),
30 paraformer(paraformer), 32 paraformer(paraformer),
31 nemo_ctc(nemo_ctc), 33 nemo_ctc(nemo_ctc),
32 tokens(tokens), 34 tokens(tokens),
33 num_threads(num_threads), 35 num_threads(num_threads),
34 - debug(debug) {} 36 + debug(debug),
  37 + provider(provider) {}
35 38
36 void Register(ParseOptions *po); 39 void Register(ParseOptions *po);
37 bool Validate() const; 40 bool Validate() const;
@@ -6,6 +6,7 @@ @@ -6,6 +6,7 @@
6 6
7 #include "sherpa-onnx/csrc/macros.h" 7 #include "sherpa-onnx/csrc/macros.h"
8 #include "sherpa-onnx/csrc/onnx-utils.h" 8 #include "sherpa-onnx/csrc/onnx-utils.h"
  9 +#include "sherpa-onnx/csrc/session.h"
9 #include "sherpa-onnx/csrc/text-utils.h" 10 #include "sherpa-onnx/csrc/text-utils.h"
10 #include "sherpa-onnx/csrc/transpose.h" 11 #include "sherpa-onnx/csrc/transpose.h"
11 12
@@ -16,11 +17,8 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -16,11 +17,8 @@ class OfflineNemoEncDecCtcModel::Impl {
16 explicit Impl(const OfflineModelConfig &config) 17 explicit Impl(const OfflineModelConfig &config)
17 : config_(config), 18 : config_(config),
18 env_(ORT_LOGGING_LEVEL_ERROR), 19 env_(ORT_LOGGING_LEVEL_ERROR),
19 - sess_opts_{}, 20 + sess_opts_(GetSessionOptions(config)),
20 allocator_{} { 21 allocator_{} {
21 - sess_opts_.SetIntraOpNumThreads(config_.num_threads);  
22 - sess_opts_.SetInterOpNumThreads(config_.num_threads);  
23 -  
24 Init(); 22 Init();
25 } 23 }
26 24
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 9
10 #include "sherpa-onnx/csrc/macros.h" 10 #include "sherpa-onnx/csrc/macros.h"
11 #include "sherpa-onnx/csrc/onnx-utils.h" 11 #include "sherpa-onnx/csrc/onnx-utils.h"
  12 +#include "sherpa-onnx/csrc/session.h"
12 #include "sherpa-onnx/csrc/text-utils.h" 13 #include "sherpa-onnx/csrc/text-utils.h"
13 14
14 namespace sherpa_onnx { 15 namespace sherpa_onnx {
@@ -18,11 +19,8 @@ class OfflineParaformerModel::Impl { @@ -18,11 +19,8 @@ class OfflineParaformerModel::Impl {
18 explicit Impl(const OfflineModelConfig &config) 19 explicit Impl(const OfflineModelConfig &config)
19 : config_(config), 20 : config_(config),
20 env_(ORT_LOGGING_LEVEL_ERROR), 21 env_(ORT_LOGGING_LEVEL_ERROR),
21 - sess_opts_{}, 22 + sess_opts_(GetSessionOptions(config)),
22 allocator_{} { 23 allocator_{} {
23 - sess_opts_.SetIntraOpNumThreads(config_.num_threads);  
24 - sess_opts_.SetInterOpNumThreads(config_.num_threads);  
25 -  
26 Init(); 24 Init();
27 } 25 }
28 26
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/macros.h" 11 #include "sherpa-onnx/csrc/macros.h"
12 #include "sherpa-onnx/csrc/offline-transducer-decoder.h" 12 #include "sherpa-onnx/csrc/offline-transducer-decoder.h"
13 #include "sherpa-onnx/csrc/onnx-utils.h" 13 #include "sherpa-onnx/csrc/onnx-utils.h"
  14 +#include "sherpa-onnx/csrc/session.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
@@ -19,10 +20,8 @@ class OfflineTransducerModel::Impl { @@ -19,10 +20,8 @@ class OfflineTransducerModel::Impl {
19 explicit Impl(const OfflineModelConfig &config) 20 explicit Impl(const OfflineModelConfig &config)
20 : config_(config), 21 : config_(config),
21 env_(ORT_LOGGING_LEVEL_WARNING), 22 env_(ORT_LOGGING_LEVEL_WARNING),
22 - sess_opts_{}, 23 + sess_opts_(GetSessionOptions(config)),
23 allocator_{} { 24 allocator_{} {
24 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
25 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
26 { 25 {
27 auto buf = ReadFile(config.transducer.encoder_filename); 26 auto buf = ReadFile(config.transducer.encoder_filename);
28 InitEncoder(buf.data(), buf.size()); 27 InitEncoder(buf.data(), buf.size());
@@ -9,7 +9,6 @@ @@ -9,7 +9,6 @@
9 #include <algorithm> 9 #include <algorithm>
10 #include <memory> 10 #include <memory>
11 #include <sstream> 11 #include <sstream>
12 -#include <iostream>  
13 #include <string> 12 #include <string>
14 #include <utility> 13 #include <utility>
15 #include <vector> 14 #include <vector>
@@ -24,6 +23,7 @@ @@ -24,6 +23,7 @@
24 #include "sherpa-onnx/csrc/macros.h" 23 #include "sherpa-onnx/csrc/macros.h"
25 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 24 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
26 #include "sherpa-onnx/csrc/onnx-utils.h" 25 #include "sherpa-onnx/csrc/onnx-utils.h"
  26 +#include "sherpa-onnx/csrc/session.h"
27 #include "sherpa-onnx/csrc/text-utils.h" 27 #include "sherpa-onnx/csrc/text-utils.h"
28 #include "sherpa-onnx/csrc/unbind.h" 28 #include "sherpa-onnx/csrc/unbind.h"
29 29
@@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( @@ -33,11 +33,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
33 const OnlineTransducerModelConfig &config) 33 const OnlineTransducerModelConfig &config)
34 : env_(ORT_LOGGING_LEVEL_WARNING), 34 : env_(ORT_LOGGING_LEVEL_WARNING),
35 config_(config), 35 config_(config),
36 - sess_opts_{}, 36 + sess_opts_(GetSessionOptions(config)),
37 allocator_{} { 37 allocator_{} {
38 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
39 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
40 -  
41 { 38 {
42 auto buf = ReadFile(config.encoder_filename); 39 auto buf = ReadFile(config.encoder_filename);
43 InitEncoder(buf.data(), buf.size()); 40 InitEncoder(buf.data(), buf.size());
@@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel( @@ -59,11 +56,8 @@ OnlineConformerTransducerModel::OnlineConformerTransducerModel(
59 AAssetManager *mgr, const OnlineTransducerModelConfig &config) 56 AAssetManager *mgr, const OnlineTransducerModelConfig &config)
60 : env_(ORT_LOGGING_LEVEL_WARNING), 57 : env_(ORT_LOGGING_LEVEL_WARNING),
61 config_(config), 58 config_(config),
62 - sess_opts_{}, 59 + sess_opts_(GetSessionOptions(config)),
63 allocator_{} { 60 allocator_{} {
64 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
65 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
66 -  
67 { 61 {
68 auto buf = ReadFile(mgr, config.encoder_filename); 62 auto buf = ReadFile(mgr, config.encoder_filename);
69 InitEncoder(buf.data(), buf.size()); 63 InitEncoder(buf.data(), buf.size());
@@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() { @@ -209,8 +203,8 @@ std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
209 // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 203 // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
210 // for details 204 // for details
211 constexpr int32_t kBatchSize = 1; 205 constexpr int32_t kBatchSize = 1;
212 - std::array<int64_t, 4> h_shape{  
213 - num_encoder_layers_, left_context_, kBatchSize, encoder_dim_}; 206 + std::array<int64_t, 4> h_shape{num_encoder_layers_, left_context_, kBatchSize,
  207 + encoder_dim_};
214 Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(), 208 Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
215 h_shape.size()); 209 h_shape.size());
216 210
@@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features, @@ -238,9 +232,7 @@ OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
238 std::vector<Ort::Value> states, 232 std::vector<Ort::Value> states,
239 Ort::Value processed_frames) { 233 Ort::Value processed_frames) {
240 std::array<Ort::Value, 4> encoder_inputs = { 234 std::array<Ort::Value, 4> encoder_inputs = {
241 - std::move(features),  
242 - std::move(states[0]),  
243 - std::move(states[1]), 235 + std::move(features), std::move(states[0]), std::move(states[1]),
244 std::move(processed_frames)}; 236 std::move(processed_frames)};
245 237
246 auto encoder_out = encoder_sess_->Run( 238 auto encoder_out = encoder_sess_->Run(
@@ -22,6 +22,7 @@ @@ -22,6 +22,7 @@
22 #include "sherpa-onnx/csrc/macros.h" 22 #include "sherpa-onnx/csrc/macros.h"
23 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 23 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
24 #include "sherpa-onnx/csrc/onnx-utils.h" 24 #include "sherpa-onnx/csrc/onnx-utils.h"
  25 +#include "sherpa-onnx/csrc/session.h"
25 #include "sherpa-onnx/csrc/unbind.h" 26 #include "sherpa-onnx/csrc/unbind.h"
26 27
27 namespace sherpa_onnx { 28 namespace sherpa_onnx {
@@ -30,11 +31,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( @@ -30,11 +31,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
30 const OnlineTransducerModelConfig &config) 31 const OnlineTransducerModelConfig &config)
31 : env_(ORT_LOGGING_LEVEL_WARNING), 32 : env_(ORT_LOGGING_LEVEL_WARNING),
32 config_(config), 33 config_(config),
33 - sess_opts_{}, 34 + sess_opts_(GetSessionOptions(config)),
34 allocator_{} { 35 allocator_{} {
35 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
36 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
37 -  
38 { 36 {
39 auto buf = ReadFile(config.encoder_filename); 37 auto buf = ReadFile(config.encoder_filename);
40 InitEncoder(buf.data(), buf.size()); 38 InitEncoder(buf.data(), buf.size());
@@ -56,11 +54,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel( @@ -56,11 +54,8 @@ OnlineLstmTransducerModel::OnlineLstmTransducerModel(
56 AAssetManager *mgr, const OnlineTransducerModelConfig &config) 54 AAssetManager *mgr, const OnlineTransducerModelConfig &config)
57 : env_(ORT_LOGGING_LEVEL_WARNING), 55 : env_(ORT_LOGGING_LEVEL_WARNING),
58 config_(config), 56 config_(config),
59 - sess_opts_{}, 57 + sess_opts_(GetSessionOptions(config)),
60 allocator_{} { 58 allocator_{} {
61 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
62 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
63 -  
64 { 59 {
65 auto buf = ReadFile(mgr, config.encoder_filename); 60 auto buf = ReadFile(mgr, config.encoder_filename);
66 InitEncoder(buf.data(), buf.size()); 61 InitEncoder(buf.data(), buf.size());
@@ -9,7 +9,6 @@ @@ -9,7 +9,6 @@
9 9
10 #include <algorithm> 10 #include <algorithm>
11 #include <iomanip> 11 #include <iomanip>
12 -#include <iostream>  
13 #include <memory> 12 #include <memory>
14 #include <sstream> 13 #include <sstream>
15 #include <utility> 14 #include <utility>
@@ -140,7 +139,7 @@ class OnlineRecognizer::Impl { @@ -140,7 +139,7 @@ class OnlineRecognizer::Impl {
140 decoder_ = 139 decoder_ =
141 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 140 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
142 } else { 141 } else {
143 - fprintf(stderr, "Unsupported decoding method: %s\n", 142 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
144 config.decoding_method.c_str()); 143 config.decoding_method.c_str());
145 exit(-1); 144 exit(-1);
146 } 145 }
@@ -160,7 +159,7 @@ class OnlineRecognizer::Impl { @@ -160,7 +159,7 @@ class OnlineRecognizer::Impl {
160 decoder_ = 159 decoder_ =
161 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get()); 160 std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
162 } else { 161 } else {
163 - fprintf(stderr, "Unsupported decoding method: %s\n", 162 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
164 config.decoding_method.c_str()); 163 config.decoding_method.c_str());
165 exit(-1); 164 exit(-1);
166 } 165 }
@@ -219,16 +218,13 @@ class OnlineRecognizer::Impl { @@ -219,16 +218,13 @@ class OnlineRecognizer::Impl {
219 static_cast<int64_t>(all_processed_frames.size())}; 218 static_cast<int64_t>(all_processed_frames.size())};
220 219
221 Ort::Value processed_frames = Ort::Value::CreateTensor( 220 Ort::Value processed_frames = Ort::Value::CreateTensor(
222 - memory_info,  
223 - all_processed_frames.data(),  
224 - all_processed_frames.size(),  
225 - processed_frames_shape.data(),  
226 - processed_frames_shape.size()); 221 + memory_info, all_processed_frames.data(), all_processed_frames.size(),
  222 + processed_frames_shape.data(), processed_frames_shape.size());
227 223
228 auto states = model_->StackStates(states_vec); 224 auto states = model_->StackStates(states_vec);
229 225
230 - auto pair = model_->RunEncoder(  
231 - std::move(x), std::move(states), std::move(processed_frames)); 226 + auto pair = model_->RunEncoder(std::move(x), std::move(states),
  227 + std::move(processed_frames));
232 228
233 decoder_->Decode(std::move(pair.first), &results); 229 decoder_->Decode(std::move(pair.first), &results);
234 230
@@ -17,19 +17,21 @@ struct OnlineTransducerModelConfig { @@ -17,19 +17,21 @@ struct OnlineTransducerModelConfig {
17 std::string tokens; 17 std::string tokens;
18 int32_t num_threads = 2; 18 int32_t num_threads = 2;
19 bool debug = false; 19 bool debug = false;
  20 + std::string provider = "cpu";
20 21
21 OnlineTransducerModelConfig() = default; 22 OnlineTransducerModelConfig() = default;
22 OnlineTransducerModelConfig(const std::string &encoder_filename, 23 OnlineTransducerModelConfig(const std::string &encoder_filename,
23 const std::string &decoder_filename, 24 const std::string &decoder_filename,
24 const std::string &joiner_filename, 25 const std::string &joiner_filename,
25 const std::string &tokens, int32_t num_threads, 26 const std::string &tokens, int32_t num_threads,
26 - bool debug) 27 + bool debug, const std::string &provider)
27 : encoder_filename(encoder_filename), 28 : encoder_filename(encoder_filename),
28 decoder_filename(decoder_filename), 29 decoder_filename(decoder_filename),
29 joiner_filename(joiner_filename), 30 joiner_filename(joiner_filename),
30 tokens(tokens), 31 tokens(tokens),
31 num_threads(num_threads), 32 num_threads(num_threads),
32 - debug(debug) {} 33 + debug(debug),
  34 + provider(provider) {}
33 35
34 void Register(ParseOptions *po); 36 void Register(ParseOptions *po);
35 bool Validate() const; 37 bool Validate() const;
@@ -10,7 +10,6 @@ @@ -10,7 +10,6 @@
10 #endif 10 #endif
11 11
12 #include <algorithm> 12 #include <algorithm>
13 -#include <iostream>  
14 #include <memory> 13 #include <memory>
15 #include <sstream> 14 #include <sstream>
16 #include <string> 15 #include <string>
@@ -23,6 +23,7 @@ @@ -23,6 +23,7 @@
23 #include "sherpa-onnx/csrc/macros.h" 23 #include "sherpa-onnx/csrc/macros.h"
24 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 24 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
25 #include "sherpa-onnx/csrc/onnx-utils.h" 25 #include "sherpa-onnx/csrc/onnx-utils.h"
  26 +#include "sherpa-onnx/csrc/session.h"
26 #include "sherpa-onnx/csrc/text-utils.h" 27 #include "sherpa-onnx/csrc/text-utils.h"
27 #include "sherpa-onnx/csrc/unbind.h" 28 #include "sherpa-onnx/csrc/unbind.h"
28 29
@@ -32,11 +33,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( @@ -32,11 +33,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
32 const OnlineTransducerModelConfig &config) 33 const OnlineTransducerModelConfig &config)
33 : env_(ORT_LOGGING_LEVEL_WARNING), 34 : env_(ORT_LOGGING_LEVEL_WARNING),
34 config_(config), 35 config_(config),
35 - sess_opts_{}, 36 + sess_opts_(GetSessionOptions(config)),
36 allocator_{} { 37 allocator_{} {
37 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
38 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
39 -  
40 { 38 {
41 auto buf = ReadFile(config.encoder_filename); 39 auto buf = ReadFile(config.encoder_filename);
42 InitEncoder(buf.data(), buf.size()); 40 InitEncoder(buf.data(), buf.size());
@@ -58,11 +56,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( @@ -58,11 +56,8 @@ OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
58 AAssetManager *mgr, const OnlineTransducerModelConfig &config) 56 AAssetManager *mgr, const OnlineTransducerModelConfig &config)
59 : env_(ORT_LOGGING_LEVEL_WARNING), 57 : env_(ORT_LOGGING_LEVEL_WARNING),
60 config_(config), 58 config_(config),
61 - sess_opts_{}, 59 + sess_opts_(GetSessionOptions(config)),
62 allocator_{} { 60 allocator_{} {
63 - sess_opts_.SetIntraOpNumThreads(config.num_threads);  
64 - sess_opts_.SetInterOpNumThreads(config.num_threads);  
65 -  
66 { 61 {
67 auto buf = ReadFile(mgr, config.encoder_filename); 62 auto buf = ReadFile(mgr, config.encoder_filename);
68 InitEncoder(buf.data(), buf.size()); 63 InitEncoder(buf.data(), buf.size());
  1 +// sherpa-onnx/csrc/provider.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/provider.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cctype>
  9 +
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +Provider StringToProvider(std::string s) {
  15 + std::transform(s.cbegin(), s.cend(), s.begin(),
  16 + [](unsigned char c) { return std::tolower(c); });
  17 + if (s == "cpu") {
  18 + return Provider::kCPU;
  19 + } else if (s == "cuda") {
  20 + return Provider::kCUDA;
  21 + } else if (s == "coreml") {
  22 + return Provider::kCoreML;
  23 + } else {
  24 + SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
  25 + return Provider::kCPU;
  26 + }
  27 +}
  28 +
  29 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/provider.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_PROVIDER_H_
  6 +#define SHERPA_ONNX_CSRC_PROVIDER_H_
  7 +
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +// Please refer to
  13 +// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
  14 +// for a list of available providers
  15 +enum class Provider {
  16 + kCPU = 0, // CPUExecutionProvider
  17 + kCUDA = 1, // CUDAExecutionProvider
  18 + kCoreML = 2, // CoreMLExecutionProvider
  19 +};
  20 +
  21 +/**
  22 + * Convert a string to an enum.
  23 + *
  24 + * @param s We will convert it to lowercase before comparing.
  25 + * @return Return an instance of Provider.
  26 + */
  27 +Provider StringToProvider(std::string s);
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_PROVIDER_H_
  1 +// sherpa-onnx/csrc/session.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/session.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/provider.h"
  12 +#if defined(__APPLE__)
  13 +#include "coreml_provider_factory.h" // NOLINT
  14 +#endif
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
  19 + std::string provider_str) {
  20 + Provider p = StringToProvider(std::move(provider_str));
  21 +
  22 + Ort::SessionOptions sess_opts;
  23 + sess_opts.SetIntraOpNumThreads(num_threads);
  24 + sess_opts.SetInterOpNumThreads(num_threads);
  25 +
  26 + switch (p) {
  27 + case Provider::kCPU:
  28 + break; // nothing to do for the CPU provider
  29 + case Provider::kCUDA: {
  30 + OrtCUDAProviderOptions options;
  31 + options.device_id = 0;
  32 + // set more options on need
  33 + sess_opts.AppendExecutionProvider_CUDA(options);
  34 + break;
  35 + }
  36 + case Provider::kCoreML: {
  37 +#if defined(__APPLE__)
  38 + uint32_t coreml_flags = 0;
  39 + (void)OrtSessionOptionsAppendExecutionProvider_CoreML(sess_opts,
  40 + coreml_flags);
  41 +#else
  42 + SHERPA_ONNX_LOGE("CoreML is for Apple only. Fallback to cpu!");
  43 +#endif
  44 + break;
  45 + }
  46 + }
  47 +
  48 + return sess_opts;
  49 +}
  50 +
  51 +Ort::SessionOptions GetSessionOptions(
  52 + const OnlineTransducerModelConfig &config) {
  53 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  54 +}
  55 +
  56 +Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
  57 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  58 +}
  59 +
  60 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/session.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_SESSION_H_
  6 +#define SHERPA_ONNX_CSRC_SESSION_H_
  7 +
  8 +#include "onnxruntime_cxx_api.h" // NOLINT
  9 +#include "sherpa-onnx/csrc/offline-model-config.h"
  10 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +Ort::SessionOptions GetSessionOptions(
  15 + const OnlineTransducerModelConfig &config);
  16 +
  17 +Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
  18 +
  19 +} // namespace sherpa_onnx
  20 +
  21 +#endif // SHERPA_ONNX_CSRC_SESSION_H_
@@ -6,7 +6,6 @@ @@ -6,7 +6,6 @@
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <functional> 8 #include <functional>
9 -#include <iostream>  
10 #include <numeric> 9 #include <numeric>
11 #include <utility> 10 #include <utility>
12 11
@@ -58,21 +57,17 @@ Ort::Value Stack(OrtAllocator *allocator, @@ -58,21 +57,17 @@ Ort::Value Stack(OrtAllocator *allocator,
58 ans_shape.reserve(v0_shape.size() + 1); 57 ans_shape.reserve(v0_shape.size() + 1);
59 ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); 58 ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
60 ans_shape.push_back(values.size()); 59 ans_shape.push_back(values.size());
61 - ans_shape.insert(  
62 - ans_shape.end(),  
63 - v0_shape.data() + dim, 60 + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim,
64 v0_shape.data() + v0_shape.size()); 61 v0_shape.data() + v0_shape.size());
65 62
66 auto leading_size = static_cast<int32_t>(std::accumulate( 63 auto leading_size = static_cast<int32_t>(std::accumulate(
67 v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>())); 64 v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
68 65
69 - auto trailing_size = static_cast<int32_t>(  
70 - std::accumulate(v0_shape.begin() + dim,  
71 - v0_shape.end(), 1,  
72 - std::multiplies<int64_t>())); 66 + auto trailing_size = static_cast<int32_t>(std::accumulate(
  67 + v0_shape.begin() + dim, v0_shape.end(), 1, std::multiplies<int64_t>()));
73 68
74 - Ort::Value ans = Ort::Value::CreateTensor<T>(  
75 - allocator, ans_shape.data(), ans_shape.size()); 69 + Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
  70 + ans_shape.size());
76 T *dst = ans.GetTensorMutableData<T>(); 71 T *dst = ans.GetTensorMutableData<T>();
77 72
78 for (int32_t i = 0; i != leading_size; ++i) { 73 for (int32_t i = 0; i != leading_size; ++i) {
@@ -88,14 +83,12 @@ Ort::Value Stack(OrtAllocator *allocator, @@ -88,14 +83,12 @@ Ort::Value Stack(OrtAllocator *allocator,
88 return ans; 83 return ans;
89 } 84 }
90 85
91 -template Ort::Value Stack<float>(  
92 - OrtAllocator *allocator, 86 +template Ort::Value Stack<float>(OrtAllocator *allocator,
93 const std::vector<const Ort::Value *> &values, 87 const std::vector<const Ort::Value *> &values,
94 int32_t dim); 88 int32_t dim);
95 89
96 template Ort::Value Stack<int64_t>( 90 template Ort::Value Stack<int64_t>(
97 - OrtAllocator *allocator,  
98 - const std::vector<const Ort::Value *> &values, 91 + OrtAllocator *allocator, const std::vector<const Ort::Value *> &values,
99 int32_t dim); 92 int32_t dim);
100 93
101 } // namespace sherpa_onnx 94 } // namespace sherpa_onnx
@@ -24,17 +24,19 @@ void PybindOfflineModelConfig(py::module *m) { @@ -24,17 +24,19 @@ void PybindOfflineModelConfig(py::module *m) {
24 .def(py::init<const OfflineTransducerModelConfig &, 24 .def(py::init<const OfflineTransducerModelConfig &,
25 const OfflineParaformerModelConfig &, 25 const OfflineParaformerModelConfig &,
26 const OfflineNemoEncDecCtcModelConfig &, 26 const OfflineNemoEncDecCtcModelConfig &,
27 - const std::string &, int32_t, bool>(), 27 + const std::string &, int32_t, bool, const std::string &>(),
28 py::arg("transducer") = OfflineTransducerModelConfig(), 28 py::arg("transducer") = OfflineTransducerModelConfig(),
29 py::arg("paraformer") = OfflineParaformerModelConfig(), 29 py::arg("paraformer") = OfflineParaformerModelConfig(),
30 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), 30 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
31 - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false) 31 + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
  32 + py::arg("provider") = "cpu")
32 .def_readwrite("transducer", &PyClass::transducer) 33 .def_readwrite("transducer", &PyClass::transducer)
33 .def_readwrite("paraformer", &PyClass::paraformer) 34 .def_readwrite("paraformer", &PyClass::paraformer)
34 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 35 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
35 .def_readwrite("tokens", &PyClass::tokens) 36 .def_readwrite("tokens", &PyClass::tokens)
36 .def_readwrite("num_threads", &PyClass::num_threads) 37 .def_readwrite("num_threads", &PyClass::num_threads)
37 .def_readwrite("debug", &PyClass::debug) 38 .def_readwrite("debug", &PyClass::debug)
  39 + .def_readwrite("provider", &PyClass::provider)
38 .def("__str__", &PyClass::ToString); 40 .def("__str__", &PyClass::ToString);
39 } 41 }
40 42
@@ -14,16 +14,19 @@ void PybindOnlineTransducerModelConfig(py::module *m) { @@ -14,16 +14,19 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
14 using PyClass = OnlineTransducerModelConfig; 14 using PyClass = OnlineTransducerModelConfig;
15 py::class_<PyClass>(*m, "OnlineTransducerModelConfig") 15 py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
16 .def(py::init<const std::string &, const std::string &, 16 .def(py::init<const std::string &, const std::string &,
17 - const std::string &, const std::string &, int32_t, bool>(), 17 + const std::string &, const std::string &, int32_t, bool,
  18 + const std::string &>(),
18 py::arg("encoder_filename"), py::arg("decoder_filename"), 19 py::arg("encoder_filename"), py::arg("decoder_filename"),
19 py::arg("joiner_filename"), py::arg("tokens"), 20 py::arg("joiner_filename"), py::arg("tokens"),
20 - py::arg("num_threads"), py::arg("debug") = false) 21 + py::arg("num_threads"), py::arg("debug") = false,
  22 + py::arg("provider") = "cpu")
21 .def_readwrite("encoder_filename", &PyClass::encoder_filename) 23 .def_readwrite("encoder_filename", &PyClass::encoder_filename)
22 .def_readwrite("decoder_filename", &PyClass::decoder_filename) 24 .def_readwrite("decoder_filename", &PyClass::decoder_filename)
23 .def_readwrite("joiner_filename", &PyClass::joiner_filename) 25 .def_readwrite("joiner_filename", &PyClass::joiner_filename)
24 .def_readwrite("tokens", &PyClass::tokens) 26 .def_readwrite("tokens", &PyClass::tokens)
25 .def_readwrite("num_threads", &PyClass::num_threads) 27 .def_readwrite("num_threads", &PyClass::num_threads)
26 .def_readwrite("debug", &PyClass::debug) 28 .def_readwrite("debug", &PyClass::debug)
  29 + .def_readwrite("provider", &PyClass::provider)
27 .def("__str__", &PyClass::ToString); 30 .def("__str__", &PyClass::ToString);
28 } 31 }
29 32
@@ -40,6 +40,7 @@ class OfflineRecognizer(object): @@ -40,6 +40,7 @@ class OfflineRecognizer(object):
40 feature_dim: int = 80, 40 feature_dim: int = 80,
41 decoding_method: str = "greedy_search", 41 decoding_method: str = "greedy_search",
42 debug: bool = False, 42 debug: bool = False,
  43 + provider: str = "cpu",
43 ): 44 ):
44 """ 45 """
45 Please refer to 46 Please refer to
@@ -70,6 +71,8 @@ class OfflineRecognizer(object): @@ -70,6 +71,8 @@ class OfflineRecognizer(object):
70 Support only greedy_search for now. 71 Support only greedy_search for now.
71 debug: 72 debug:
72 True to show debug messages. 73 True to show debug messages.
  74 + provider:
  75 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
73 """ 76 """
74 self = cls.__new__(cls) 77 self = cls.__new__(cls)
75 model_config = OfflineModelConfig( 78 model_config = OfflineModelConfig(
@@ -81,6 +84,7 @@ class OfflineRecognizer(object): @@ -81,6 +84,7 @@ class OfflineRecognizer(object):
81 tokens=tokens, 84 tokens=tokens,
82 num_threads=num_threads, 85 num_threads=num_threads,
83 debug=debug, 86 debug=debug,
  87 + provider=provider,
84 ) 88 )
85 89
86 feat_config = OfflineFeatureExtractorConfig( 90 feat_config = OfflineFeatureExtractorConfig(
@@ -39,6 +39,7 @@ class OnlineRecognizer(object): @@ -39,6 +39,7 @@ class OnlineRecognizer(object):
39 rule3_min_utterance_length: float = 20.0, 39 rule3_min_utterance_length: float = 20.0,
40 decoding_method: str = "greedy_search", 40 decoding_method: str = "greedy_search",
41 max_active_paths: int = 4, 41 max_active_paths: int = 4,
  42 + provider: str = "cpu",
42 ): 43 ):
43 """ 44 """
44 Please refer to 45 Please refer to
@@ -86,6 +87,8 @@ class OnlineRecognizer(object): @@ -86,6 +87,8 @@ class OnlineRecognizer(object):
86 max_active_paths: 87 max_active_paths:
87 Use only when decoding_method is modified_beam_search. It specifies 88 Use only when decoding_method is modified_beam_search. It specifies
88 the maximum number of active paths during beam search. 89 the maximum number of active paths during beam search.
  90 + provider:
  91 + onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
89 """ 92 """
90 _assert_file_exists(tokens) 93 _assert_file_exists(tokens)
91 _assert_file_exists(encoder) 94 _assert_file_exists(encoder)
@@ -100,6 +103,7 @@ class OnlineRecognizer(object): @@ -100,6 +103,7 @@ class OnlineRecognizer(object):
100 joiner_filename=joiner, 103 joiner_filename=joiner,
101 tokens=tokens, 104 tokens=tokens,
102 num_threads=num_threads, 105 num_threads=num_threads,
  106 + provider=provider,
103 ) 107 )
104 108
105 feat_config = FeatureExtractorConfig( 109 feat_config = FeatureExtractorConfig(