Fangjun Kuang
Committed by GitHub

Fix rknn for multi-threads (#2274)

... ... @@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl {
}
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run(
std::vector<float> features,
std::vector<std::vector<uint8_t>> states) const {
std::vector<float> features, std::vector<std::vector<uint8_t>> states) {
std::vector<rknn_input> inputs(input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
... ... @@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl {
}
}
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
rknn_context ctx = 0;
auto ret = rknn_dup_context(&ctx_, &ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx");
ret = rknn_inputs_set(ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx_, nullptr);
ret = rknn_run(ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
for (int32_t i = 0; i < next_states.size(); ++i) {
... ... @@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl {
}
}
rknn_destroy(ctx);
return {std::move(out), std::move(next_states)};
}
... ...
... ... @@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder(
std::vector<float> features,
std::vector<std::vector<uint8_t>> states) const {
std::vector<float> features, std::vector<std::vector<uint8_t>> states) {
std::vector<rknn_input> inputs(encoder_input_attrs_.size());
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
... ... @@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
}
auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data());
rknn_context encoder_ctx = 0;
// https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h#L444C1-L444C75
// rknn_dup_context(rknn_context* context_in, rknn_context* context_out);
auto ret = rknn_dup_context(&encoder_ctx_, &encoder_ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the encoder ctx");
ret = rknn_inputs_set(encoder_ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs");
ret = rknn_run(encoder_ctx_, nullptr);
ret = rknn_run(encoder_ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder");
ret =
rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr);
rknn_outputs_get(encoder_ctx, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output");
for (int32_t i = 0; i < next_states.size(); ++i) {
... ... @@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl {
}
}
rknn_destroy(encoder_ctx);
return {std::move(encoder_out), std::move(next_states)};
}
std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) const {
std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) {
auto &attr = decoder_input_attrs_[0];
rknn_input input;
... ... @@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl {
output.size = decoder_out.size() * sizeof(float);
output.buf = decoder_out.data();
auto ret = rknn_inputs_set(decoder_ctx_, 1, &input);
rknn_context decoder_ctx = 0;
auto ret = rknn_dup_context(&decoder_ctx_, &decoder_ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the decoder ctx");
ret = rknn_inputs_set(decoder_ctx, 1, &input);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs");
ret = rknn_run(decoder_ctx_, nullptr);
ret = rknn_run(decoder_ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder");
ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr);
ret = rknn_outputs_get(decoder_ctx, 1, &output, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output");
rknn_destroy(decoder_ctx);
return decoder_out;
}
std::vector<float> RunJoiner(const float *encoder_out,
const float *decoder_out) const {
const float *decoder_out) {
std::vector<rknn_input> inputs(2);
inputs[0].index = 0;
inputs[0].type = RKNN_TENSOR_FLOAT32;
... ... @@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl {
output.size = joiner_out.size() * sizeof(float);
output.buf = joiner_out.data();
auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data());
rknn_context joiner_ctx = 0;
auto ret = rknn_dup_context(&joiner_ctx_, &joiner_ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the joiner ctx");
ret = rknn_inputs_set(joiner_ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs");
ret = rknn_run(joiner_ctx_, nullptr);
ret = rknn_run(joiner_ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner");
ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr);
ret = rknn_outputs_get(joiner_ctx, 1, &output, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output");
rknn_destroy(joiner_ctx);
return joiner_out;
}
... ...