Manix
Committed by GitHub

Adding warm up for Zipformer2 (#766)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
... ... @@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("warm-up", &warm_up,
"Number of warm-up to run the onnxruntime"
"Valid vales are: zipformer2");
po->Register("debug", &debug,
"true to print model information while loading it.");
... ... @@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const {
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "warm_up=" << warm_up << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\")";
... ...
... ... @@ -20,6 +20,7 @@ struct OnlineModelConfig {
OnlineZipformer2CtcModelConfig zipformer2_ctc;
std::string tokens;
int32_t num_threads = 1;
int32_t warm_up = 0;
bool debug = false;
std::string provider = "cpu";
... ... @@ -38,14 +39,17 @@ struct OnlineModelConfig {
const OnlineParaformerModelConfig &paraformer,
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug,
const std::string &provider,
const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
tokens(tokens),
num_threads(num_threads),
warm_up(warm_up),
debug(debug),
provider(provider),
model_type(model_type) {}
... ...
... ... @@ -37,6 +37,12 @@ class OnlineRecognizerImpl {
virtual bool IsReady(OnlineStream *s) const = 0;
virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
// ToDo extending to other models
SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now.");
exit(-1);
}
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0;
... ...
... ... @@ -32,6 +32,7 @@
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/utils.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
... ... @@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->NumFramesReady();
}
// Warmping up engine with wp: warm_up count and max-batch-size
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
auto max_batch_size = mbs;
if (warmup <= 0 || warmup > 100) {
return;
}
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feature_dim = 80;
std::vector<OnlineTransducerDecoderResult> results(max_batch_size);
std::vector<float> features_vec(max_batch_size * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(max_batch_size);
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{max_batch_size, chunk_size, feature_dim};
for (int32_t i = 0; i != max_batch_size; ++i) {
states_vec[i] = model_->GetEncoderInitStates();
results[i] = decoder_->GetEmptyResult();
}
for (int32_t i = 0; i != warmup; ++i) {
auto states = model_->StackStates(states_vec);
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
auto x_copy = Clone(model_->Allocator(), &x);
auto pair = model_->RunEncoder(std::move(x), std::move(states),
std::move(x_copy));
decoder_->Decode(std::move(pair.first), &results);
}
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
... ...
... ... @@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
if (warmup > 0) {
impl_->WarmpUpRecognizer(warmup, mbs);
}
}
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
impl_->DecodeStreams(ss, n);
}
... ...
... ... @@ -162,6 +162,15 @@ class OnlineRecognizer {
DecodeStreams(ss, 1);
}
/**
* Warmups up onnxruntime sessions by apply optimization and
* allocating memory prior
*
* @param warmup Number of warmups.
* @param mbs : max-batch-size Max batch size for the models
*/
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
/** Decode multiple streams in parallel
*
* @param ss Pointer array containing streams to be decoded.
... ...
... ... @@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
c->eof = true;
}
void OnlineWebsocketDecoder::Warmup() const {
recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up,
config_.max_batch_size);
}
void OnlineWebsocketDecoder::Run() {
timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
... ... @@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) {
server_.set_reuse_addr(true);
server_.listen(asio::ip::tcp::v4(), port);
server_.start_accept();
auto recognizer_config = config_.decoder_config.recognizer_config;
int32_t warm_up = recognizer_config.model_config.warm_up;
const std::string &model_type = recognizer_config.model_config.model_type;
if (0 < warm_up && warm_up < 100) {
if (model_type == "zipformer2") {
decoder_.Warmup();
SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up);
} else {
SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now.");
SHERPA_ONNX_LOGE("Given: %s", model_type.c_str());
exit(0);
}
} else if (warm_up == 0) {
SHERPA_ONNX_LOGE("Starting without warmup!");
} else {
SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100");
exit(0);
}
decoder_.Run();
}
... ...
... ... @@ -85,6 +85,8 @@ class OnlineWebsocketDecoder {
// signal that there will be no more audio samples for a stream
void InputFinished(std::shared_ptr<Connection> c);
void Warmup() const;
void Run();
private:
... ...
... ... @@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) {
.def(py::init<const OnlineTransducerModelConfig &,
const OnlineParaformerModelConfig &,
const OnlineWenetCtcModelConfig &,
const OnlineZipformer2CtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
const OnlineZipformer2CtcModelConfig &,
const std::string &, int32_t, int32_t,
bool, const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("debug") = false, py::arg("provider") = "cpu",
py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
... ...