Fangjun Kuang
Committed by GitHub

Reduce model initialization time for offline speech recognition (#213)

... ... @@ -387,6 +387,7 @@ void CNonStreamingSpeechRecognitionDlg::InitParaformer() {
config_.model_config.tokens = tokens.c_str();
config_.model_config.num_threads = 1;
config_.model_config.debug = 1;
config_.model_config.model_type = "paraformer";
config_.decoding_method = "greedy_search";
config_.max_active_paths = 4;
... ... @@ -447,6 +448,7 @@ void CNonStreamingSpeechRecognitionDlg::InitRecognizer() {
config_.model_config.tokens = tokens.c_str();
config_.model_config.num_threads = 1;
config_.model_config.debug = 0;
config_.model_config.model_type = "transducer";
config_.decoding_method = "greedy_search";
config_.max_active_paths = 4;
... ...
... ... @@ -76,6 +76,8 @@ namespace SherpaOnnx
Tokens = "";
NumThreads = 1;
Debug = 0;
Provider = "cpu";
ModelType = "";
}
public OfflineTransducerModelConfig Transducer;
public OfflineParaformerModelConfig Paraformer;
... ... @@ -87,6 +89,12 @@ namespace SherpaOnnx
public int NumThreads;
public int Debug;
[MarshalAs(UnmanagedType.LPStr)]
public string Provider;
[MarshalAs(UnmanagedType.LPStr)]
public string ModelType;
}
[StructLayout(LayoutKind.Sequential)]
... ...
... ... @@ -33,23 +33,33 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config;
recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.feat_config.sampling_rate =
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
recognizer_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.model_config.encoder_filename =
SHERPA_ONNX_OR(config->model_config.encoder, "");
recognizer_config.model_config.decoder_filename =
SHERPA_ONNX_OR(config->model_config.decoder, "");
recognizer_config.model_config.joiner_filename = SHERPA_ONNX_OR(config->model_config.joiner, "");
recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
recognizer_config.enable_endpoint = SHERPA_ONNX_OR(config->enable_endpoint, 0);
recognizer_config.model_config.joiner_filename =
SHERPA_ONNX_OR(config->model_config.joiner, "");
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.decoding_method =
SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
recognizer_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);
recognizer_config.enable_endpoint =
SHERPA_ONNX_OR(config->enable_endpoint, 0);
recognizer_config.endpoint_config.rule1.min_trailing_silence =
SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4);
... ... @@ -173,9 +183,11 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
const SherpaOnnxOfflineRecognizerConfig *config) {
sherpa_onnx::OfflineRecognizerConfig recognizer_config;
recognizer_config.feat_config.sampling_rate = SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
recognizer_config.feat_config.sampling_rate =
SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
recognizer_config.feat_config.feature_dim = SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.feat_config.feature_dim =
SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
recognizer_config.model_config.transducer.encoder_filename =
SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
... ... @@ -184,7 +196,7 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
recognizer_config.model_config.transducer.joiner_filename =
SHERPA_ONNX_OR(config->model_config.transducer.joiner,"");
SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
recognizer_config.model_config.paraformer.model =
SHERPA_ONNX_OR(config->model_config.paraformer.model, "");
... ... @@ -192,15 +204,26 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.model_config.nemo_ctc.model =
SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, "");
recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.debug = SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.lm_config.model = SHERPA_ONNX_OR(config->lm_config.model, "");
recognizer_config.lm_config.scale = SHERPA_ONNX_OR(config->lm_config.scale, 1.0);
recognizer_config.decoding_method = SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
recognizer_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);
recognizer_config.model_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
recognizer_config.lm_config.model =
SHERPA_ONNX_OR(config->lm_config.model, "");
recognizer_config.lm_config.scale =
SHERPA_ONNX_OR(config->lm_config.scale, 1.0);
recognizer_config.decoding_method =
SHERPA_ONNX_OR(config->decoding_method, "greedy_search");
recognizer_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);
if (config->model_config.debug) {
fprintf(stderr, "%s\n", recognizer_config.ToString().c_str());
... ...
... ... @@ -272,6 +272,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineModelConfig {
const char *tokens;
int32_t num_threads;
int32_t debug;
const char *provider;
const char *model_type;
} SherpaOnnxOfflineModelConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {
... ...
... ... @@ -25,6 +25,11 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc. "
"All other values lead to loading the model twice.");
}
bool OfflineModelConfig::Validate() const {
... ... @@ -34,7 +39,7 @@ bool OfflineModelConfig::Validate() const {
}
if (!FileExists(tokens)) {
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
SHERPA_ONNX_LOGE("tokens: %s does not exist", tokens.c_str());
return false;
}
... ... @@ -59,7 +64,8 @@ std::string OfflineModelConfig::ToString() const {
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
return os.str();
}
... ...
... ... @@ -22,19 +22,31 @@ struct OfflineModelConfig {
bool debug = false;
std::string provider = "cpu";
// With the help of this field, we only need to load the model once
// instead of twice; and therefore it reduces initialization time.
//
// Valid values:
// - transducer. The given model is from icefall
// - paraformer. It is a paraformer model
// - nemo_ctc. It is a NeMo CTC model.
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
OfflineModelConfig() = default;
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer,
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider)
const std::string &provider, const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
nemo_ctc(nemo_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
provider(provider) {}
provider(provider),
model_type(model_type) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -18,6 +18,21 @@ namespace sherpa_onnx {
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const OfflineRecognizerConfig &config) {
if (!config.model_config.model_type.empty()) {
const auto &model_type = config.model_config.model_type;
if (model_type == "transducer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else {
SHERPA_ONNX_LOGE(
"Invalid model_type: %s. Trying to load the model to get its type",
model_type.c_str());
}
}
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
Ort::SessionOptions sess_opts;
... ...
... ... @@ -18,17 +18,17 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
bool OfflineTransducerModelConfig::Validate() const {
if (!FileExists(encoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str());
return false;
}
if (!FileExists(decoder_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str());
return false;
}
if (!FileExists(joiner_filename)) {
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str());
return false;
}
... ...
... ... @@ -21,15 +21,16 @@ void PybindOfflineModelConfig(py::module *m) {
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const std::string &, int32_t, bool, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu")
.def(
py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
... ... @@ -37,6 +38,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def_readwrite("model_type", &PyClass::model_type)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -86,6 +86,7 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="transducer",
)
feat_config = OfflineFeatureExtractorConfig(
... ... @@ -149,6 +150,7 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="paraformer",
)
feat_config = OfflineFeatureExtractorConfig(
... ... @@ -211,6 +213,7 @@ class OfflineRecognizer(object):
num_threads=num_threads,
debug=debug,
provider=provider,
model_type="nemo_ctc",
)
feat_config = OfflineFeatureExtractorConfig(
... ...