Fangjun Kuang
Committed by GitHub

Fix rknn for multi-threads (#2274)

@@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl { @@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl {
86 } 86 }
87 87
88 std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run( 88 std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
89 - std::vector<float> features,  
90 - std::vector<std::vector<uint8_t>> states) const { 89 + std::vector<float> features, std::vector<std::vector<uint8_t>> states) {
91 std::vector<rknn_input> inputs(input_attrs_.size()); 90 std::vector<rknn_input> inputs(input_attrs_.size());
92 91
93 for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { 92 for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
@@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl { @@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl {
147 } 146 }
148 } 147 }
149 148
150 - auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); 149 + rknn_context ctx = 0;
  150 + auto ret = rknn_dup_context(&ctx_, &ctx);
  151 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx");
  152 +
  153 + ret = rknn_inputs_set(ctx, inputs.size(), inputs.data());
151 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); 154 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
152 155
153 - ret = rknn_run(ctx_, nullptr); 156 + ret = rknn_run(ctx, nullptr);
154 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); 157 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
155 158
156 - ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); 159 + ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr);
157 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); 160 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
158 161
159 for (int32_t i = 0; i < next_states.size(); ++i) { 162 for (int32_t i = 0; i < next_states.size(); ++i) {
@@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl { @@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl {
174 } 177 }
175 } 178 }
176 179
  180 + rknn_destroy(ctx);
  181 +
177 return {std::move(out), std::move(next_states)}; 182 return {std::move(out), std::move(next_states)};
178 } 183 }
179 184
@@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
120 } 120 }
121 121
122 std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder( 122 std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder(
123 - std::vector<float> features,  
124 - std::vector<std::vector<uint8_t>> states) const { 123 + std::vector<float> features, std::vector<std::vector<uint8_t>> states) {
125 std::vector<rknn_input> inputs(encoder_input_attrs_.size()); 124 std::vector<rknn_input> inputs(encoder_input_attrs_.size());
126 125
127 for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { 126 for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
@@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
181 } 180 }
182 } 181 }
183 182
184 - auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data()); 183 + rknn_context encoder_ctx = 0;
  184 +
  185 + // https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h#L444C1-L444C75
  186 + // rknn_dup_context(rknn_context* context_in, rknn_context* context_out);
  187 + auto ret = rknn_dup_context(&encoder_ctx_, &encoder_ctx);
  188 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the encoder ctx");
  189 +
  190 + ret = rknn_inputs_set(encoder_ctx, inputs.size(), inputs.data());
185 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs"); 191 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs");
186 192
187 - ret = rknn_run(encoder_ctx_, nullptr); 193 + ret = rknn_run(encoder_ctx, nullptr);
188 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder"); 194 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder");
189 195
190 ret = 196 ret =
191 - rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr); 197 + rknn_outputs_get(encoder_ctx, outputs.size(), outputs.data(), nullptr);
192 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output"); 198 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output");
193 199
194 for (int32_t i = 0; i < next_states.size(); ++i) { 200 for (int32_t i = 0; i < next_states.size(); ++i) {
@@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl {
209 } 215 }
210 } 216 }
211 217
  218 + rknn_destroy(encoder_ctx);
  219 +
212 return {std::move(encoder_out), std::move(next_states)}; 220 return {std::move(encoder_out), std::move(next_states)};
213 } 221 }
214 222
215 - std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) const { 223 + std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) {
216 auto &attr = decoder_input_attrs_[0]; 224 auto &attr = decoder_input_attrs_[0];
217 rknn_input input; 225 rknn_input input;
218 226
@@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl {
230 output.size = decoder_out.size() * sizeof(float); 238 output.size = decoder_out.size() * sizeof(float);
231 output.buf = decoder_out.data(); 239 output.buf = decoder_out.data();
232 240
233 - auto ret = rknn_inputs_set(decoder_ctx_, 1, &input); 241 + rknn_context decoder_ctx = 0;
  242 + auto ret = rknn_dup_context(&decoder_ctx_, &decoder_ctx);
  243 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the decoder ctx");
  244 +
  245 + ret = rknn_inputs_set(decoder_ctx, 1, &input);
234 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs"); 246 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs");
235 247
236 - ret = rknn_run(decoder_ctx_, nullptr); 248 + ret = rknn_run(decoder_ctx, nullptr);
237 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder"); 249 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder");
238 250
239 - ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr); 251 + ret = rknn_outputs_get(decoder_ctx, 1, &output, nullptr);
240 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output"); 252 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output");
241 253
  254 + rknn_destroy(decoder_ctx);
  255 +
242 return decoder_out; 256 return decoder_out;
243 } 257 }
244 258
245 std::vector<float> RunJoiner(const float *encoder_out, 259 std::vector<float> RunJoiner(const float *encoder_out,
246 - const float *decoder_out) const { 260 + const float *decoder_out) {
247 std::vector<rknn_input> inputs(2); 261 std::vector<rknn_input> inputs(2);
248 inputs[0].index = 0; 262 inputs[0].index = 0;
249 inputs[0].type = RKNN_TENSOR_FLOAT32; 263 inputs[0].type = RKNN_TENSOR_FLOAT32;
@@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl { @@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
265 output.size = joiner_out.size() * sizeof(float); 279 output.size = joiner_out.size() * sizeof(float);
266 output.buf = joiner_out.data(); 280 output.buf = joiner_out.data();
267 281
268 - auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data()); 282 + rknn_context joiner_ctx = 0;
  283 + auto ret = rknn_dup_context(&joiner_ctx_, &joiner_ctx);
  284 + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the joiner ctx");
  285 +
  286 + ret = rknn_inputs_set(joiner_ctx, inputs.size(), inputs.data());
269 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs"); 287 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs");
270 288
271 - ret = rknn_run(joiner_ctx_, nullptr); 289 + ret = rknn_run(joiner_ctx, nullptr);
272 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner"); 290 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner");
273 291
274 - ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr); 292 + ret = rknn_outputs_get(joiner_ctx, 1, &output, nullptr);
275 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output"); 293 SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output");
276 294
  295 + rknn_destroy(joiner_ctx);
  296 +
277 return joiner_out; 297 return joiner_out;
278 } 298 }
279 299