Manix
Committed by GitHub

Adding warm up for Zipformer2 (#766)

Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
@@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) { @@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) {
21 po->Register("num-threads", &num_threads, 21 po->Register("num-threads", &num_threads,
22 "Number of threads to run the neural network"); 22 "Number of threads to run the neural network");
23 23
  24 + po->Register("warm-up", &warm_up,
  25 + "Number of warm-up to run the onnxruntime"
  26 + "Valid vales are: zipformer2");
  27 +
24 po->Register("debug", &debug, 28 po->Register("debug", &debug,
25 "true to print model information while loading it."); 29 "true to print model information while loading it.");
26 30
@@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const { @@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const {
70 os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; 74 os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
71 os << "tokens=\"" << tokens << "\", "; 75 os << "tokens=\"" << tokens << "\", ";
72 os << "num_threads=" << num_threads << ", "; 76 os << "num_threads=" << num_threads << ", ";
  77 + os << "warm_up=" << warm_up << ", ";
73 os << "debug=" << (debug ? "True" : "False") << ", "; 78 os << "debug=" << (debug ? "True" : "False") << ", ";
74 os << "provider=\"" << provider << "\", "; 79 os << "provider=\"" << provider << "\", ";
75 os << "model_type=\"" << model_type << "\")"; 80 os << "model_type=\"" << model_type << "\")";
@@ -20,6 +20,7 @@ struct OnlineModelConfig { @@ -20,6 +20,7 @@ struct OnlineModelConfig {
20 OnlineZipformer2CtcModelConfig zipformer2_ctc; 20 OnlineZipformer2CtcModelConfig zipformer2_ctc;
21 std::string tokens; 21 std::string tokens;
22 int32_t num_threads = 1; 22 int32_t num_threads = 1;
  23 + int32_t warm_up = 0;
23 bool debug = false; 24 bool debug = false;
24 std::string provider = "cpu"; 25 std::string provider = "cpu";
25 26
@@ -38,14 +39,17 @@ struct OnlineModelConfig { @@ -38,14 +39,17 @@ struct OnlineModelConfig {
38 const OnlineParaformerModelConfig &paraformer, 39 const OnlineParaformerModelConfig &paraformer,
39 const OnlineWenetCtcModelConfig &wenet_ctc, 40 const OnlineWenetCtcModelConfig &wenet_ctc,
40 const OnlineZipformer2CtcModelConfig &zipformer2_ctc, 41 const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
41 - const std::string &tokens, int32_t num_threads, bool debug,  
42 - const std::string &provider, const std::string &model_type) 42 + const std::string &tokens, int32_t num_threads,
  43 + int32_t warm_up, bool debug,
  44 + const std::string &provider,
  45 + const std::string &model_type)
43 : transducer(transducer), 46 : transducer(transducer),
44 paraformer(paraformer), 47 paraformer(paraformer),
45 wenet_ctc(wenet_ctc), 48 wenet_ctc(wenet_ctc),
46 zipformer2_ctc(zipformer2_ctc), 49 zipformer2_ctc(zipformer2_ctc),
47 tokens(tokens), 50 tokens(tokens),
48 num_threads(num_threads), 51 num_threads(num_threads),
  52 + warm_up(warm_up),
49 debug(debug), 53 debug(debug),
50 provider(provider), 54 provider(provider),
51 model_type(model_type) {} 55 model_type(model_type) {}
@@ -37,6 +37,12 @@ class OnlineRecognizerImpl { @@ -37,6 +37,12 @@ class OnlineRecognizerImpl {
37 37
38 virtual bool IsReady(OnlineStream *s) const = 0; 38 virtual bool IsReady(OnlineStream *s) const = 0;
39 39
  40 + virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
  41 + // ToDo extending to other models
  42 + SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now.");
  43 + exit(-1);
  44 + }
  45 +
40 virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; 46 virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
41 47
42 virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0; 48 virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0;
@@ -32,6 +32,7 @@ @@ -32,6 +32,7 @@
32 #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" 32 #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
33 #include "sherpa-onnx/csrc/symbol-table.h" 33 #include "sherpa-onnx/csrc/symbol-table.h"
34 #include "sherpa-onnx/csrc/utils.h" 34 #include "sherpa-onnx/csrc/utils.h"
  35 +#include "sherpa-onnx/csrc/onnx-utils.h"
35 36
36 namespace sherpa_onnx { 37 namespace sherpa_onnx {
37 38
@@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
183 s->NumFramesReady(); 184 s->NumFramesReady();
184 } 185 }
185 186
  187 + // Warmping up engine with wp: warm_up count and max-batch-size
  188 + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
  189 + auto max_batch_size = mbs;
  190 + if (warmup <= 0 || warmup > 100) {
  191 + return;
  192 + }
  193 + int32_t chunk_size = model_->ChunkSize();
  194 + int32_t chunk_shift = model_->ChunkShift();
  195 + int32_t feature_dim = 80;
  196 + std::vector<OnlineTransducerDecoderResult> results(max_batch_size);
  197 + std::vector<float> features_vec(max_batch_size * chunk_size * feature_dim);
  198 + std::vector<std::vector<Ort::Value>> states_vec(max_batch_size);
  199 +
  200 + auto memory_info =
  201 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  202 +
  203 + std::array<int64_t, 3> x_shape{max_batch_size, chunk_size, feature_dim};
  204 +
  205 + for (int32_t i = 0; i != max_batch_size; ++i) {
  206 + states_vec[i] = model_->GetEncoderInitStates();
  207 + results[i] = decoder_->GetEmptyResult();
  208 + }
  209 +
  210 + for (int32_t i = 0; i != warmup; ++i) {
  211 + auto states = model_->StackStates(states_vec);
  212 + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
  213 + features_vec.size(), x_shape.data(),
  214 + x_shape.size());
  215 + auto x_copy = Clone(model_->Allocator(), &x);
  216 + auto pair = model_->RunEncoder(std::move(x), std::move(states),
  217 + std::move(x_copy));
  218 + decoder_->Decode(std::move(pair.first), &results);
  219 + }
  220 + }
  221 +
186 void DecodeStreams(OnlineStream **ss, int32_t n) const override { 222 void DecodeStreams(OnlineStream **ss, int32_t n) const override {
187 int32_t chunk_size = model_->ChunkSize(); 223 int32_t chunk_size = model_->ChunkSize();
188 int32_t chunk_shift = model_->ChunkShift(); 224 int32_t chunk_shift = model_->ChunkShift();
@@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { @@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const {
171 return impl_->IsReady(s); 171 return impl_->IsReady(s);
172 } 172 }
173 173
  174 +void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
  175 + if (warmup > 0) {
  176 + impl_->WarmpUpRecognizer(warmup, mbs);
  177 + }
  178 +}
  179 +
174 void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { 180 void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const {
175 impl_->DecodeStreams(ss, n); 181 impl_->DecodeStreams(ss, n);
176 } 182 }
@@ -162,6 +162,15 @@ class OnlineRecognizer { @@ -162,6 +162,15 @@ class OnlineRecognizer {
162 DecodeStreams(ss, 1); 162 DecodeStreams(ss, 1);
163 } 163 }
164 164
  165 + /**
  166 + * Warmups up onnxruntime sessions by apply optimization and
  167 + * allocating memory prior
  168 + *
  169 + * @param warmup Number of warmups.
  170 + * @param mbs : max-batch-size Max batch size for the models
  171 + */
  172 + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
  173 +
165 /** Decode multiple streams in parallel 174 /** Decode multiple streams in parallel
166 * 175 *
167 * @param ss Pointer array containing streams to be decoded. 176 * @param ss Pointer array containing streams to be decoded.
@@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) { @@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
95 c->eof = true; 95 c->eof = true;
96 } 96 }
97 97
  98 +void OnlineWebsocketDecoder::Warmup() const {
  99 + recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up,
  100 + config_.max_batch_size);
  101 +}
  102 +
98 void OnlineWebsocketDecoder::Run() { 103 void OnlineWebsocketDecoder::Run() {
99 timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); 104 timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms));
100 105
@@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) { @@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) {
242 server_.set_reuse_addr(true); 247 server_.set_reuse_addr(true);
243 server_.listen(asio::ip::tcp::v4(), port); 248 server_.listen(asio::ip::tcp::v4(), port);
244 server_.start_accept(); 249 server_.start_accept();
  250 + auto recognizer_config = config_.decoder_config.recognizer_config;
  251 + int32_t warm_up = recognizer_config.model_config.warm_up;
  252 + const std::string &model_type = recognizer_config.model_config.model_type;
  253 + if (0 < warm_up && warm_up < 100) {
  254 + if (model_type == "zipformer2") {
  255 + decoder_.Warmup();
  256 + SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up);
  257 + } else {
  258 + SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now.");
  259 + SHERPA_ONNX_LOGE("Given: %s", model_type.c_str());
  260 + exit(0);
  261 + }
  262 + } else if (warm_up == 0) {
  263 + SHERPA_ONNX_LOGE("Starting without warmup!");
  264 + } else {
  265 + SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100");
  266 + exit(0);
  267 + }
245 decoder_.Run(); 268 decoder_.Run();
246 } 269 }
247 270
@@ -85,6 +85,8 @@ class OnlineWebsocketDecoder { @@ -85,6 +85,8 @@ class OnlineWebsocketDecoder {
85 // signal that there will be no more audio samples for a stream 85 // signal that there will be no more audio samples for a stream
86 void InputFinished(std::shared_ptr<Connection> c); 86 void InputFinished(std::shared_ptr<Connection> c);
87 87
  88 + void Warmup() const;
  89 +
88 void Run(); 90 void Run();
89 91
90 private: 92 private:
@@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) { @@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) {
27 .def(py::init<const OnlineTransducerModelConfig &, 27 .def(py::init<const OnlineTransducerModelConfig &,
28 const OnlineParaformerModelConfig &, 28 const OnlineParaformerModelConfig &,
29 const OnlineWenetCtcModelConfig &, 29 const OnlineWenetCtcModelConfig &,
30 - const OnlineZipformer2CtcModelConfig &, const std::string &,  
31 - int32_t, bool, const std::string &, const std::string &>(), 30 + const OnlineZipformer2CtcModelConfig &,
  31 + const std::string &, int32_t, int32_t,
  32 + bool, const std::string &, const std::string &>(),
32 py::arg("transducer") = OnlineTransducerModelConfig(), 33 py::arg("transducer") = OnlineTransducerModelConfig(),
33 py::arg("paraformer") = OnlineParaformerModelConfig(), 34 py::arg("paraformer") = OnlineParaformerModelConfig(),
34 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), 35 py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
35 py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), 36 py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
36 - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,  
37 - py::arg("provider") = "cpu", py::arg("model_type") = "") 37 + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
  38 + py::arg("debug") = false, py::arg("provider") = "cpu",
  39 + py::arg("model_type") = "")
38 .def_readwrite("transducer", &PyClass::transducer) 40 .def_readwrite("transducer", &PyClass::transducer)
39 .def_readwrite("paraformer", &PyClass::paraformer) 41 .def_readwrite("paraformer", &PyClass::paraformer)
40 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) 42 .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)