Committed by
GitHub
Adding warm up for Zipformer2 (#766)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com>
正在显示
9 个修改的文件
包含
99 行增加
和
6 行删除
| @@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) { | @@ -21,6 +21,10 @@ void OnlineModelConfig::Register(ParseOptions *po) { | ||
| 21 | po->Register("num-threads", &num_threads, | 21 | po->Register("num-threads", &num_threads, |
| 22 | "Number of threads to run the neural network"); | 22 | "Number of threads to run the neural network"); |
| 23 | 23 | ||
| 24 | + po->Register("warm-up", &warm_up, | ||
| 25 | + "Number of warm-up to run the onnxruntime" | ||
| 26 | + "Valid vales are: zipformer2"); | ||
| 27 | + | ||
| 24 | po->Register("debug", &debug, | 28 | po->Register("debug", &debug, |
| 25 | "true to print model information while loading it."); | 29 | "true to print model information while loading it."); |
| 26 | 30 | ||
| @@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const { | @@ -70,6 +74,7 @@ std::string OnlineModelConfig::ToString() const { | ||
| 70 | os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; | 74 | os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; |
| 71 | os << "tokens=\"" << tokens << "\", "; | 75 | os << "tokens=\"" << tokens << "\", "; |
| 72 | os << "num_threads=" << num_threads << ", "; | 76 | os << "num_threads=" << num_threads << ", "; |
| 77 | + os << "warm_up=" << warm_up << ", "; | ||
| 73 | os << "debug=" << (debug ? "True" : "False") << ", "; | 78 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| 74 | os << "provider=\"" << provider << "\", "; | 79 | os << "provider=\"" << provider << "\", "; |
| 75 | os << "model_type=\"" << model_type << "\")"; | 80 | os << "model_type=\"" << model_type << "\")"; |
| @@ -20,6 +20,7 @@ struct OnlineModelConfig { | @@ -20,6 +20,7 @@ struct OnlineModelConfig { | ||
| 20 | OnlineZipformer2CtcModelConfig zipformer2_ctc; | 20 | OnlineZipformer2CtcModelConfig zipformer2_ctc; |
| 21 | std::string tokens; | 21 | std::string tokens; |
| 22 | int32_t num_threads = 1; | 22 | int32_t num_threads = 1; |
| 23 | + int32_t warm_up = 0; | ||
| 23 | bool debug = false; | 24 | bool debug = false; |
| 24 | std::string provider = "cpu"; | 25 | std::string provider = "cpu"; |
| 25 | 26 | ||
| @@ -38,14 +39,17 @@ struct OnlineModelConfig { | @@ -38,14 +39,17 @@ struct OnlineModelConfig { | ||
| 38 | const OnlineParaformerModelConfig ¶former, | 39 | const OnlineParaformerModelConfig ¶former, |
| 39 | const OnlineWenetCtcModelConfig &wenet_ctc, | 40 | const OnlineWenetCtcModelConfig &wenet_ctc, |
| 40 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, | 41 | const OnlineZipformer2CtcModelConfig &zipformer2_ctc, |
| 41 | - const std::string &tokens, int32_t num_threads, bool debug, | ||
| 42 | - const std::string &provider, const std::string &model_type) | 42 | + const std::string &tokens, int32_t num_threads, |
| 43 | + int32_t warm_up, bool debug, | ||
| 44 | + const std::string &provider, | ||
| 45 | + const std::string &model_type) | ||
| 43 | : transducer(transducer), | 46 | : transducer(transducer), |
| 44 | paraformer(paraformer), | 47 | paraformer(paraformer), |
| 45 | wenet_ctc(wenet_ctc), | 48 | wenet_ctc(wenet_ctc), |
| 46 | zipformer2_ctc(zipformer2_ctc), | 49 | zipformer2_ctc(zipformer2_ctc), |
| 47 | tokens(tokens), | 50 | tokens(tokens), |
| 48 | num_threads(num_threads), | 51 | num_threads(num_threads), |
| 52 | + warm_up(warm_up), | ||
| 49 | debug(debug), | 53 | debug(debug), |
| 50 | provider(provider), | 54 | provider(provider), |
| 51 | model_type(model_type) {} | 55 | model_type(model_type) {} |
| @@ -37,6 +37,12 @@ class OnlineRecognizerImpl { | @@ -37,6 +37,12 @@ class OnlineRecognizerImpl { | ||
| 37 | 37 | ||
| 38 | virtual bool IsReady(OnlineStream *s) const = 0; | 38 | virtual bool IsReady(OnlineStream *s) const = 0; |
| 39 | 39 | ||
| 40 | + virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { | ||
| 41 | + // ToDo extending to other models | ||
| 42 | + SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now."); | ||
| 43 | + exit(-1); | ||
| 44 | + } | ||
| 45 | + | ||
| 40 | virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; | 46 | virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; |
| 41 | 47 | ||
| 42 | virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0; | 48 | virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0; |
| @@ -32,6 +32,7 @@ | @@ -32,6 +32,7 @@ | ||
| 32 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" | 32 | #include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h" |
| 33 | #include "sherpa-onnx/csrc/symbol-table.h" | 33 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 34 | #include "sherpa-onnx/csrc/utils.h" | 34 | #include "sherpa-onnx/csrc/utils.h" |
| 35 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 35 | 36 | ||
| 36 | namespace sherpa_onnx { | 37 | namespace sherpa_onnx { |
| 37 | 38 | ||
| @@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | @@ -183,6 +184,41 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { | ||
| 183 | s->NumFramesReady(); | 184 | s->NumFramesReady(); |
| 184 | } | 185 | } |
| 185 | 186 | ||
| 187 | + // Warmping up engine with wp: warm_up count and max-batch-size | ||
| 188 | + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { | ||
| 189 | + auto max_batch_size = mbs; | ||
| 190 | + if (warmup <= 0 || warmup > 100) { | ||
| 191 | + return; | ||
| 192 | + } | ||
| 193 | + int32_t chunk_size = model_->ChunkSize(); | ||
| 194 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 195 | + int32_t feature_dim = 80; | ||
| 196 | + std::vector<OnlineTransducerDecoderResult> results(max_batch_size); | ||
| 197 | + std::vector<float> features_vec(max_batch_size * chunk_size * feature_dim); | ||
| 198 | + std::vector<std::vector<Ort::Value>> states_vec(max_batch_size); | ||
| 199 | + | ||
| 200 | + auto memory_info = | ||
| 201 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 202 | + | ||
| 203 | + std::array<int64_t, 3> x_shape{max_batch_size, chunk_size, feature_dim}; | ||
| 204 | + | ||
| 205 | + for (int32_t i = 0; i != max_batch_size; ++i) { | ||
| 206 | + states_vec[i] = model_->GetEncoderInitStates(); | ||
| 207 | + results[i] = decoder_->GetEmptyResult(); | ||
| 208 | + } | ||
| 209 | + | ||
| 210 | + for (int32_t i = 0; i != warmup; ++i) { | ||
| 211 | + auto states = model_->StackStates(states_vec); | ||
| 212 | + Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(), | ||
| 213 | + features_vec.size(), x_shape.data(), | ||
| 214 | + x_shape.size()); | ||
| 215 | + auto x_copy = Clone(model_->Allocator(), &x); | ||
| 216 | + auto pair = model_->RunEncoder(std::move(x), std::move(states), | ||
| 217 | + std::move(x_copy)); | ||
| 218 | + decoder_->Decode(std::move(pair.first), &results); | ||
| 219 | + } | ||
| 220 | + } | ||
| 221 | + | ||
| 186 | void DecodeStreams(OnlineStream **ss, int32_t n) const override { | 222 | void DecodeStreams(OnlineStream **ss, int32_t n) const override { |
| 187 | int32_t chunk_size = model_->ChunkSize(); | 223 | int32_t chunk_size = model_->ChunkSize(); |
| 188 | int32_t chunk_shift = model_->ChunkShift(); | 224 | int32_t chunk_shift = model_->ChunkShift(); |
| @@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { | @@ -171,6 +171,12 @@ bool OnlineRecognizer::IsReady(OnlineStream *s) const { | ||
| 171 | return impl_->IsReady(s); | 171 | return impl_->IsReady(s); |
| 172 | } | 172 | } |
| 173 | 173 | ||
| 174 | +void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { | ||
| 175 | + if (warmup > 0) { | ||
| 176 | + impl_->WarmpUpRecognizer(warmup, mbs); | ||
| 177 | + } | ||
| 178 | +} | ||
| 179 | + | ||
| 174 | void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { | 180 | void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { |
| 175 | impl_->DecodeStreams(ss, n); | 181 | impl_->DecodeStreams(ss, n); |
| 176 | } | 182 | } |
| @@ -162,6 +162,15 @@ class OnlineRecognizer { | @@ -162,6 +162,15 @@ class OnlineRecognizer { | ||
| 162 | DecodeStreams(ss, 1); | 162 | DecodeStreams(ss, 1); |
| 163 | } | 163 | } |
| 164 | 164 | ||
| 165 | + /** | ||
| 166 | + * Warmups up onnxruntime sessions by apply optimization and | ||
| 167 | + * allocating memory prior | ||
| 168 | + * | ||
| 169 | + * @param warmup Number of warmups. | ||
| 170 | + * @param mbs : max-batch-size Max batch size for the models | ||
| 171 | + */ | ||
| 172 | + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const; | ||
| 173 | + | ||
| 165 | /** Decode multiple streams in parallel | 174 | /** Decode multiple streams in parallel |
| 166 | * | 175 | * |
| 167 | * @param ss Pointer array containing streams to be decoded. | 176 | * @param ss Pointer array containing streams to be decoded. |
| @@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) { | @@ -95,6 +95,11 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) { | ||
| 95 | c->eof = true; | 95 | c->eof = true; |
| 96 | } | 96 | } |
| 97 | 97 | ||
| 98 | +void OnlineWebsocketDecoder::Warmup() const { | ||
| 99 | + recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up, | ||
| 100 | + config_.max_batch_size); | ||
| 101 | +} | ||
| 102 | + | ||
| 98 | void OnlineWebsocketDecoder::Run() { | 103 | void OnlineWebsocketDecoder::Run() { |
| 99 | timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); | 104 | timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); |
| 100 | 105 | ||
| @@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) { | @@ -242,6 +247,24 @@ void OnlineWebsocketServer::Run(uint16_t port) { | ||
| 242 | server_.set_reuse_addr(true); | 247 | server_.set_reuse_addr(true); |
| 243 | server_.listen(asio::ip::tcp::v4(), port); | 248 | server_.listen(asio::ip::tcp::v4(), port); |
| 244 | server_.start_accept(); | 249 | server_.start_accept(); |
| 250 | + auto recognizer_config = config_.decoder_config.recognizer_config; | ||
| 251 | + int32_t warm_up = recognizer_config.model_config.warm_up; | ||
| 252 | + const std::string &model_type = recognizer_config.model_config.model_type; | ||
| 253 | + if (0 < warm_up && warm_up < 100) { | ||
| 254 | + if (model_type == "zipformer2") { | ||
| 255 | + decoder_.Warmup(); | ||
| 256 | + SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up); | ||
| 257 | + } else { | ||
| 258 | + SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now."); | ||
| 259 | + SHERPA_ONNX_LOGE("Given: %s", model_type.c_str()); | ||
| 260 | + exit(0); | ||
| 261 | + } | ||
| 262 | + } else if (warm_up == 0) { | ||
| 263 | + SHERPA_ONNX_LOGE("Starting without warmup!"); | ||
| 264 | + } else { | ||
| 265 | + SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100"); | ||
| 266 | + exit(0); | ||
| 267 | + } | ||
| 245 | decoder_.Run(); | 268 | decoder_.Run(); |
| 246 | } | 269 | } |
| 247 | 270 |
| @@ -85,6 +85,8 @@ class OnlineWebsocketDecoder { | @@ -85,6 +85,8 @@ class OnlineWebsocketDecoder { | ||
| 85 | // signal that there will be no more audio samples for a stream | 85 | // signal that there will be no more audio samples for a stream |
| 86 | void InputFinished(std::shared_ptr<Connection> c); | 86 | void InputFinished(std::shared_ptr<Connection> c); |
| 87 | 87 | ||
| 88 | + void Warmup() const; | ||
| 89 | + | ||
| 88 | void Run(); | 90 | void Run(); |
| 89 | 91 | ||
| 90 | private: | 92 | private: |
| @@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) { | @@ -27,14 +27,16 @@ void PybindOnlineModelConfig(py::module *m) { | ||
| 27 | .def(py::init<const OnlineTransducerModelConfig &, | 27 | .def(py::init<const OnlineTransducerModelConfig &, |
| 28 | const OnlineParaformerModelConfig &, | 28 | const OnlineParaformerModelConfig &, |
| 29 | const OnlineWenetCtcModelConfig &, | 29 | const OnlineWenetCtcModelConfig &, |
| 30 | - const OnlineZipformer2CtcModelConfig &, const std::string &, | ||
| 31 | - int32_t, bool, const std::string &, const std::string &>(), | 30 | + const OnlineZipformer2CtcModelConfig &, |
| 31 | + const std::string &, int32_t, int32_t, | ||
| 32 | + bool, const std::string &, const std::string &>(), | ||
| 32 | py::arg("transducer") = OnlineTransducerModelConfig(), | 33 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 33 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 34 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 34 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | 35 | py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), |
| 35 | py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), | 36 | py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), |
| 36 | - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | ||
| 37 | - py::arg("provider") = "cpu", py::arg("model_type") = "") | 37 | + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, |
| 38 | + py::arg("debug") = false, py::arg("provider") = "cpu", | ||
| 39 | + py::arg("model_type") = "") | ||
| 38 | .def_readwrite("transducer", &PyClass::transducer) | 40 | .def_readwrite("transducer", &PyClass::transducer) |
| 39 | .def_readwrite("paraformer", &PyClass::paraformer) | 41 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 40 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | 42 | .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) |
-
请 注册 或 登录 后发表评论