Fangjun Kuang
Committed by GitHub

Support non-streaming WeNet CTC models. (#426)

... ... @@ -14,6 +14,47 @@ echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run Wenet models"
log "------------------------------------------------------------"
wenet_models=(
sherpa-onnx-zh-wenet-aishell
sherpa-onnx-zh-wenet-aishell2
sherpa-onnx-zh-wenet-wenetspeech
sherpa-onnx-zh-wenet-multi-cn
sherpa-onnx-en-wenet-librispeech
sherpa-onnx-en-wenet-gigaspeech
)
for name in ${wenet_models[@]}; do
repo_url=https://huggingface.co/csukuangfj/$name
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
log "test float32 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
log "test int8 models"
time $EXE \
--tokens=$repo/tokens.txt \
--wenet-ctc-model=$repo/model.int8.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
done
log "------------------------------------------------------------"
log "Run tdnn yesno (Hebrew)"
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-tdnn-yesno
... ...
name: export-wenet-to-onnx
on:
push:
branches:
- master
paths:
- 'scripts/wenet/**'
- '.github/workflows/export-wenet-to-onnx.yaml'
pull_request:
paths:
- 'scripts/wenet/**'
- '.github/workflows/export-wenet-to-onnx.yaml'
workflow_dispatch:
concurrency:
... ...
... ... @@ -89,6 +89,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline TTS
shell: bash
run: |
... ... @@ -115,14 +123,6 @@ jobs:
.github/scripts/test-offline-whisper.sh
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer
shell: bash
run: |
... ...
... ... @@ -172,7 +172,7 @@ def main():
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet-ctc",
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "streaming",
... ... @@ -185,6 +185,7 @@ def main():
"cnn_module_kernel": cnn_module_kernel,
"right_context": right_context,
"subsampling_factor": subsampling_factor,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
}
add_meta_data(filename=filename, meta_data=meta_data)
... ...
... ... @@ -107,10 +107,12 @@ def main():
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet-ctc",
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "non-streaming",
"subsampling_factor": torch_model.encoder.embed.subsampling_rate,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
"url": url,
}
add_meta_data(filename=filename, meta_data=meta_data)
... ...
... ... @@ -41,6 +41,8 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
offline-wenet-ctc-model-config.cc
offline-wenet-ctc-model.cc
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -21,10 +22,11 @@ enum class ModelType {
kEncDecCTCModelBPE,
kTdnn,
kZipformerCtc,
kWenetCtc,
kUnkown,
};
}
} // namespace
namespace sherpa_onnx {
... ... @@ -52,6 +54,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"If you are using models from NeMo, please refer to\n"
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"If you are using models from WeNet, please refer to\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
... ... @@ -63,6 +68,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) {
return ModelType::kZipformerCtc;
} else if (model_type.get() == std::string("wenet_ctc")) {
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
... ... @@ -80,6 +87,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -101,6 +110,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ... @@ -122,6 +134,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else if (!config.wenet_ctc.model.empty()) {
filename = config.wenet_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -143,6 +157,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
break;
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ...
... ... @@ -63,6 +63,9 @@ class OfflineCtcModel {
* for the features.
*/
virtual std::string FeatureNormalizationMethod() const { return {}; }
// Return true if the model supports batch size > 1
virtual bool SupportBatchProcessing() const { return true; }
};
} // namespace sherpa_onnx
... ...
... ... @@ -17,6 +17,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
whisper.Register(po);
tdnn.Register(po);
zipformer_ctc.Register(po);
wenet_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -67,6 +68,10 @@ bool OfflineModelConfig::Validate() const {
return zipformer_ctc.Validate();
}
if (!wenet_ctc.model.empty()) {
return wenet_ctc.Validate();
}
return transducer.Validate();
}
... ... @@ -80,6 +85,7 @@ std::string OfflineModelConfig::ToString() const {
os << "whisper=" << whisper.ToString() << ", ";
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -10,6 +10,7 @@
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
... ... @@ -22,6 +23,7 @@ struct OfflineModelConfig {
OfflineWhisperModelConfig whisper;
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
OfflineWenetCtcModelConfig wenet_ctc;
std::string tokens;
int32_t num_threads = 2;
... ... @@ -46,6 +48,7 @@ struct OfflineModelConfig {
const OfflineWhisperModelConfig &whisper,
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const OfflineWenetCtcModelConfig &wenet_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
... ... @@ -54,6 +57,7 @@ struct OfflineModelConfig {
whisper(whisper),
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
wenet_ctc(wenet_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
... ...
... ... @@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif
void Init() {
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
... ... @@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
!symbol_table_.contains("<eps>") &&
!symbol_table_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> and its ID.");
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}
... ... @@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
} else if (symbol_table_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = symbol_table_["<eps>"];
} else if (symbol_table_.contains("<blank>")) {
// for Wenet CTC models
blank_id = symbol_table_["<blank>"];
}
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
... ... @@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
if (!model_->SupportBatchProcessing()) {
// If the model does not support batch process,
// we process each stream independently.
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
return;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
... ... @@ -165,6 +184,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());
auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
int32_t frame_shift_ms = 10;
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
s->SetResult(r);
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineCtcModel> model_;
... ...
... ... @@ -26,7 +26,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
... ... @@ -51,6 +51,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.wenet_ctc.model.empty()) {
model_filename = config.model_config.wenet_ctc.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
... ... @@ -99,6 +101,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"zipformer/export-onnx-ctc.py"
"\n"
"(6) CTC models from WeNet"
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"\n");
exit(-1);
}
... ... @@ -114,7 +120,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
}
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
... ... @@ -130,7 +136,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n",
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
model_type.c_str());
exit(-1);
... ... @@ -146,7 +153,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
... ... @@ -171,6 +178,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.wenet_ctc.model.empty()) {
model_filename = config.model_config.wenet_ctc.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
... ... @@ -219,6 +228,10 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"zipformer/export-onnx-ctc.py"
"\n"
"(6) CTC models from WeNet"
"\n "
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh"
"\n"
"\n");
exit(-1);
}
... ... @@ -234,7 +247,7 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
}
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
model_type == "zipformer2_ctc" || model_type == "wenet_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
... ... @@ -250,7 +263,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n"
" - Zipformer CTC models\n",
" - Zipformer CTC models\n"
" - WeNet CTC models\n",
model_type.c_str());
exit(-1);
... ...
... ... @@ -40,7 +40,8 @@ struct OfflineFeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;
// Set internally by some models, e.g., paraformer sets it to false.
// Set internally by some models, e.g., paraformer and wenet CTC models set
// it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
// the range [-1, 1].
... ...
// sherpa-onnx/csrc/offline-wenet-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineWenetCtcModelConfig::Register(ParseOptions *po) {
po->Register(
"wenet-ctc-model", &model,
"Path to model.onnx from WeNet. Please see "
"https://github.com/k2-fsa/sherpa-onnx/pull/425 for available models");
}
bool OfflineWenetCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("WeNet model: %s does not exist", model.c_str());
return false;
}
return true;
}
std::string OfflineWenetCtcModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineWenetCtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-wenet-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineWenetCtcModelConfig {
std::string model;
OfflineWenetCtcModelConfig() = default;
explicit OfflineWenetCtcModelConfig(const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-wenet-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineWenetCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.wenet_ctc.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.wenet_ctc.model);
Init(buf.data(), buf.size());
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 0;
};
OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineWenetCtcModel::OfflineWenetCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineWenetCtcModel::~OfflineWenetCtcModel() = default;
std::vector<Ort::Value> OfflineWenetCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OfflineWenetCtcModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
OrtAllocator *OfflineWenetCtcModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-wenet-ctc-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
/** This class implements the CTC model from WeNet.
*
* See
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/export-onnx.py
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/test-onnx.py
* https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/run.sh
*
*/
class OfflineWenetCtcModel : public OfflineCtcModel {
public:
explicit OfflineWenetCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineWenetCtcModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
~OfflineWenetCtcModel() override;
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** SubsamplingFactor of the model
*
* For Citrinet, the subsampling factor is usually 4.
* For Conformer CTC, the subsampling factor is usually 8.
*/
int32_t SubsamplingFactor() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// WeNet CTC models do not support batch size > 1
bool SupportBatchProcessing() const override { return false; }
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_
... ...
... ... @@ -17,6 +17,7 @@ pybind11_add_module(_sherpa_onnx
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts.cc
offline-wenet-ctc-model-config.cc
offline-whisper-model-config.cc
offline-zipformer-ctc-model-config.cc
online-lm-config.cc
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
... ... @@ -24,6 +25,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineWhisperModelConfig(m);
PybindOfflineTdnnModelConfig(m);
PybindOfflineZipformerCtcModelConfig(m);
PybindOfflineWenetCtcModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
... ... @@ -32,7 +34,8 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &,
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &, const std::string &,
const OfflineZipformerCtcModelConfig &,
const OfflineWenetCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
... ... @@ -40,6 +43,7 @@ void PybindOfflineModelConfig(py::module *m) {
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
... ... @@ -48,6 +52,7 @@ void PybindOfflineModelConfig(py::module *m) {
.def_readwrite("whisper", &PyClass::whisper)
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
// sherpa-onnx/python/csrc/offline-wenet-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-wenet-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-wenet-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOfflineWenetCtcModelConfig(py::module *m) {
using PyClass = OfflineWenetCtcModelConfig;
py::class_<PyClass>(*m, "OfflineWenetCtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-wenet-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineWenetCtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_
... ...