Fangjun Kuang
Committed by GitHub

Refactor online recognizer (#250)

* Refactor online recognizer.

Make it easier to support other streaming models.

Note that it is a breaking change for the Python API.
`sherpa_onnx.OnlineRecognizer()` used before should be
replaced by `sherpa_onnx.OnlineRecognizer.from_transducer()`.
正在显示 40 个修改的文件 包含 670 行增加480 行删除
@@ -205,7 +205,7 @@ def main(): @@ -205,7 +205,7 @@ def main():
205 assert_file_exists(args.joiner) 205 assert_file_exists(args.joiner)
206 assert_file_exists(args.tokens) 206 assert_file_exists(args.tokens)
207 207
208 - recognizer = sherpa_onnx.OnlineRecognizer( 208 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
209 tokens=args.tokens, 209 tokens=args.tokens,
210 encoder=args.encoder, 210 encoder=args.encoder,
211 decoder=args.decoder, 211 decoder=args.decoder,
@@ -91,7 +91,7 @@ def create_recognizer(): @@ -91,7 +91,7 @@ def create_recognizer():
91 # Please replace the model files if needed. 91 # Please replace the model files if needed.
92 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 92 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
93 # for download links. 93 # for download links.
94 - recognizer = sherpa_onnx.OnlineRecognizer( 94 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
95 tokens=args.tokens, 95 tokens=args.tokens,
96 encoder=args.encoder, 96 encoder=args.encoder,
97 decoder=args.decoder, 97 decoder=args.decoder,
@@ -145,7 +145,7 @@ def create_recognizer(): @@ -145,7 +145,7 @@ def create_recognizer():
145 # Please replace the model files if needed. 145 # Please replace the model files if needed.
146 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 146 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
147 # for download links. 147 # for download links.
148 - recognizer = sherpa_onnx.OnlineRecognizer( 148 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
149 tokens=args.tokens, 149 tokens=args.tokens,
150 encoder=args.encoder, 150 encoder=args.encoder,
151 decoder=args.decoder, 151 decoder=args.decoder,
@@ -94,7 +94,7 @@ def create_recognizer(args): @@ -94,7 +94,7 @@ def create_recognizer(args):
94 # Please replace the model files if needed. 94 # Please replace the model files if needed.
95 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 95 # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
96 # for download links. 96 # for download links.
97 - recognizer = sherpa_onnx.OnlineRecognizer( 97 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
98 tokens=args.tokens, 98 tokens=args.tokens,
99 encoder=args.encoder, 99 encoder=args.encoder,
100 decoder=args.decoder, 100 decoder=args.decoder,
@@ -294,7 +294,7 @@ def get_args(): @@ -294,7 +294,7 @@ def get_args():
294 294
295 295
296 def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: 296 def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
297 - recognizer = sherpa_onnx.OnlineRecognizer( 297 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
298 tokens=args.tokens, 298 tokens=args.tokens,
299 encoder=args.encoder_model, 299 encoder=args.encoder_model,
300 decoder=args.decoder_model, 300 decoder=args.decoder_model,
@@ -38,11 +38,11 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( @@ -38,11 +38,11 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
38 recognizer_config.feat_config.feature_dim = 38 recognizer_config.feat_config.feature_dim =
39 SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); 39 SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
40 40
41 - recognizer_config.model_config.encoder_filename = 41 + recognizer_config.model_config.transducer.encoder =
42 SHERPA_ONNX_OR(config->model_config.encoder, ""); 42 SHERPA_ONNX_OR(config->model_config.encoder, "");
43 - recognizer_config.model_config.decoder_filename = 43 + recognizer_config.model_config.transducer.decoder =
44 SHERPA_ONNX_OR(config->model_config.decoder, ""); 44 SHERPA_ONNX_OR(config->model_config.decoder, "");
45 - recognizer_config.model_config.joiner_filename = 45 + recognizer_config.model_config.transducer.joiner =
46 SHERPA_ONNX_OR(config->model_config.joiner, ""); 46 SHERPA_ONNX_OR(config->model_config.joiner, "");
47 recognizer_config.model_config.tokens = 47 recognizer_config.model_config.tokens =
48 SHERPA_ONNX_OR(config->model_config.tokens, ""); 48 SHERPA_ONNX_OR(config->model_config.tokens, "");
@@ -143,7 +143,7 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( @@ -143,7 +143,7 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
143 auto count = result.tokens.size(); 143 auto count = result.tokens.size();
144 if (count > 0) { 144 if (count > 0) {
145 size_t total_length = 0; 145 size_t total_length = 0;
146 - for (const auto& token : result.tokens) { 146 + for (const auto &token : result.tokens) {
147 // +1 for the null character at the end of each token 147 // +1 for the null character at the end of each token
148 total_length += token.size() + 1; 148 total_length += token.size() + 1;
149 } 149 }
@@ -154,10 +154,10 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult( @@ -154,10 +154,10 @@ SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
154 memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0, 154 memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
155 total_length); 155 total_length);
156 r->timestamps = new float[r->count]; 156 r->timestamps = new float[r->count];
157 - char **tokens_temp = new char*[r->count]; 157 + char **tokens_temp = new char *[r->count];
158 int32_t pos = 0; 158 int32_t pos = 0;
159 for (int32_t i = 0; i < r->count; ++i) { 159 for (int32_t i = 0; i < r->count; ++i) {
160 - tokens_temp[i] = const_cast<char*>(r->tokens) + pos; 160 + tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
161 memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)), 161 memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
162 result.tokens[i].c_str(), result.tokens[i].size()); 162 result.tokens[i].c_str(), result.tokens[i].size());
163 // +1 to move past the null character 163 // +1 to move past the null character
@@ -43,6 +43,8 @@ set(sources @@ -43,6 +43,8 @@ set(sources
43 online-lm-config.cc 43 online-lm-config.cc
44 online-lm.cc 44 online-lm.cc
45 online-lstm-transducer-model.cc 45 online-lstm-transducer-model.cc
  46 + online-model-config.cc
  47 + online-recognizer-impl.cc
46 online-recognizer.cc 48 online-recognizer.cc
47 online-rnn-lm.cc 49 online-rnn-lm.cc
48 online-stream.cc 50 online-stream.cc
@@ -30,46 +30,46 @@ @@ -30,46 +30,46 @@
30 namespace sherpa_onnx { 30 namespace sherpa_onnx {
31 31
32 OnlineConformerTransducerModel::OnlineConformerTransducerModel( 32 OnlineConformerTransducerModel::OnlineConformerTransducerModel(
33 - const OnlineTransducerModelConfig &config) 33 + const OnlineModelConfig &config)
34 : env_(ORT_LOGGING_LEVEL_WARNING), 34 : env_(ORT_LOGGING_LEVEL_WARNING),
35 config_(config), 35 config_(config),
36 sess_opts_(GetSessionOptions(config)), 36 sess_opts_(GetSessionOptions(config)),
37 allocator_{} { 37 allocator_{} {
38 { 38 {
39 - auto buf = ReadFile(config.encoder_filename); 39 + auto buf = ReadFile(config.transducer.encoder);
40 InitEncoder(buf.data(), buf.size()); 40 InitEncoder(buf.data(), buf.size());
41 } 41 }
42 42
43 { 43 {
44 - auto buf = ReadFile(config.decoder_filename); 44 + auto buf = ReadFile(config.transducer.decoder);
45 InitDecoder(buf.data(), buf.size()); 45 InitDecoder(buf.data(), buf.size());
46 } 46 }
47 47
48 { 48 {
49 - auto buf = ReadFile(config.joiner_filename); 49 + auto buf = ReadFile(config.transducer.joiner);
50 InitJoiner(buf.data(), buf.size()); 50 InitJoiner(buf.data(), buf.size());
51 } 51 }
52 } 52 }
53 53
54 #if __ANDROID_API__ >= 9 54 #if __ANDROID_API__ >= 9
55 OnlineConformerTransducerModel::OnlineConformerTransducerModel( 55 OnlineConformerTransducerModel::OnlineConformerTransducerModel(
56 - AAssetManager *mgr, const OnlineTransducerModelConfig &config) 56 + AAssetManager *mgr, const OnlineModelConfig &config)
57 : env_(ORT_LOGGING_LEVEL_WARNING), 57 : env_(ORT_LOGGING_LEVEL_WARNING),
58 config_(config), 58 config_(config),
59 sess_opts_(GetSessionOptions(config)), 59 sess_opts_(GetSessionOptions(config)),
60 allocator_{} { 60 allocator_{} {
61 { 61 {
62 - auto buf = ReadFile(mgr, config.encoder_filename); 62 + auto buf = ReadFile(mgr, config.transducer.encoder);
63 InitEncoder(buf.data(), buf.size()); 63 InitEncoder(buf.data(), buf.size());
64 } 64 }
65 65
66 { 66 {
67 - auto buf = ReadFile(mgr, config.decoder_filename); 67 + auto buf = ReadFile(mgr, config.transducer.decoder);
68 InitDecoder(buf.data(), buf.size()); 68 InitDecoder(buf.data(), buf.size());
69 } 69 }
70 70
71 { 71 {
72 - auto buf = ReadFile(mgr, config.joiner_filename); 72 + auto buf = ReadFile(mgr, config.transducer.joiner);
73 InitJoiner(buf.data(), buf.size()); 73 InitJoiner(buf.data(), buf.size());
74 } 74 }
75 } 75 }
@@ -16,19 +16,18 @@ @@ -16,19 +16,18 @@
16 #endif 16 #endif
17 17
18 #include "onnxruntime_cxx_api.h" // NOLINT 18 #include "onnxruntime_cxx_api.h" // NOLINT
19 -#include "sherpa-onnx/csrc/online-transducer-model-config.h" 19 +#include "sherpa-onnx/csrc/online-model-config.h"
20 #include "sherpa-onnx/csrc/online-transducer-model.h" 20 #include "sherpa-onnx/csrc/online-transducer-model.h"
21 21
22 namespace sherpa_onnx { 22 namespace sherpa_onnx {
23 23
24 class OnlineConformerTransducerModel : public OnlineTransducerModel { 24 class OnlineConformerTransducerModel : public OnlineTransducerModel {
25 public: 25 public:
26 - explicit OnlineConformerTransducerModel(  
27 - const OnlineTransducerModelConfig &config); 26 + explicit OnlineConformerTransducerModel(const OnlineModelConfig &config);
28 27
29 #if __ANDROID_API__ >= 9 28 #if __ANDROID_API__ >= 9
30 OnlineConformerTransducerModel(AAssetManager *mgr, 29 OnlineConformerTransducerModel(AAssetManager *mgr,
31 - const OnlineTransducerModelConfig &config); 30 + const OnlineModelConfig &config);
32 #endif 31 #endif
33 32
34 std::vector<Ort::Value> StackStates( 33 std::vector<Ort::Value> StackStates(
@@ -88,7 +87,7 @@ class OnlineConformerTransducerModel : public OnlineTransducerModel { @@ -88,7 +87,7 @@ class OnlineConformerTransducerModel : public OnlineTransducerModel {
88 std::vector<std::string> joiner_output_names_; 87 std::vector<std::string> joiner_output_names_;
89 std::vector<const char *> joiner_output_names_ptr_; 88 std::vector<const char *> joiner_output_names_ptr_;
90 89
91 - OnlineTransducerModelConfig config_; 90 + OnlineModelConfig config_;
92 91
93 int32_t num_encoder_layers_ = 0; 92 int32_t num_encoder_layers_ = 0;
94 int32_t T_ = 0; 93 int32_t T_ = 0;
@@ -28,46 +28,46 @@ @@ -28,46 +28,46 @@
28 namespace sherpa_onnx { 28 namespace sherpa_onnx {
29 29
30 OnlineLstmTransducerModel::OnlineLstmTransducerModel( 30 OnlineLstmTransducerModel::OnlineLstmTransducerModel(
31 - const OnlineTransducerModelConfig &config) 31 + const OnlineModelConfig &config)
32 : env_(ORT_LOGGING_LEVEL_WARNING), 32 : env_(ORT_LOGGING_LEVEL_WARNING),
33 config_(config), 33 config_(config),
34 sess_opts_(GetSessionOptions(config)), 34 sess_opts_(GetSessionOptions(config)),
35 allocator_{} { 35 allocator_{} {
36 { 36 {
37 - auto buf = ReadFile(config.encoder_filename); 37 + auto buf = ReadFile(config.transducer.encoder);
38 InitEncoder(buf.data(), buf.size()); 38 InitEncoder(buf.data(), buf.size());
39 } 39 }
40 40
41 { 41 {
42 - auto buf = ReadFile(config.decoder_filename); 42 + auto buf = ReadFile(config.transducer.decoder);
43 InitDecoder(buf.data(), buf.size()); 43 InitDecoder(buf.data(), buf.size());
44 } 44 }
45 45
46 { 46 {
47 - auto buf = ReadFile(config.joiner_filename); 47 + auto buf = ReadFile(config.transducer.joiner);
48 InitJoiner(buf.data(), buf.size()); 48 InitJoiner(buf.data(), buf.size());
49 } 49 }
50 } 50 }
51 51
52 #if __ANDROID_API__ >= 9 52 #if __ANDROID_API__ >= 9
53 OnlineLstmTransducerModel::OnlineLstmTransducerModel( 53 OnlineLstmTransducerModel::OnlineLstmTransducerModel(
54 - AAssetManager *mgr, const OnlineTransducerModelConfig &config) 54 + AAssetManager *mgr, const OnlineModelConfig &config)
55 : env_(ORT_LOGGING_LEVEL_WARNING), 55 : env_(ORT_LOGGING_LEVEL_WARNING),
56 config_(config), 56 config_(config),
57 sess_opts_(GetSessionOptions(config)), 57 sess_opts_(GetSessionOptions(config)),
58 allocator_{} { 58 allocator_{} {
59 { 59 {
60 - auto buf = ReadFile(mgr, config.encoder_filename); 60 + auto buf = ReadFile(mgr, config.transducer.encoder);
61 InitEncoder(buf.data(), buf.size()); 61 InitEncoder(buf.data(), buf.size());
62 } 62 }
63 63
64 { 64 {
65 - auto buf = ReadFile(mgr, config.decoder_filename); 65 + auto buf = ReadFile(mgr, config.transducer.decoder);
66 InitDecoder(buf.data(), buf.size()); 66 InitDecoder(buf.data(), buf.size());
67 } 67 }
68 68
69 { 69 {
70 - auto buf = ReadFile(mgr, config.joiner_filename); 70 + auto buf = ReadFile(mgr, config.transducer.joiner);
71 InitJoiner(buf.data(), buf.size()); 71 InitJoiner(buf.data(), buf.size());
72 } 72 }
73 } 73 }
@@ -15,18 +15,18 @@ @@ -15,18 +15,18 @@
15 #endif 15 #endif
16 16
17 #include "onnxruntime_cxx_api.h" // NOLINT 17 #include "onnxruntime_cxx_api.h" // NOLINT
18 -#include "sherpa-onnx/csrc/online-transducer-model-config.h" 18 +#include "sherpa-onnx/csrc/online-model-config.h"
19 #include "sherpa-onnx/csrc/online-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-transducer-model.h"
20 20
21 namespace sherpa_onnx { 21 namespace sherpa_onnx {
22 22
23 class OnlineLstmTransducerModel : public OnlineTransducerModel { 23 class OnlineLstmTransducerModel : public OnlineTransducerModel {
24 public: 24 public:
25 - explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config); 25 + explicit OnlineLstmTransducerModel(const OnlineModelConfig &config);
26 26
27 #if __ANDROID_API__ >= 9 27 #if __ANDROID_API__ >= 9
28 OnlineLstmTransducerModel(AAssetManager *mgr, 28 OnlineLstmTransducerModel(AAssetManager *mgr,
29 - const OnlineTransducerModelConfig &config); 29 + const OnlineModelConfig &config);
30 #endif 30 #endif
31 31
32 std::vector<Ort::Value> StackStates( 32 std::vector<Ort::Value> StackStates(
@@ -86,7 +86,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { @@ -86,7 +86,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
86 std::vector<std::string> joiner_output_names_; 86 std::vector<std::string> joiner_output_names_;
87 std::vector<const char *> joiner_output_names_ptr_; 87 std::vector<const char *> joiner_output_names_ptr_;
88 88
89 - OnlineTransducerModelConfig config_; 89 + OnlineModelConfig config_;
90 90
91 int32_t num_encoder_layers_ = 0; 91 int32_t num_encoder_layers_ = 0;
92 int32_t T_ = 0; 92 int32_t T_ = 0;
  1 +// sherpa-onnx/csrc/online-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/online-model-config.h"
  5 +
  6 +#include <string>
  7 +
  8 +#include "sherpa-onnx/csrc/file-utils.h"
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OnlineModelConfig::Register(ParseOptions *po) {
  14 + transducer.Register(po);
  15 +
  16 + po->Register("tokens", &tokens, "Path to tokens.txt");
  17 +
  18 + po->Register("num-threads", &num_threads,
  19 + "Number of threads to run the neural network");
  20 +
  21 + po->Register("debug", &debug,
  22 + "true to print model information while loading it.");
  23 +
  24 + po->Register("provider", &provider,
  25 + "Specify a provider to use: cpu, cuda, coreml");
  26 +
  27 + po->Register("model-type", &model_type,
  28 + "Specify it to reduce model initialization time. "
  29 + "Valid values are: conformer, lstm, zipformer, zipformer2."
  30 + "All other values lead to loading the model twice.");
  31 +}
  32 +
  33 +bool OnlineModelConfig::Validate() const {
  34 + if (num_threads < 1) {
  35 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  36 + return false;
  37 + }
  38 +
  39 + if (!FileExists(tokens)) {
  40 + SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
  41 + return false;
  42 + }
  43 +
  44 + return transducer.Validate();
  45 +}
  46 +
  47 +std::string OnlineModelConfig::ToString() const {
  48 + std::ostringstream os;
  49 +
  50 + os << "OnlineModelConfig(";
  51 + os << "transducer=" << transducer.ToString() << ", ";
  52 + os << "tokens=\"" << tokens << "\", ";
  53 + os << "num_threads=" << num_threads << ", ";
  54 + os << "debug=" << (debug ? "True" : "False") << ", ";
  55 + os << "provider=\"" << provider << "\", ";
  56 + os << "model_type=\"" << model_type << "\")";
  57 +
  58 + return os.str();
  59 +}
  60 +
  61 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OnlineModelConfig {
  14 + OnlineTransducerModelConfig transducer;
  15 + std::string tokens;
  16 + int32_t num_threads = 1;
  17 + bool debug = false;
  18 + std::string provider = "cpu";
  19 +
  20 + // Valid values:
  21 + // - conformer, conformer transducer from icefall
  22 + // - lstm, lstm transducer from icefall
  23 + // - zipformer, zipformer transducer from icefall
  24 + // - zipformer2, zipformer2 transducer from icefall
  25 + //
  26 + // All other values are invalid and lead to loading the model twice.
  27 + std::string model_type;
  28 +
  29 + OnlineModelConfig() = default;
  30 + OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
  31 + const std::string &tokens, int32_t num_threads, bool debug,
  32 + const std::string &provider, const std::string &model_type)
  33 + : transducer(transducer),
  34 + tokens(tokens),
  35 + num_threads(num_threads),
  36 + debug(debug),
  37 + provider(provider),
  38 + model_type(model_type) {}
  39 +
  40 + void Register(ParseOptions *po);
  41 + bool Validate() const;
  42 +
  43 + std::string ToString() const;
  44 +};
  45 +
  46 +} // namespace sherpa_onnx
  47 +
  48 +#endif // SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-recognizer-impl.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
  6 +
  7 +#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
  12 + const OnlineRecognizerConfig &config) {
  13 + if (!config.model_config.transducer.encoder.empty()) {
  14 + return std::make_unique<OnlineRecognizerTransducerImpl>(config);
  15 + }
  16 +
  17 + SHERPA_ONNX_LOGE("Please specify a model");
  18 + exit(-1);
  19 +}
  20 +
  21 +#if __ANDROID_API__ >= 9
  22 +std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
  23 + AAssetManager *mgr, const OnlineRecognizerConfig &config) {
  24 + if (!config.model_config.transducer.encoder.empty()) {
  25 + return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
  26 + }
  27 +
  28 + SHERPA_ONNX_LOGE("Please specify a model");
  29 + exit(-1);
  30 +}
  31 +#endif
  32 +
  33 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-recognizer-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
  7 +
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +#include "sherpa-onnx/csrc/online-recognizer.h"
  13 +#include "sherpa-onnx/csrc/online-stream.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OnlineRecognizerImpl {
  18 + public:
  19 + static std::unique_ptr<OnlineRecognizerImpl> Create(
  20 + const OnlineRecognizerConfig &config);
  21 +
  22 +#if __ANDROID_API__ >= 9
  23 + static std::unique_ptr<OnlineRecognizerImpl> Create(
  24 + AAssetManager *mgr, const OnlineRecognizerConfig &config);
  25 +#endif
  26 +
  27 + virtual ~OnlineRecognizerImpl() = default;
  28 +
  29 + virtual void InitOnlineStream(OnlineStream *stream) const = 0;
  30 +
  31 + virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
  32 +
  33 + virtual std::unique_ptr<OnlineStream> CreateStream(
  34 + const std::vector<std::vector<int32_t>> &contexts) const {
  35 + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
  36 + exit(-1);
  37 + }
  38 +
  39 + virtual bool IsReady(OnlineStream *s) const = 0;
  40 +
  41 + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
  42 +
  43 + virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0;
  44 +
  45 + virtual bool IsEndpoint(OnlineStream *s) const = 0;
  46 +
  47 + virtual void Reset(OnlineStream *s) const = 0;
  48 +};
  49 +
  50 +} // namespace sherpa_onnx
  51 +
  52 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_
  1 +// sherpa-onnx/csrc/online-recognizer-transducer-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "sherpa-onnx/csrc/file-utils.h"
  14 +#include "sherpa-onnx/csrc/macros.h"
  15 +#include "sherpa-onnx/csrc/online-lm.h"
  16 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
  17 +#include "sherpa-onnx/csrc/online-recognizer.h"
  18 +#include "sherpa-onnx/csrc/online-transducer-decoder.h"
  19 +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
  20 +#include "sherpa-onnx/csrc/online-transducer-model.h"
  21 +#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
  22 +#include "sherpa-onnx/csrc/symbol-table.h"
  23 +
  24 +namespace sherpa_onnx {
  25 +
  26 +static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
  27 + const SymbolTable &sym_table,
  28 + int32_t frame_shift_ms,
  29 + int32_t subsampling_factor) {
  30 + OnlineRecognizerResult r;
  31 + r.tokens.reserve(src.tokens.size());
  32 + r.timestamps.reserve(src.tokens.size());
  33 +
  34 + for (auto i : src.tokens) {
  35 + auto sym = sym_table[i];
  36 +
  37 + r.text.append(sym);
  38 + r.tokens.push_back(std::move(sym));
  39 + }
  40 +
  41 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  42 + for (auto t : src.timestamps) {
  43 + float time = frame_shift_s * t;
  44 + r.timestamps.push_back(time);
  45 + }
  46 +
  47 + return r;
  48 +}
  49 +
  50 +class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
  51 + public:
  52 + explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config)
  53 + : config_(config),
  54 + model_(OnlineTransducerModel::Create(config.model_config)),
  55 + sym_(config.model_config.tokens),
  56 + endpoint_(config_.endpoint_config) {
  57 + if (config.decoding_method == "modified_beam_search") {
  58 + if (!config_.lm_config.model.empty()) {
  59 + lm_ = OnlineLM::Create(config.lm_config);
  60 + }
  61 +
  62 + decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
  63 + model_.get(), lm_.get(), config_.max_active_paths,
  64 + config_.lm_config.scale);
  65 + } else if (config.decoding_method == "greedy_search") {
  66 + decoder_ =
  67 + std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
  68 + } else {
  69 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  70 + config.decoding_method.c_str());
  71 + exit(-1);
  72 + }
  73 + }
  74 +
  75 +#if __ANDROID_API__ >= 9
  76 + explicit OnlineRecognizerTransducerImpl(AAssetManager *mgr,
  77 + const OnlineRecognizerConfig &config)
  78 + : config_(config),
  79 + model_(OnlineTransducerModel::Create(mgr, config.model_config)),
  80 + sym_(mgr, config.model_config.tokens),
  81 + endpoint_(config_.endpoint_config) {
  82 + if (config.decoding_method == "modified_beam_search") {
  83 + decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
  84 + model_.get(), lm_.get(), config_.max_active_paths,
  85 + config_.lm_config.scale);
  86 + } else if (config.decoding_method == "greedy_search") {
  87 + decoder_ =
  88 + std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
  89 + } else {
  90 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  91 + config.decoding_method.c_str());
  92 + exit(-1);
  93 + }
  94 + }
  95 +#endif
  96 +
  97 + void InitOnlineStream(OnlineStream *stream) const override {
  98 + auto r = decoder_->GetEmptyResult();
  99 +
  100 + if (config_.decoding_method == "modified_beam_search" &&
  101 + nullptr != stream->GetContextGraph()) {
  102 + // r.hyps has only one element.
  103 + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
  104 + it->second.context_state = stream->GetContextGraph()->Root();
  105 + }
  106 + }
  107 +
  108 + stream->SetResult(r);
  109 + stream->SetStates(model_->GetEncoderInitStates());
  110 + }
  111 +
  112 + std::unique_ptr<OnlineStream> CreateStream() const override {
  113 + auto stream = std::make_unique<OnlineStream>(config_.feat_config);
  114 + InitOnlineStream(stream.get());
  115 + return stream;
  116 + }
  117 +
  118 + std::unique_ptr<OnlineStream> CreateStream(
  119 + const std::vector<std::vector<int32_t>> &contexts) const override {
  120 + // We create context_graph at this level, because we might have default
  121 + // context_graph(will be added later if needed) that belongs to the whole
  122 + // model rather than each stream.
  123 + auto context_graph =
  124 + std::make_shared<ContextGraph>(contexts, config_.context_score);
  125 + auto stream =
  126 + std::make_unique<OnlineStream>(config_.feat_config, context_graph);
  127 + InitOnlineStream(stream.get());
  128 + return stream;
  129 + }
  130 +
  131 + bool IsReady(OnlineStream *s) const override {
  132 + return s->GetNumProcessedFrames() + model_->ChunkSize() <
  133 + s->NumFramesReady();
  134 + }
  135 +
  136 + void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  137 + int32_t chunk_size = model_->ChunkSize();
  138 + int32_t chunk_shift = model_->ChunkShift();
  139 +
  140 + int32_t feature_dim = ss[0]->FeatureDim();
  141 +
  142 + std::vector<OnlineTransducerDecoderResult> results(n);
  143 + std::vector<float> features_vec(n * chunk_size * feature_dim);
  144 + std::vector<std::vector<Ort::Value>> states_vec(n);
  145 + std::vector<int64_t> all_processed_frames(n);
  146 + bool has_context_graph = false;
  147 +
  148 + for (int32_t i = 0; i != n; ++i) {
  149 + if (!has_context_graph && ss[i]->GetContextGraph())
  150 + has_context_graph = true;
  151 +
  152 + const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
  153 + std::vector<float> features =
  154 + ss[i]->GetFrames(num_processed_frames, chunk_size);
  155 +
  156 + // Question: should num_processed_frames include chunk_shift?
  157 + ss[i]->GetNumProcessedFrames() += chunk_shift;
  158 +
  159 + std::copy(features.begin(), features.end(),
  160 + features_vec.data() + i * chunk_size * feature_dim);
  161 +
  162 + results[i] = std::move(ss[i]->GetResult());
  163 + states_vec[i] = std::move(ss[i]->GetStates());
  164 + all_processed_frames[i] = num_processed_frames;
  165 + }
  166 +
  167 + auto memory_info =
  168 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  169 +
  170 + std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
  171 +
  172 + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
  173 + features_vec.size(), x_shape.data(),
  174 + x_shape.size());
  175 +
  176 + std::array<int64_t, 1> processed_frames_shape{
  177 + static_cast<int64_t>(all_processed_frames.size())};
  178 +
  179 + Ort::Value processed_frames = Ort::Value::CreateTensor(
  180 + memory_info, all_processed_frames.data(), all_processed_frames.size(),
  181 + processed_frames_shape.data(), processed_frames_shape.size());
  182 +
  183 + auto states = model_->StackStates(states_vec);
  184 +
  185 + auto pair = model_->RunEncoder(std::move(x), std::move(states),
  186 + std::move(processed_frames));
  187 +
  188 + if (has_context_graph) {
  189 + decoder_->Decode(std::move(pair.first), ss, &results);
  190 + } else {
  191 + decoder_->Decode(std::move(pair.first), &results);
  192 + }
  193 +
  194 + std::vector<std::vector<Ort::Value>> next_states =
  195 + model_->UnStackStates(pair.second);
  196 +
  197 + for (int32_t i = 0; i != n; ++i) {
  198 + ss[i]->SetResult(results[i]);
  199 + ss[i]->SetStates(std::move(next_states[i]));
  200 + }
  201 + }
  202 +
  203 + OnlineRecognizerResult GetResult(OnlineStream *s) const override {
  204 + OnlineTransducerDecoderResult decoder_result = s->GetResult();
  205 + decoder_->StripLeadingBlanks(&decoder_result);
  206 +
  207 + // TODO(fangjun): Remember to change these constants if needed
  208 + int32_t frame_shift_ms = 10;
  209 + int32_t subsampling_factor = 4;
  210 + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
  211 + }
  212 +
  213 + bool IsEndpoint(OnlineStream *s) const override {
  214 + if (!config_.enable_endpoint) return false;
  215 + int32_t num_processed_frames = s->GetNumProcessedFrames();
  216 +
  217 + // frame shift is 10 milliseconds
  218 + float frame_shift_in_seconds = 0.01;
  219 +
  220 + // subsampling factor is 4
  221 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;
  222 +
  223 + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
  224 + frame_shift_in_seconds);
  225 + }
  226 +
  227 + void Reset(OnlineStream *s) const override {
  228 + // we keep the decoder_out
  229 + decoder_->UpdateDecoderOut(&s->GetResult());
  230 + Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
  231 + s->SetResult(decoder_->GetEmptyResult());
  232 + s->GetResult().decoder_out = std::move(decoder_out);
  233 +
  234 + // Note: We only update counters. The underlying audio samples
  235 + // are not discarded.
  236 + s->Reset();
  237 + }
  238 +
  239 + private:
  240 + OnlineRecognizerConfig config_;
  241 + std::unique_ptr<OnlineTransducerModel> model_;
  242 + std::unique_ptr<OnlineLM> lm_;
  243 + std::unique_ptr<OnlineTransducerDecoder> decoder_;
  244 + SymbolTable sym_;
  245 + Endpoint endpoint_;
  246 +};
  247 +
  248 +} // namespace sherpa_onnx
  249 +
  250 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
@@ -15,14 +15,7 @@ @@ -15,14 +15,7 @@
15 #include <vector> 15 #include <vector>
16 16
17 #include "nlohmann/json.hpp" 17 #include "nlohmann/json.hpp"
18 -#include "sherpa-onnx/csrc/file-utils.h"  
19 -#include "sherpa-onnx/csrc/macros.h"  
20 -#include "sherpa-onnx/csrc/online-lm.h"  
21 -#include "sherpa-onnx/csrc/online-transducer-decoder.h"  
22 -#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"  
23 -#include "sherpa-onnx/csrc/online-transducer-model.h"  
24 -#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"  
25 -#include "sherpa-onnx/csrc/symbol-table.h" 18 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
26 19
27 namespace sherpa_onnx { 20 namespace sherpa_onnx {
28 21
@@ -54,30 +47,6 @@ std::string OnlineRecognizerResult::AsJsonString() const { @@ -54,30 +47,6 @@ std::string OnlineRecognizerResult::AsJsonString() const {
54 return j.dump(); 47 return j.dump();
55 } 48 }
56 49
57 -static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,  
58 - const SymbolTable &sym_table,  
59 - int32_t frame_shift_ms,  
60 - int32_t subsampling_factor) {  
61 - OnlineRecognizerResult r;  
62 - r.tokens.reserve(src.tokens.size());  
63 - r.timestamps.reserve(src.tokens.size());  
64 -  
65 - for (auto i : src.tokens) {  
66 - auto sym = sym_table[i];  
67 -  
68 - r.text.append(sym);  
69 - r.tokens.push_back(std::move(sym));  
70 - }  
71 -  
72 - float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;  
73 - for (auto t : src.timestamps) {  
74 - float time = frame_shift_s * t;  
75 - r.timestamps.push_back(time);  
76 - }  
77 -  
78 - return r;  
79 -}  
80 -  
81 void OnlineRecognizerConfig::Register(ParseOptions *po) { 50 void OnlineRecognizerConfig::Register(ParseOptions *po) {
82 feat_config.Register(po); 51 feat_config.Register(po);
83 model_config.Register(po); 52 model_config.Register(po);
@@ -124,210 +93,13 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -124,210 +93,13 @@ std::string OnlineRecognizerConfig::ToString() const {
124 return os.str(); 93 return os.str();
125 } 94 }
126 95
127 -class OnlineRecognizer::Impl {  
128 - public:  
129 - explicit Impl(const OnlineRecognizerConfig &config)  
130 - : config_(config),  
131 - model_(OnlineTransducerModel::Create(config.model_config)),  
132 - sym_(config.model_config.tokens),  
133 - endpoint_(config_.endpoint_config) {  
134 - if (config.decoding_method == "modified_beam_search") {  
135 - if (!config_.lm_config.model.empty()) {  
136 - lm_ = OnlineLM::Create(config.lm_config);  
137 - }  
138 -  
139 - decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(  
140 - model_.get(), lm_.get(), config_.max_active_paths,  
141 - config_.lm_config.scale);  
142 - } else if (config.decoding_method == "greedy_search") {  
143 - decoder_ =  
144 - std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());  
145 - } else {  
146 - SHERPA_ONNX_LOGE("Unsupported decoding method: %s",  
147 - config.decoding_method.c_str());  
148 - exit(-1);  
149 - }  
150 - }  
151 -  
152 -#if __ANDROID_API__ >= 9  
153 - explicit Impl(AAssetManager *mgr, const OnlineRecognizerConfig &config)  
154 - : config_(config),  
155 - model_(OnlineTransducerModel::Create(mgr, config.model_config)),  
156 - sym_(mgr, config.model_config.tokens),  
157 - endpoint_(config_.endpoint_config) {  
158 - if (config.decoding_method == "modified_beam_search") {  
159 - decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(  
160 - model_.get(), lm_.get(), config_.max_active_paths,  
161 - config_.lm_config.scale);  
162 - } else if (config.decoding_method == "greedy_search") {  
163 - decoder_ =  
164 - std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());  
165 - } else {  
166 - SHERPA_ONNX_LOGE("Unsupported decoding method: %s",  
167 - config.decoding_method.c_str());  
168 - exit(-1);  
169 - }  
170 - }  
171 -#endif  
172 -  
173 - void InitOnlineStream(OnlineStream *stream) const {  
174 - auto r = decoder_->GetEmptyResult();  
175 -  
176 - if (config_.decoding_method == "modified_beam_search" &&  
177 - nullptr != stream->GetContextGraph()) {  
178 - // r.hyps has only one element.  
179 - for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {  
180 - it->second.context_state = stream->GetContextGraph()->Root();  
181 - }  
182 - }  
183 -  
184 - stream->SetResult(r);  
185 - stream->SetStates(model_->GetEncoderInitStates());  
186 - }  
187 -  
188 - std::unique_ptr<OnlineStream> CreateStream() const {  
189 - auto stream = std::make_unique<OnlineStream>(config_.feat_config);  
190 - InitOnlineStream(stream.get());  
191 - return stream;  
192 - }  
193 -  
194 - std::unique_ptr<OnlineStream> CreateStream(  
195 - const std::vector<std::vector<int32_t>> &contexts) const {  
196 - // We create context_graph at this level, because we might have default  
197 - // context_graph(will be added later if needed) that belongs to the whole  
198 - // model rather than each stream.  
199 - auto context_graph =  
200 - std::make_shared<ContextGraph>(contexts, config_.context_score);  
201 - auto stream =  
202 - std::make_unique<OnlineStream>(config_.feat_config, context_graph);  
203 - InitOnlineStream(stream.get());  
204 - return stream;  
205 - }  
206 -  
207 - bool IsReady(OnlineStream *s) const {  
208 - return s->GetNumProcessedFrames() + model_->ChunkSize() <  
209 - s->NumFramesReady();  
210 - }  
211 -  
212 - void DecodeStreams(OnlineStream **ss, int32_t n) const {  
213 - int32_t chunk_size = model_->ChunkSize();  
214 - int32_t chunk_shift = model_->ChunkShift();  
215 -  
216 - int32_t feature_dim = ss[0]->FeatureDim();  
217 -  
218 - std::vector<OnlineTransducerDecoderResult> results(n);  
219 - std::vector<float> features_vec(n * chunk_size * feature_dim);  
220 - std::vector<std::vector<Ort::Value>> states_vec(n);  
221 - std::vector<int64_t> all_processed_frames(n);  
222 - bool has_context_graph = false;  
223 -  
224 - for (int32_t i = 0; i != n; ++i) {  
225 - if (!has_context_graph && ss[i]->GetContextGraph())  
226 - has_context_graph = true;  
227 -  
228 - const auto num_processed_frames = ss[i]->GetNumProcessedFrames();  
229 - std::vector<float> features =  
230 - ss[i]->GetFrames(num_processed_frames, chunk_size);  
231 -  
232 - // Question: should num_processed_frames include chunk_shift?  
233 - ss[i]->GetNumProcessedFrames() += chunk_shift;  
234 -  
235 - std::copy(features.begin(), features.end(),  
236 - features_vec.data() + i * chunk_size * feature_dim);  
237 -  
238 - results[i] = std::move(ss[i]->GetResult());  
239 - states_vec[i] = std::move(ss[i]->GetStates());  
240 - all_processed_frames[i] = num_processed_frames;  
241 - }  
242 -  
243 - auto memory_info =  
244 - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);  
245 -  
246 - std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};  
247 -  
248 - Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),  
249 - features_vec.size(), x_shape.data(),  
250 - x_shape.size());  
251 -  
252 - std::array<int64_t, 1> processed_frames_shape{  
253 - static_cast<int64_t>(all_processed_frames.size())};  
254 -  
255 - Ort::Value processed_frames = Ort::Value::CreateTensor(  
256 - memory_info, all_processed_frames.data(), all_processed_frames.size(),  
257 - processed_frames_shape.data(), processed_frames_shape.size());  
258 -  
259 - auto states = model_->StackStates(states_vec);  
260 -  
261 - auto pair = model_->RunEncoder(std::move(x), std::move(states),  
262 - std::move(processed_frames));  
263 -  
264 - if (has_context_graph) {  
265 - decoder_->Decode(std::move(pair.first), ss, &results);  
266 - } else {  
267 - decoder_->Decode(std::move(pair.first), &results);  
268 - }  
269 -  
270 - std::vector<std::vector<Ort::Value>> next_states =  
271 - model_->UnStackStates(pair.second);  
272 -  
273 - for (int32_t i = 0; i != n; ++i) {  
274 - ss[i]->SetResult(results[i]);  
275 - ss[i]->SetStates(std::move(next_states[i]));  
276 - }  
277 - }  
278 -  
279 - OnlineRecognizerResult GetResult(OnlineStream *s) const {  
280 - OnlineTransducerDecoderResult decoder_result = s->GetResult();  
281 - decoder_->StripLeadingBlanks(&decoder_result);  
282 -  
283 - // TODO(fangjun): Remember to change these constants if needed  
284 - int32_t frame_shift_ms = 10;  
285 - int32_t subsampling_factor = 4;  
286 - return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);  
287 - }  
288 -  
289 - bool IsEndpoint(OnlineStream *s) const {  
290 - if (!config_.enable_endpoint) return false;  
291 - int32_t num_processed_frames = s->GetNumProcessedFrames();  
292 -  
293 - // frame shift is 10 milliseconds  
294 - float frame_shift_in_seconds = 0.01;  
295 -  
296 - // subsampling factor is 4  
297 - int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4;  
298 -  
299 - return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,  
300 - frame_shift_in_seconds);  
301 - }  
302 -  
303 - void Reset(OnlineStream *s) const {  
304 - // we keep the decoder_out  
305 - decoder_->UpdateDecoderOut(&s->GetResult());  
306 - Ort::Value decoder_out = std::move(s->GetResult().decoder_out);  
307 - s->SetResult(decoder_->GetEmptyResult());  
308 - s->GetResult().decoder_out = std::move(decoder_out);  
309 -  
310 - // Note: We only update counters. The underlying audio samples  
311 - // are not discarded.  
312 - s->Reset();  
313 - }  
314 -  
315 - private:  
316 - OnlineRecognizerConfig config_;  
317 - std::unique_ptr<OnlineTransducerModel> model_;  
318 - std::unique_ptr<OnlineLM> lm_;  
319 - std::unique_ptr<OnlineTransducerDecoder> decoder_;  
320 - SymbolTable sym_;  
321 - Endpoint endpoint_;  
322 -};  
323 -  
324 OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) 96 OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
325 - : impl_(std::make_unique<Impl>(config)) {} 97 + : impl_(OnlineRecognizerImpl::Create(config)) {}
326 98
327 #if __ANDROID_API__ >= 9 99 #if __ANDROID_API__ >= 9
328 OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr, 100 OnlineRecognizer::OnlineRecognizer(AAssetManager *mgr,
329 const OnlineRecognizerConfig &config) 101 const OnlineRecognizerConfig &config)
330 - : impl_(std::make_unique<Impl>(mgr, config)) {} 102 + : impl_(OnlineRecognizerImpl::Create(mgr, config)) {}
331 #endif 103 #endif
332 104
333 OnlineRecognizer::~OnlineRecognizer() = default; 105 OnlineRecognizer::~OnlineRecognizer() = default;
@@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
17 #include "sherpa-onnx/csrc/endpoint.h" 17 #include "sherpa-onnx/csrc/endpoint.h"
18 #include "sherpa-onnx/csrc/features.h" 18 #include "sherpa-onnx/csrc/features.h"
19 #include "sherpa-onnx/csrc/online-lm-config.h" 19 #include "sherpa-onnx/csrc/online-lm-config.h"
  20 +#include "sherpa-onnx/csrc/online-model-config.h"
20 #include "sherpa-onnx/csrc/online-stream.h" 21 #include "sherpa-onnx/csrc/online-stream.h"
21 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 22 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
22 #include "sherpa-onnx/csrc/parse-options.h" 23 #include "sherpa-onnx/csrc/parse-options.h"
@@ -67,7 +68,7 @@ struct OnlineRecognizerResult { @@ -67,7 +68,7 @@ struct OnlineRecognizerResult {
67 68
68 struct OnlineRecognizerConfig { 69 struct OnlineRecognizerConfig {
69 FeatureExtractorConfig feat_config; 70 FeatureExtractorConfig feat_config;
70 - OnlineTransducerModelConfig model_config; 71 + OnlineModelConfig model_config;
71 OnlineLMConfig lm_config; 72 OnlineLMConfig lm_config;
72 EndpointConfig endpoint_config; 73 EndpointConfig endpoint_config;
73 bool enable_endpoint = true; 74 bool enable_endpoint = true;
@@ -83,7 +84,7 @@ struct OnlineRecognizerConfig { @@ -83,7 +84,7 @@ struct OnlineRecognizerConfig {
83 OnlineRecognizerConfig() = default; 84 OnlineRecognizerConfig() = default;
84 85
85 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, 86 OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
86 - const OnlineTransducerModelConfig &model_config, 87 + const OnlineModelConfig &model_config,
87 const OnlineLMConfig &lm_config, 88 const OnlineLMConfig &lm_config,
88 const EndpointConfig &endpoint_config, 89 const EndpointConfig &endpoint_config,
89 bool enable_endpoint, 90 bool enable_endpoint,
@@ -103,6 +104,8 @@ struct OnlineRecognizerConfig { @@ -103,6 +104,8 @@ struct OnlineRecognizerConfig {
103 std::string ToString() const; 104 std::string ToString() const;
104 }; 105 };
105 106
  107 +class OnlineRecognizerImpl;
  108 +
106 class OnlineRecognizer { 109 class OnlineRecognizer {
107 public: 110 public:
108 explicit OnlineRecognizer(const OnlineRecognizerConfig &config); 111 explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
@@ -151,8 +154,7 @@ class OnlineRecognizer { @@ -151,8 +154,7 @@ class OnlineRecognizer {
151 void Reset(OnlineStream *s) const; 154 void Reset(OnlineStream *s) const;
152 155
153 private: 156 private:
154 - class Impl;  
155 - std::unique_ptr<Impl> impl_; 157 + std::unique_ptr<OnlineRecognizerImpl> impl_;
156 }; 158 };
157 159
158 } // namespace sherpa_onnx 160 } // namespace sherpa_onnx
@@ -11,46 +11,24 @@ @@ -11,46 +11,24 @@
11 namespace sherpa_onnx { 11 namespace sherpa_onnx {
12 12
13 void OnlineTransducerModelConfig::Register(ParseOptions *po) { 13 void OnlineTransducerModelConfig::Register(ParseOptions *po) {
14 - po->Register("encoder", &encoder_filename, "Path to encoder.onnx");  
15 - po->Register("decoder", &decoder_filename, "Path to decoder.onnx");  
16 - po->Register("joiner", &joiner_filename, "Path to joiner.onnx");  
17 - po->Register("tokens", &tokens, "Path to tokens.txt");  
18 - po->Register("num_threads", &num_threads,  
19 - "Number of threads to run the neural network");  
20 - po->Register("provider", &provider,  
21 - "Specify a provider to use: cpu, cuda, coreml");  
22 -  
23 - po->Register("debug", &debug,  
24 - "true to print model information while loading it.");  
25 - po->Register("model-type", &model_type,  
26 - "Specify it to reduce model initialization time. "  
27 - "Valid values are: conformer, lstm, zipformer, zipformer2. "  
28 - "All other values lead to loading the model twice."); 14 + po->Register("encoder", &encoder, "Path to encoder.onnx");
  15 + po->Register("decoder", &decoder, "Path to decoder.onnx");
  16 + po->Register("joiner", &joiner, "Path to joiner.onnx");
29 } 17 }
30 18
31 bool OnlineTransducerModelConfig::Validate() const { 19 bool OnlineTransducerModelConfig::Validate() const {
32 - if (!FileExists(tokens)) {  
33 - SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());  
34 - return false;  
35 - }  
36 -  
37 - if (!FileExists(encoder_filename)) {  
38 - SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());  
39 - return false;  
40 - }  
41 -  
42 - if (!FileExists(decoder_filename)) {  
43 - SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); 20 + if (!FileExists(encoder)) {
  21 + SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", encoder.c_str());
44 return false; 22 return false;
45 } 23 }
46 24
47 - if (!FileExists(joiner_filename)) {  
48 - SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); 25 + if (!FileExists(decoder)) {
  26 + SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", decoder.c_str());
49 return false; 27 return false;
50 } 28 }
51 29
52 - if (num_threads < 1) {  
53 - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); 30 + if (!FileExists(joiner)) {
  31 + SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner.c_str());
54 return false; 32 return false;
55 } 33 }
56 34
@@ -61,14 +39,9 @@ std::string OnlineTransducerModelConfig::ToString() const { @@ -61,14 +39,9 @@ std::string OnlineTransducerModelConfig::ToString() const {
61 std::ostringstream os; 39 std::ostringstream os;
62 40
63 os << "OnlineTransducerModelConfig("; 41 os << "OnlineTransducerModelConfig(";
64 - os << "encoder_filename=\"" << encoder_filename << "\", ";  
65 - os << "decoder_filename=\"" << decoder_filename << "\", ";  
66 - os << "joiner_filename=\"" << joiner_filename << "\", ";  
67 - os << "tokens=\"" << tokens << "\", ";  
68 - os << "num_threads=" << num_threads << ", ";  
69 - os << "provider=\"" << provider << "\", ";  
70 - os << "model_type=\"" << model_type << "\", ";  
71 - os << "debug=" << (debug ? "True" : "False") << ")"; 42 + os << "encoder=\"" << encoder << "\", ";
  43 + os << "decoder=\"" << decoder << "\", ";
  44 + os << "joiner=\"" << joiner << "\")";
72 45
73 return os.str(); 46 return os.str();
74 } 47 }
@@ -11,41 +11,15 @@ @@ -11,41 +11,15 @@
11 namespace sherpa_onnx { 11 namespace sherpa_onnx {
12 12
13 struct OnlineTransducerModelConfig { 13 struct OnlineTransducerModelConfig {
14 - std::string encoder_filename;  
15 - std::string decoder_filename;  
16 - std::string joiner_filename;  
17 - std::string tokens;  
18 - int32_t num_threads = 2;  
19 - bool debug = false;  
20 - std::string provider = "cpu";  
21 -  
22 - // With the help of this field, we only need to load the model once  
23 - // instead of twice; and therefore it reduces initialization time.  
24 - //  
25 - // Valid values:  
26 - // - conformer  
27 - // - lstm  
28 - // - zipformer  
29 - // - zipformer2  
30 - //  
31 - // All other values are invalid and lead to loading the model twice.  
32 - std::string model_type; 14 + std::string encoder;
  15 + std::string decoder;
  16 + std::string joiner;
33 17
34 OnlineTransducerModelConfig() = default; 18 OnlineTransducerModelConfig() = default;
35 - OnlineTransducerModelConfig(const std::string &encoder_filename,  
36 - const std::string &decoder_filename,  
37 - const std::string &joiner_filename,  
38 - const std::string &tokens, int32_t num_threads,  
39 - bool debug, const std::string &provider,  
40 - const std::string &model_type)  
41 - : encoder_filename(encoder_filename),  
42 - decoder_filename(decoder_filename),  
43 - joiner_filename(joiner_filename),  
44 - tokens(tokens),  
45 - num_threads(num_threads),  
46 - debug(debug),  
47 - provider(provider),  
48 - model_type(model_type) {} 19 + OnlineTransducerModelConfig(const std::string &encoder,
  20 + const std::string &decoder,
  21 + const std::string &joiner)
  22 + : encoder(encoder), decoder(decoder), joiner(joiner) {}
49 23
50 void Register(ParseOptions *po); 24 void Register(ParseOptions *po);
51 bool Validate() const; 25 bool Validate() const;
@@ -76,7 +76,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -76,7 +76,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
76 } 76 }
77 77
78 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( 78 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
79 - const OnlineTransducerModelConfig &config) { 79 + const OnlineModelConfig &config) {
80 if (!config.model_type.empty()) { 80 if (!config.model_type.empty()) {
81 const auto &model_type = config.model_type; 81 const auto &model_type = config.model_type;
82 if (model_type == "conformer") { 82 if (model_type == "conformer") {
@@ -96,7 +96,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -96,7 +96,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
96 ModelType model_type = ModelType::kUnkown; 96 ModelType model_type = ModelType::kUnkown;
97 97
98 { 98 {
99 - auto buffer = ReadFile(config.encoder_filename); 99 + auto buffer = ReadFile(config.transducer.encoder);
100 100
101 model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 101 model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
102 } 102 }
@@ -155,7 +155,7 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput( @@ -155,7 +155,7 @@ Ort::Value OnlineTransducerModel::BuildDecoderInput(
155 155
156 #if __ANDROID_API__ >= 9 156 #if __ANDROID_API__ >= 9
157 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( 157 std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
158 - AAssetManager *mgr, const OnlineTransducerModelConfig &config) { 158 + AAssetManager *mgr, const OnlineModelConfig &config) {
159 if (!config.model_type.empty()) { 159 if (!config.model_type.empty()) {
160 const auto &model_type = config.model_type; 160 const auto &model_type = config.model_type;
161 if (model_type == "conformer") { 161 if (model_type == "conformer") {
@@ -173,7 +173,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create( @@ -173,7 +173,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
173 } 173 }
174 } 174 }
175 175
176 - auto buffer = ReadFile(mgr, config.encoder_filename); 176 + auto buffer = ReadFile(mgr, config.transducer.encoder);
177 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); 177 auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
178 178
179 switch (model_type) { 179 switch (model_type) {
@@ -15,6 +15,7 @@ @@ -15,6 +15,7 @@
15 15
16 #include "onnxruntime_cxx_api.h" // NOLINT 16 #include "onnxruntime_cxx_api.h" // NOLINT
17 #include "sherpa-onnx/csrc/hypothesis.h" 17 #include "sherpa-onnx/csrc/hypothesis.h"
  18 +#include "sherpa-onnx/csrc/online-model-config.h"
18 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 19 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
19 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 20 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
20 21
@@ -27,11 +28,11 @@ class OnlineTransducerModel { @@ -27,11 +28,11 @@ class OnlineTransducerModel {
27 virtual ~OnlineTransducerModel() = default; 28 virtual ~OnlineTransducerModel() = default;
28 29
29 static std::unique_ptr<OnlineTransducerModel> Create( 30 static std::unique_ptr<OnlineTransducerModel> Create(
30 - const OnlineTransducerModelConfig &config); 31 + const OnlineModelConfig &config);
31 32
32 #if __ANDROID_API__ >= 9 33 #if __ANDROID_API__ >= 9
33 static std::unique_ptr<OnlineTransducerModel> Create( 34 static std::unique_ptr<OnlineTransducerModel> Create(
34 - AAssetManager *mgr, const OnlineTransducerModelConfig &config); 35 + AAssetManager *mgr, const OnlineModelConfig &config);
35 #endif 36 #endif
36 37
37 /** Stack a list of individual states into a batch. 38 /** Stack a list of individual states into a batch.
@@ -64,15 +65,15 @@ class OnlineTransducerModel { @@ -64,15 +65,15 @@ class OnlineTransducerModel {
64 * 65 *
65 * @param features A tensor of shape (N, T, C). It is changed in-place. 66 * @param features A tensor of shape (N, T, C). It is changed in-place.
66 * @param states Encoder state of the previous chunk. It is changed in-place. 67 * @param states Encoder state of the previous chunk. It is changed in-place.
67 - * @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t. 68 + * @param processed_frames Processed frames before subsampling. It is a 1-D
  69 + * tensor with data type int64_t.
68 * 70 *
69 * @return Return a tuple containing: 71 * @return Return a tuple containing:
70 * - encoder_out, a tensor of shape (N, T', encoder_out_dim) 72 * - encoder_out, a tensor of shape (N, T', encoder_out_dim)
71 * - next_states Encoder state for the next chunk. 73 * - next_states Encoder state for the next chunk.
72 */ 74 */
73 virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder( 75 virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
74 - Ort::Value features,  
75 - std::vector<Ort::Value> states, 76 + Ort::Value features, std::vector<Ort::Value> states,
76 Ort::Value processed_frames) = 0; // NOLINT 77 Ort::Value processed_frames) = 0; // NOLINT
77 78
78 /** Run the decoder network. 79 /** Run the decoder network.
@@ -30,46 +30,46 @@ @@ -30,46 +30,46 @@
30 namespace sherpa_onnx { 30 namespace sherpa_onnx {
31 31
32 OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( 32 OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
33 - const OnlineTransducerModelConfig &config) 33 + const OnlineModelConfig &config)
34 : env_(ORT_LOGGING_LEVEL_WARNING), 34 : env_(ORT_LOGGING_LEVEL_WARNING),
35 config_(config), 35 config_(config),
36 sess_opts_(GetSessionOptions(config)), 36 sess_opts_(GetSessionOptions(config)),
37 allocator_{} { 37 allocator_{} {
38 { 38 {
39 - auto buf = ReadFile(config.encoder_filename); 39 + auto buf = ReadFile(config.transducer.encoder);
40 InitEncoder(buf.data(), buf.size()); 40 InitEncoder(buf.data(), buf.size());
41 } 41 }
42 42
43 { 43 {
44 - auto buf = ReadFile(config.decoder_filename); 44 + auto buf = ReadFile(config.transducer.decoder);
45 InitDecoder(buf.data(), buf.size()); 45 InitDecoder(buf.data(), buf.size());
46 } 46 }
47 47
48 { 48 {
49 - auto buf = ReadFile(config.joiner_filename); 49 + auto buf = ReadFile(config.transducer.joiner);
50 InitJoiner(buf.data(), buf.size()); 50 InitJoiner(buf.data(), buf.size());
51 } 51 }
52 } 52 }
53 53
54 #if __ANDROID_API__ >= 9 54 #if __ANDROID_API__ >= 9
55 OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( 55 OnlineZipformerTransducerModel::OnlineZipformerTransducerModel(
56 - AAssetManager *mgr, const OnlineTransducerModelConfig &config) 56 + AAssetManager *mgr, const OnlineModelConfig &config)
57 : env_(ORT_LOGGING_LEVEL_WARNING), 57 : env_(ORT_LOGGING_LEVEL_WARNING),
58 config_(config), 58 config_(config),
59 sess_opts_(GetSessionOptions(config)), 59 sess_opts_(GetSessionOptions(config)),
60 allocator_{} { 60 allocator_{} {
61 { 61 {
62 - auto buf = ReadFile(mgr, config.encoder_filename); 62 + auto buf = ReadFile(mgr, config.transducer.encoder);
63 InitEncoder(buf.data(), buf.size()); 63 InitEncoder(buf.data(), buf.size());
64 } 64 }
65 65
66 { 66 {
67 - auto buf = ReadFile(mgr, config.decoder_filename); 67 + auto buf = ReadFile(mgr, config.transducer.decoder);
68 InitDecoder(buf.data(), buf.size()); 68 InitDecoder(buf.data(), buf.size());
69 } 69 }
70 70
71 { 71 {
72 - auto buf = ReadFile(mgr, config.joiner_filename); 72 + auto buf = ReadFile(mgr, config.transducer.joiner);
73 InitJoiner(buf.data(), buf.size()); 73 InitJoiner(buf.data(), buf.size());
74 } 74 }
75 } 75 }
@@ -15,19 +15,18 @@ @@ -15,19 +15,18 @@
15 #endif 15 #endif
16 16
17 #include "onnxruntime_cxx_api.h" // NOLINT 17 #include "onnxruntime_cxx_api.h" // NOLINT
18 -#include "sherpa-onnx/csrc/online-transducer-model-config.h" 18 +#include "sherpa-onnx/csrc/online-model-config.h"
19 #include "sherpa-onnx/csrc/online-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-transducer-model.h"
20 20
21 namespace sherpa_onnx { 21 namespace sherpa_onnx {
22 22
23 class OnlineZipformerTransducerModel : public OnlineTransducerModel { 23 class OnlineZipformerTransducerModel : public OnlineTransducerModel {
24 public: 24 public:
25 - explicit OnlineZipformerTransducerModel(  
26 - const OnlineTransducerModelConfig &config); 25 + explicit OnlineZipformerTransducerModel(const OnlineModelConfig &config);
27 26
28 #if __ANDROID_API__ >= 9 27 #if __ANDROID_API__ >= 9
29 OnlineZipformerTransducerModel(AAssetManager *mgr, 28 OnlineZipformerTransducerModel(AAssetManager *mgr,
30 - const OnlineTransducerModelConfig &config); 29 + const OnlineModelConfig &config);
31 #endif 30 #endif
32 31
33 std::vector<Ort::Value> StackStates( 32 std::vector<Ort::Value> StackStates(
@@ -87,7 +86,7 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { @@ -87,7 +86,7 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
87 std::vector<std::string> joiner_output_names_; 86 std::vector<std::string> joiner_output_names_;
88 std::vector<const char *> joiner_output_names_ptr_; 87 std::vector<const char *> joiner_output_names_ptr_;
89 88
90 - OnlineTransducerModelConfig config_; 89 + OnlineModelConfig config_;
91 90
92 std::vector<int32_t> encoder_dims_; 91 std::vector<int32_t> encoder_dims_;
93 std::vector<int32_t> attention_dims_; 92 std::vector<int32_t> attention_dims_;
@@ -32,46 +32,46 @@ @@ -32,46 +32,46 @@
32 namespace sherpa_onnx { 32 namespace sherpa_onnx {
33 33
34 OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( 34 OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
35 - const OnlineTransducerModelConfig &config) 35 + const OnlineModelConfig &config)
36 : env_(ORT_LOGGING_LEVEL_WARNING), 36 : env_(ORT_LOGGING_LEVEL_WARNING),
37 config_(config), 37 config_(config),
38 sess_opts_(GetSessionOptions(config)), 38 sess_opts_(GetSessionOptions(config)),
39 allocator_{} { 39 allocator_{} {
40 { 40 {
41 - auto buf = ReadFile(config.encoder_filename); 41 + auto buf = ReadFile(config.transducer.encoder);
42 InitEncoder(buf.data(), buf.size()); 42 InitEncoder(buf.data(), buf.size());
43 } 43 }
44 44
45 { 45 {
46 - auto buf = ReadFile(config.decoder_filename); 46 + auto buf = ReadFile(config.transducer.decoder);
47 InitDecoder(buf.data(), buf.size()); 47 InitDecoder(buf.data(), buf.size());
48 } 48 }
49 49
50 { 50 {
51 - auto buf = ReadFile(config.joiner_filename); 51 + auto buf = ReadFile(config.transducer.joiner);
52 InitJoiner(buf.data(), buf.size()); 52 InitJoiner(buf.data(), buf.size());
53 } 53 }
54 } 54 }
55 55
56 #if __ANDROID_API__ >= 9 56 #if __ANDROID_API__ >= 9
57 OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( 57 OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel(
58 - AAssetManager *mgr, const OnlineTransducerModelConfig &config) 58 + AAssetManager *mgr, const OnlineModelConfig &config)
59 : env_(ORT_LOGGING_LEVEL_WARNING), 59 : env_(ORT_LOGGING_LEVEL_WARNING),
60 config_(config), 60 config_(config),
61 sess_opts_(GetSessionOptions(config)), 61 sess_opts_(GetSessionOptions(config)),
62 allocator_{} { 62 allocator_{} {
63 { 63 {
64 - auto buf = ReadFile(mgr, config.encoder_filename); 64 + auto buf = ReadFile(mgr, config.transducer.encoder);
65 InitEncoder(buf.data(), buf.size()); 65 InitEncoder(buf.data(), buf.size());
66 } 66 }
67 67
68 { 68 {
69 - auto buf = ReadFile(mgr, config.decoder_filename); 69 + auto buf = ReadFile(mgr, config.transducer.decoder);
70 InitDecoder(buf.data(), buf.size()); 70 InitDecoder(buf.data(), buf.size());
71 } 71 }
72 72
73 { 73 {
74 - auto buf = ReadFile(mgr, config.joiner_filename); 74 + auto buf = ReadFile(mgr, config.transducer.joiner);
75 InitJoiner(buf.data(), buf.size()); 75 InitJoiner(buf.data(), buf.size());
76 } 76 }
77 } 77 }
@@ -15,19 +15,18 @@ @@ -15,19 +15,18 @@
15 #endif 15 #endif
16 16
17 #include "onnxruntime_cxx_api.h" // NOLINT 17 #include "onnxruntime_cxx_api.h" // NOLINT
18 -#include "sherpa-onnx/csrc/online-transducer-model-config.h" 18 +#include "sherpa-onnx/csrc/online-model-config.h"
19 #include "sherpa-onnx/csrc/online-transducer-model.h" 19 #include "sherpa-onnx/csrc/online-transducer-model.h"
20 20
21 namespace sherpa_onnx { 21 namespace sherpa_onnx {
22 22
23 class OnlineZipformer2TransducerModel : public OnlineTransducerModel { 23 class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
24 public: 24 public:
25 - explicit OnlineZipformer2TransducerModel(  
26 - const OnlineTransducerModelConfig &config); 25 + explicit OnlineZipformer2TransducerModel(const OnlineModelConfig &config);
27 26
28 #if __ANDROID_API__ >= 9 27 #if __ANDROID_API__ >= 9
29 OnlineZipformer2TransducerModel(AAssetManager *mgr, 28 OnlineZipformer2TransducerModel(AAssetManager *mgr,
30 - const OnlineTransducerModelConfig &config); 29 + const OnlineModelConfig &config);
31 #endif 30 #endif
32 31
33 std::vector<Ort::Value> StackStates( 32 std::vector<Ort::Value> StackStates(
@@ -87,7 +86,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel { @@ -87,7 +86,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
87 std::vector<std::string> joiner_output_names_; 86 std::vector<std::string> joiner_output_names_;
88 std::vector<const char *> joiner_output_names_ptr_; 87 std::vector<const char *> joiner_output_names_ptr_;
89 88
90 - OnlineTransducerModelConfig config_; 89 + OnlineModelConfig config_;
91 90
92 std::vector<int32_t> encoder_dims_; 91 std::vector<int32_t> encoder_dims_;
93 std::vector<int32_t> query_head_dims_; 92 std::vector<int32_t> query_head_dims_;
@@ -60,8 +60,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, @@ -60,8 +60,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
60 return sess_opts; 60 return sess_opts;
61 } 61 }
62 62
63 -Ort::SessionOptions GetSessionOptions(  
64 - const OnlineTransducerModelConfig &config) { 63 +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) {
65 return GetSessionOptionsImpl(config.num_threads, config.provider); 64 return GetSessionOptionsImpl(config.num_threads, config.provider);
66 } 65 }
67 66
@@ -9,12 +9,11 @@ @@ -9,12 +9,11 @@
9 #include "sherpa-onnx/csrc/offline-lm-config.h" 9 #include "sherpa-onnx/csrc/offline-lm-config.h"
10 #include "sherpa-onnx/csrc/offline-model-config.h" 10 #include "sherpa-onnx/csrc/offline-model-config.h"
11 #include "sherpa-onnx/csrc/online-lm-config.h" 11 #include "sherpa-onnx/csrc/online-lm-config.h"
12 -#include "sherpa-onnx/csrc/online-transducer-model-config.h" 12 +#include "sherpa-onnx/csrc/online-model-config.h"
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 -Ort::SessionOptions GetSessionOptions(  
17 - const OnlineTransducerModelConfig &config); 16 +Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
18 17
19 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); 18 Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
20 19
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/csrc/alsa.h" 12 #include "sherpa-onnx/csrc/alsa.h"
13 #include "sherpa-onnx/csrc/display.h" 13 #include "sherpa-onnx/csrc/display.h"
14 #include "sherpa-onnx/csrc/online-recognizer.h" 14 #include "sherpa-onnx/csrc/online-recognizer.h"
  15 +#include "sherpa-onnx/csrc/parse-options.h"
15 16
16 bool stop = false; 17 bool stop = false;
17 18
@@ -21,19 +22,19 @@ static void Handler(int sig) { @@ -21,19 +22,19 @@ static void Handler(int sig) {
21 } 22 }
22 23
23 int main(int32_t argc, char *argv[]) { 24 int main(int32_t argc, char *argv[]) {
24 - if (argc < 6 || argc > 8) {  
25 - const char *usage = R"usage( 25 + signal(SIGINT, Handler);
  26 +
  27 + const char *kUsageMessage = R"usage(
26 Usage: 28 Usage:
27 ./bin/sherpa-onnx-alsa \ 29 ./bin/sherpa-onnx-alsa \
28 - /path/to/tokens.txt \  
29 - /path/to/encoder.onnx \  
30 - /path/to/decoder.onnx \  
31 - /path/to/joiner.onnx \ 30 + --tokens=/path/to/tokens.txt \
  31 + --encoder=/path/to/encoder.onnx \
  32 + --decoder=/path/to/decoder.onnx \
  33 + --joiner=/path/to/joiner.onnx \
  34 + --provider=cpu \
  35 + --num-threads=2 \
  36 + --decoding-method=greedy_search \
32 device_name \ 37 device_name \
33 - [num_threads [decoding_method]]  
34 -  
35 -Default value for num_threads is 2.  
36 -Valid values for decoding_method: greedy_search (default), modified_beam_search.  
37 38
38 Please refer to 39 Please refer to
39 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 40 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -55,44 +56,24 @@ and if you want to select card 3 and the device 0 on that card, please use: @@ -55,44 +56,24 @@ and if you want to select card 3 and the device 0 on that card, please use:
55 56
56 hw:3,0 57 hw:3,0
57 58
58 -as the device_name.  
59 -)usage";  
60 -  
61 - fprintf(stderr, "%s\n", usage);  
62 - fprintf(stderr, "argc, %d\n", argc);  
63 -  
64 - return 0;  
65 - } 59 +or
66 60
67 - signal(SIGINT, Handler); 61 + plughw:3,0
68 62
  63 +as the device_name.
  64 +)usage";
  65 + sherpa_onnx::ParseOptions po(kUsageMessage);
69 sherpa_onnx::OnlineRecognizerConfig config; 66 sherpa_onnx::OnlineRecognizerConfig config;
70 67
71 - config.model_config.tokens = argv[1];  
72 -  
73 - config.model_config.debug = false;  
74 - config.model_config.encoder_filename = argv[2];  
75 - config.model_config.decoder_filename = argv[3];  
76 - config.model_config.joiner_filename = argv[4];  
77 -  
78 - const char *device_name = argv[5]; 68 + config.Register(&po);
79 69
80 - config.model_config.num_threads = 2;  
81 - if (argc == 7 && atoi(argv[6]) > 0) {  
82 - config.model_config.num_threads = atoi(argv[6]); 70 + po.Read(argc, argv);
  71 + if (po.NumArgs() != 1) {
  72 + fprintf(stderr, "Please provide only 1 argument: the device name\n");
  73 + po.PrintUsage();
  74 + exit(EXIT_FAILURE);
83 } 75 }
84 76
85 - if (argc == 8) {  
86 - config.decoding_method = argv[7];  
87 - }  
88 - config.max_active_paths = 4;  
89 -  
90 - config.enable_endpoint = true;  
91 -  
92 - config.endpoint_config.rule1.min_trailing_silence = 2.4;  
93 - config.endpoint_config.rule2.min_trailing_silence = 1.2;  
94 - config.endpoint_config.rule3.min_utterance_length = 300;  
95 -  
96 fprintf(stderr, "%s\n", config.ToString().c_str()); 77 fprintf(stderr, "%s\n", config.ToString().c_str());
97 78
98 if (!config.Validate()) { 79 if (!config.Validate()) {
@@ -103,8 +84,9 @@ as the device_name. @@ -103,8 +84,9 @@ as the device_name.
103 84
104 int32_t expected_sample_rate = config.feat_config.sampling_rate; 85 int32_t expected_sample_rate = config.feat_config.sampling_rate;
105 86
106 - sherpa_onnx::Alsa alsa(device_name);  
107 - fprintf(stderr, "Use recording device: %s\n", device_name); 87 + std::string device_name = po.GetArg(1);
  88 + sherpa_onnx::Alsa alsa(device_name.c_str());
  89 + fprintf(stderr, "Use recording device: %s\n", device_name.c_str());
108 90
109 if (alsa.GetExpectedSampleRate() != expected_sample_rate) { 91 if (alsa.GetExpectedSampleRate() != expected_sample_rate) {
110 fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), 92 fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(),
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 // Copyright 2023 Xiaomi Corporation 4 // Copyright 2023 Xiaomi Corporation
5 #ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 5 #ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_
6 #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 6 #define SHERPA_ONNX_CSRC_TEXT_UTILS_H_
  7 +#include <errno.h>.
7 #include <stdlib.h> 8 #include <stdlib.h>
8 9
9 #include <limits> 10 #include <limits>
@@ -159,47 +159,47 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { @@ -159,47 +159,47 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
159 //---------- model config ---------- 159 //---------- model config ----------
160 fid = env->GetFieldID(cls, "modelConfig", 160 fid = env->GetFieldID(cls, "modelConfig",
161 "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); 161 "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
162 - jobject model_config = env->GetObjectField(config, fid);  
163 - jclass model_config_cls = env->GetObjectClass(model_config); 162 + jobject transducer_config = env->GetObjectField(config, fid);
  163 + jclass model_config_cls = env->GetObjectClass(transducer_config);
164 164
165 fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); 165 fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
166 - s = (jstring)env->GetObjectField(model_config, fid); 166 + s = (jstring)env->GetObjectField(transducer_config, fid);
167 p = env->GetStringUTFChars(s, nullptr); 167 p = env->GetStringUTFChars(s, nullptr);
168 - ans.model_config.encoder_filename = p; 168 + ans.model_config.transducer.encoder = p;
169 env->ReleaseStringUTFChars(s, p); 169 env->ReleaseStringUTFChars(s, p);
170 170
171 fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;"); 171 fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
172 - s = (jstring)env->GetObjectField(model_config, fid); 172 + s = (jstring)env->GetObjectField(transducer_config, fid);
173 p = env->GetStringUTFChars(s, nullptr); 173 p = env->GetStringUTFChars(s, nullptr);
174 - ans.model_config.decoder_filename = p; 174 + ans.model_config.transducer.decoder = p;
175 env->ReleaseStringUTFChars(s, p); 175 env->ReleaseStringUTFChars(s, p);
176 176
177 fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;"); 177 fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
178 - s = (jstring)env->GetObjectField(model_config, fid); 178 + s = (jstring)env->GetObjectField(transducer_config, fid);
179 p = env->GetStringUTFChars(s, nullptr); 179 p = env->GetStringUTFChars(s, nullptr);
180 - ans.model_config.joiner_filename = p; 180 + ans.model_config.transducer.joiner = p;
181 env->ReleaseStringUTFChars(s, p); 181 env->ReleaseStringUTFChars(s, p);
182 182
183 fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); 183 fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
184 - s = (jstring)env->GetObjectField(model_config, fid); 184 + s = (jstring)env->GetObjectField(transducer_config, fid);
185 p = env->GetStringUTFChars(s, nullptr); 185 p = env->GetStringUTFChars(s, nullptr);
186 ans.model_config.tokens = p; 186 ans.model_config.tokens = p;
187 env->ReleaseStringUTFChars(s, p); 187 env->ReleaseStringUTFChars(s, p);
188 188
189 fid = env->GetFieldID(model_config_cls, "numThreads", "I"); 189 fid = env->GetFieldID(model_config_cls, "numThreads", "I");
190 - ans.model_config.num_threads = env->GetIntField(model_config, fid); 190 + ans.model_config.num_threads = env->GetIntField(transducer_config, fid);
191 191
192 fid = env->GetFieldID(model_config_cls, "debug", "Z"); 192 fid = env->GetFieldID(model_config_cls, "debug", "Z");
193 - ans.model_config.debug = env->GetBooleanField(model_config, fid); 193 + ans.model_config.debug = env->GetBooleanField(transducer_config, fid);
194 194
195 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); 195 fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
196 - s = (jstring)env->GetObjectField(model_config, fid); 196 + s = (jstring)env->GetObjectField(transducer_config, fid);
197 p = env->GetStringUTFChars(s, nullptr); 197 p = env->GetStringUTFChars(s, nullptr);
198 ans.model_config.provider = p; 198 ans.model_config.provider = p;
199 env->ReleaseStringUTFChars(s, p); 199 env->ReleaseStringUTFChars(s, p);
200 200
201 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); 201 fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
202 - s = (jstring)env->GetObjectField(model_config, fid); 202 + s = (jstring)env->GetObjectField(transducer_config, fid);
203 p = env->GetStringUTFChars(s, nullptr); 203 p = env->GetStringUTFChars(s, nullptr);
204 ans.model_config.model_type = p; 204 ans.model_config.model_type = p;
205 env->ReleaseStringUTFChars(s, p); 205 env->ReleaseStringUTFChars(s, p);
@@ -328,7 +328,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens( @@ -328,7 +328,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
328 jobjectArray result = env->NewObjectArray(size, stringClass, NULL); 328 jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
329 for (int i = 0; i < size; i++) { 329 for (int i = 0; i < size; i++) {
330 // Convert the C++ string to a C string 330 // Convert the C++ string to a C string
331 - const char* cstr = tokens[i].c_str(); 331 + const char *cstr = tokens[i].c_str();
332 332
333 // Convert the C string to a jstring 333 // Convert the C string to a jstring
334 jstring jstr = env->NewStringUTF(cstr); 334 jstring jstr = env->NewStringUTF(cstr);
@@ -13,6 +13,7 @@ pybind11_add_module(_sherpa_onnx @@ -13,6 +13,7 @@ pybind11_add_module(_sherpa_onnx
13 offline-transducer-model-config.cc 13 offline-transducer-model-config.cc
14 offline-whisper-model-config.cc 14 offline-whisper-model-config.cc
15 online-lm-config.cc 15 online-lm-config.cc
  16 + online-model-config.cc
16 online-recognizer.cc 17 online-recognizer.cc
17 online-stream.cc 18 online-stream.cc
18 online-transducer-model-config.cc 19 online-transducer-model-config.cc
  1 +// sherpa-onnx/python/csrc/online-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-model-config.h"
  11 +#include "sherpa-onnx/csrc/online-transducer-model-config.h"
  12 +#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +void PybindOnlineModelConfig(py::module *m) {
  17 + PybindOnlineTransducerModelConfig(m);
  18 +
  19 + using PyClass = OnlineModelConfig;
  20 + py::class_<PyClass>(*m, "OnlineModelConfig")
  21 + .def(py::init<const OnlineTransducerModelConfig &, std::string &, int32_t,
  22 + bool, const std::string &, const std::string &>(),
  23 + py::arg("transducer") = OnlineTransducerModelConfig(),
  24 + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
  25 + py::arg("provider") = "cpu", py::arg("model_type") = "")
  26 + .def_readwrite("transducer", &PyClass::transducer)
  27 + .def_readwrite("tokens", &PyClass::tokens)
  28 + .def_readwrite("num_threads", &PyClass::num_threads)
  29 + .def_readwrite("debug", &PyClass::debug)
  30 + .def_readwrite("provider", &PyClass::provider)
  31 + .def_readwrite("model_type", &PyClass::model_type)
  32 + .def("__str__", &PyClass::ToString);
  33 +}
  34 +
  35 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-model-config.h
  2 +//
  3 +// Copyright (c) 2023 by manyeyes
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_
@@ -27,10 +27,9 @@ static void PybindOnlineRecognizerResult(py::module *m) { @@ -27,10 +27,9 @@ static void PybindOnlineRecognizerResult(py::module *m) {
27 static void PybindOnlineRecognizerConfig(py::module *m) { 27 static void PybindOnlineRecognizerConfig(py::module *m) {
28 using PyClass = OnlineRecognizerConfig; 28 using PyClass = OnlineRecognizerConfig;
29 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 29 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
30 - .def(py::init<const FeatureExtractorConfig &,  
31 - const OnlineTransducerModelConfig &, const OnlineLMConfig &,  
32 - const EndpointConfig &, bool, const std::string &, int32_t,  
33 - float>(), 30 + .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
  31 + const OnlineLMConfig &, const EndpointConfig &, bool,
  32 + const std::string &, int32_t, float>(),
34 py::arg("feat_config"), py::arg("model_config"), 33 py::arg("feat_config"), py::arg("model_config"),
35 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), 34 py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
36 py::arg("enable_endpoint"), py::arg("decoding_method"), 35 py::arg("enable_endpoint"), py::arg("decoding_method"),
@@ -14,20 +14,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) { @@ -14,20 +14,11 @@ void PybindOnlineTransducerModelConfig(py::module *m) {
14 using PyClass = OnlineTransducerModelConfig; 14 using PyClass = OnlineTransducerModelConfig;
15 py::class_<PyClass>(*m, "OnlineTransducerModelConfig") 15 py::class_<PyClass>(*m, "OnlineTransducerModelConfig")
16 .def(py::init<const std::string &, const std::string &, 16 .def(py::init<const std::string &, const std::string &,
17 - const std::string &, const std::string &, int32_t, bool,  
18 - const std::string &, const std::string &>(),  
19 - py::arg("encoder_filename"), py::arg("decoder_filename"),  
20 - py::arg("joiner_filename"), py::arg("tokens"),  
21 - py::arg("num_threads"), py::arg("debug") = false,  
22 - py::arg("provider") = "cpu", py::arg("model_type") = "")  
23 - .def_readwrite("encoder_filename", &PyClass::encoder_filename)  
24 - .def_readwrite("decoder_filename", &PyClass::decoder_filename)  
25 - .def_readwrite("joiner_filename", &PyClass::joiner_filename)  
26 - .def_readwrite("tokens", &PyClass::tokens)  
27 - .def_readwrite("num_threads", &PyClass::num_threads)  
28 - .def_readwrite("debug", &PyClass::debug)  
29 - .def_readwrite("provider", &PyClass::provider)  
30 - .def_readwrite("model_type", &PyClass::model_type) 17 + const std::string &>(),
  18 + py::arg("encoder"), py::arg("decoder"), py::arg("joiner"))
  19 + .def_readwrite("encoder", &PyClass::encoder)
  20 + .def_readwrite("decoder", &PyClass::decoder)
  21 + .def_readwrite("joiner", &PyClass::joiner)
31 .def("__str__", &PyClass::ToString); 22 .def("__str__", &PyClass::ToString);
32 } 23 }
33 24
@@ -12,9 +12,9 @@ @@ -12,9 +12,9 @@
12 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 12 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
13 #include "sherpa-onnx/python/csrc/offline-stream.h" 13 #include "sherpa-onnx/python/csrc/offline-stream.h"
14 #include "sherpa-onnx/python/csrc/online-lm-config.h" 14 #include "sherpa-onnx/python/csrc/online-lm-config.h"
  15 +#include "sherpa-onnx/python/csrc/online-model-config.h"
15 #include "sherpa-onnx/python/csrc/online-recognizer.h" 16 #include "sherpa-onnx/python/csrc/online-recognizer.h"
16 #include "sherpa-onnx/python/csrc/online-stream.h" 17 #include "sherpa-onnx/python/csrc/online-stream.h"
17 -#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"  
18 18
19 namespace sherpa_onnx { 19 namespace sherpa_onnx {
20 20
@@ -22,7 +22,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -22,7 +22,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
22 m.doc() = "pybind11 binding of sherpa-onnx"; 22 m.doc() = "pybind11 binding of sherpa-onnx";
23 23
24 PybindFeatures(&m); 24 PybindFeatures(&m);
25 - PybindOnlineTransducerModelConfig(&m); 25 + PybindOnlineModelConfig(&m);
26 PybindOnlineLMConfig(&m); 26 PybindOnlineLMConfig(&m);
27 PybindOnlineStream(&m); 27 PybindOnlineStream(&m);
28 PybindEndpoint(&m); 28 PybindEndpoint(&m);
@@ -5,6 +5,7 @@ from typing import List, Optional @@ -5,6 +5,7 @@ from typing import List, Optional
5 from _sherpa_onnx import ( 5 from _sherpa_onnx import (
6 EndpointConfig, 6 EndpointConfig,
7 FeatureExtractorConfig, 7 FeatureExtractorConfig,
  8 + OnlineModelConfig,
8 OnlineRecognizer as _Recognizer, 9 OnlineRecognizer as _Recognizer,
9 OnlineRecognizerConfig, 10 OnlineRecognizerConfig,
10 OnlineStream, 11 OnlineStream,
@@ -24,8 +25,9 @@ class OnlineRecognizer(object): @@ -24,8 +25,9 @@ class OnlineRecognizer(object):
24 - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py 25 - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py
25 """ 26 """
26 27
27 - def __init__(  
28 - self, 28 + @classmethod
  29 + def from_transducer(
  30 + cls,
29 tokens: str, 31 tokens: str,
30 encoder: str, 32 encoder: str,
31 decoder: str, 33 decoder: str,
@@ -95,6 +97,7 @@ class OnlineRecognizer(object): @@ -95,6 +97,7 @@ class OnlineRecognizer(object):
95 Online transducer model type. Valid values are: conformer, lstm, 97 Online transducer model type. Valid values are: conformer, lstm,
96 zipformer, zipformer2. All other values lead to loading the model twice. 98 zipformer, zipformer2. All other values lead to loading the model twice.
97 """ 99 """
  100 + self = cls.__new__(cls)
98 _assert_file_exists(tokens) 101 _assert_file_exists(tokens)
99 _assert_file_exists(encoder) 102 _assert_file_exists(encoder)
100 _assert_file_exists(decoder) 103 _assert_file_exists(decoder)
@@ -102,10 +105,14 @@ class OnlineRecognizer(object): @@ -102,10 +105,14 @@ class OnlineRecognizer(object):
102 105
103 assert num_threads > 0, num_threads 106 assert num_threads > 0, num_threads
104 107
105 - model_config = OnlineTransducerModelConfig(  
106 - encoder_filename=encoder,  
107 - decoder_filename=decoder,  
108 - joiner_filename=joiner, 108 + transducer_config = OnlineTransducerModelConfig(
  109 + encoder=encoder,
  110 + decoder=decoder,
  111 + joiner=joiner,
  112 + )
  113 +
  114 + model_config = OnlineModelConfig(
  115 + transducer=transducer_config,
109 tokens=tokens, 116 tokens=tokens,
110 num_threads=num_threads, 117 num_threads=num_threads,
111 provider=provider, 118 provider=provider,
@@ -135,6 +142,7 @@ class OnlineRecognizer(object): @@ -135,6 +142,7 @@ class OnlineRecognizer(object):
135 142
136 self.recognizer = _Recognizer(recognizer_config) 143 self.recognizer = _Recognizer(recognizer_config)
137 self.config = recognizer_config 144 self.config = recognizer_config
  145 + return self
138 146
139 def create_stream(self, contexts_list: Optional[List[List[int]]] = None): 147 def create_stream(self, contexts_list: Optional[List[List[int]]] = None):
140 if contexts_list is None: 148 if contexts_list is None:
@@ -65,7 +65,7 @@ class TestOnlineRecognizer(unittest.TestCase): @@ -65,7 +65,7 @@ class TestOnlineRecognizer(unittest.TestCase):
65 return 65 return
66 66
67 for decoding_method in ["greedy_search", "modified_beam_search"]: 67 for decoding_method in ["greedy_search", "modified_beam_search"]:
68 - recognizer = sherpa_onnx.OnlineRecognizer( 68 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
69 encoder=encoder, 69 encoder=encoder,
70 decoder=decoder, 70 decoder=decoder,
71 joiner=joiner, 71 joiner=joiner,
@@ -109,7 +109,7 @@ class TestOnlineRecognizer(unittest.TestCase): @@ -109,7 +109,7 @@ class TestOnlineRecognizer(unittest.TestCase):
109 return 109 return
110 110
111 for decoding_method in ["greedy_search", "modified_beam_search"]: 111 for decoding_method in ["greedy_search", "modified_beam_search"]:
112 - recognizer = sherpa_onnx.OnlineRecognizer( 112 + recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
113 encoder=encoder, 113 encoder=encoder,
114 decoder=decoder, 114 decoder=decoder,
115 joiner=joiner, 115 joiner=joiner,
@@ -14,19 +14,13 @@ import _sherpa_onnx @@ -14,19 +14,13 @@ import _sherpa_onnx
14 class TestOnlineTransducerModelConfig(unittest.TestCase): 14 class TestOnlineTransducerModelConfig(unittest.TestCase):
15 def test_constructor(self): 15 def test_constructor(self):
16 config = _sherpa_onnx.OnlineTransducerModelConfig( 16 config = _sherpa_onnx.OnlineTransducerModelConfig(
17 - encoder_filename="encoder.onnx",  
18 - decoder_filename="decoder.onnx",  
19 - joiner_filename="joiner.onnx",  
20 - tokens="tokens.txt",  
21 - num_threads=8,  
22 - debug=True, 17 + encoder="encoder.onnx",
  18 + decoder="decoder.onnx",
  19 + joiner="joiner.onnx",
23 ) 20 )
24 - assert config.encoder_filename == "encoder.onnx", config.encoder_filename  
25 - assert config.decoder_filename == "decoder.onnx", config.decoder_filename  
26 - assert config.joiner_filename == "joiner.onnx", config.joiner_filename  
27 - assert config.tokens == "tokens.txt", config.tokens  
28 - assert config.num_threads == 8, config.num_threads  
29 - assert config.debug is True, config.debug 21 + assert config.encoder == "encoder.onnx", config.encoder
  22 + assert config.decoder == "decoder.onnx", config.decoder
  23 + assert config.joiner == "joiner.onnx", config.joiner
30 print(config) 24 print(config)
31 25
32 26