正在显示
2 个修改的文件
包含
43 行增加
和
18 行删除
| @@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl { | @@ -86,8 +86,7 @@ class OnlineZipformerCtcModelRknn::Impl { | ||
| 86 | } | 86 | } |
| 87 | 87 | ||
| 88 | std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run( | 88 | std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> Run( |
| 89 | - std::vector<float> features, | ||
| 90 | - std::vector<std::vector<uint8_t>> states) const { | 89 | + std::vector<float> features, std::vector<std::vector<uint8_t>> states) { |
| 91 | std::vector<rknn_input> inputs(input_attrs_.size()); | 90 | std::vector<rknn_input> inputs(input_attrs_.size()); |
| 92 | 91 | ||
| 93 | for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { | 92 | for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { |
| @@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl { | @@ -147,13 +146,17 @@ class OnlineZipformerCtcModelRknn::Impl { | ||
| 147 | } | 146 | } |
| 148 | } | 147 | } |
| 149 | 148 | ||
| 150 | - auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); | 149 | + rknn_context ctx = 0; |
| 150 | + auto ret = rknn_dup_context(&ctx_, &ctx); | ||
| 151 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx"); | ||
| 152 | + | ||
| 153 | + ret = rknn_inputs_set(ctx, inputs.size(), inputs.data()); | ||
| 151 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); | 154 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); |
| 152 | 155 | ||
| 153 | - ret = rknn_run(ctx_, nullptr); | 156 | + ret = rknn_run(ctx, nullptr); |
| 154 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); | 157 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); |
| 155 | 158 | ||
| 156 | - ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); | 159 | + ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr); |
| 157 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); | 160 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); |
| 158 | 161 | ||
| 159 | for (int32_t i = 0; i < next_states.size(); ++i) { | 162 | for (int32_t i = 0; i < next_states.size(); ++i) { |
| @@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl { | @@ -174,6 +177,8 @@ class OnlineZipformerCtcModelRknn::Impl { | ||
| 174 | } | 177 | } |
| 175 | } | 178 | } |
| 176 | 179 | ||
| 180 | + rknn_destroy(ctx); | ||
| 181 | + | ||
| 177 | return {std::move(out), std::move(next_states)}; | 182 | return {std::move(out), std::move(next_states)}; |
| 178 | } | 183 | } |
| 179 | 184 |
| @@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -120,8 +120,7 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 120 | } | 120 | } |
| 121 | 121 | ||
| 122 | std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder( | 122 | std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder( |
| 123 | - std::vector<float> features, | ||
| 124 | - std::vector<std::vector<uint8_t>> states) const { | 123 | + std::vector<float> features, std::vector<std::vector<uint8_t>> states) { |
| 125 | std::vector<rknn_input> inputs(encoder_input_attrs_.size()); | 124 | std::vector<rknn_input> inputs(encoder_input_attrs_.size()); |
| 126 | 125 | ||
| 127 | for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { | 126 | for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) { |
| @@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -181,14 +180,21 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 181 | } | 180 | } |
| 182 | } | 181 | } |
| 183 | 182 | ||
| 184 | - auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data()); | 183 | + rknn_context encoder_ctx = 0; |
| 184 | + | ||
| 185 | + // https://github.com/rockchip-linux/rknpu2/blob/master/runtime/RK3588/Linux/librknn_api/include/rknn_api.h#L444C1-L444C75 | ||
| 186 | + // rknn_dup_context(rknn_context* context_in, rknn_context* context_out); | ||
| 187 | + auto ret = rknn_dup_context(&encoder_ctx_, &encoder_ctx); | ||
| 188 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the encoder ctx"); | ||
| 189 | + | ||
| 190 | + ret = rknn_inputs_set(encoder_ctx, inputs.size(), inputs.data()); | ||
| 185 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs"); | 191 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs"); |
| 186 | 192 | ||
| 187 | - ret = rknn_run(encoder_ctx_, nullptr); | 193 | + ret = rknn_run(encoder_ctx, nullptr); |
| 188 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder"); | 194 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder"); |
| 189 | 195 | ||
| 190 | ret = | 196 | ret = |
| 191 | - rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr); | 197 | + rknn_outputs_get(encoder_ctx, outputs.size(), outputs.data(), nullptr); |
| 192 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output"); | 198 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output"); |
| 193 | 199 | ||
| 194 | for (int32_t i = 0; i < next_states.size(); ++i) { | 200 | for (int32_t i = 0; i < next_states.size(); ++i) { |
| @@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -209,10 +215,12 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 209 | } | 215 | } |
| 210 | } | 216 | } |
| 211 | 217 | ||
| 218 | + rknn_destroy(encoder_ctx); | ||
| 219 | + | ||
| 212 | return {std::move(encoder_out), std::move(next_states)}; | 220 | return {std::move(encoder_out), std::move(next_states)}; |
| 213 | } | 221 | } |
| 214 | 222 | ||
| 215 | - std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) const { | 223 | + std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) { |
| 216 | auto &attr = decoder_input_attrs_[0]; | 224 | auto &attr = decoder_input_attrs_[0]; |
| 217 | rknn_input input; | 225 | rknn_input input; |
| 218 | 226 | ||
| @@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -230,20 +238,26 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 230 | output.size = decoder_out.size() * sizeof(float); | 238 | output.size = decoder_out.size() * sizeof(float); |
| 231 | output.buf = decoder_out.data(); | 239 | output.buf = decoder_out.data(); |
| 232 | 240 | ||
| 233 | - auto ret = rknn_inputs_set(decoder_ctx_, 1, &input); | 241 | + rknn_context decoder_ctx = 0; |
| 242 | + auto ret = rknn_dup_context(&decoder_ctx_, &decoder_ctx); | ||
| 243 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the decoder ctx"); | ||
| 244 | + | ||
| 245 | + ret = rknn_inputs_set(decoder_ctx, 1, &input); | ||
| 234 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs"); | 246 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs"); |
| 235 | 247 | ||
| 236 | - ret = rknn_run(decoder_ctx_, nullptr); | 248 | + ret = rknn_run(decoder_ctx, nullptr); |
| 237 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder"); | 249 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder"); |
| 238 | 250 | ||
| 239 | - ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr); | 251 | + ret = rknn_outputs_get(decoder_ctx, 1, &output, nullptr); |
| 240 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output"); | 252 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output"); |
| 241 | 253 | ||
| 254 | + rknn_destroy(decoder_ctx); | ||
| 255 | + | ||
| 242 | return decoder_out; | 256 | return decoder_out; |
| 243 | } | 257 | } |
| 244 | 258 | ||
| 245 | std::vector<float> RunJoiner(const float *encoder_out, | 259 | std::vector<float> RunJoiner(const float *encoder_out, |
| 246 | - const float *decoder_out) const { | 260 | + const float *decoder_out) { |
| 247 | std::vector<rknn_input> inputs(2); | 261 | std::vector<rknn_input> inputs(2); |
| 248 | inputs[0].index = 0; | 262 | inputs[0].index = 0; |
| 249 | inputs[0].type = RKNN_TENSOR_FLOAT32; | 263 | inputs[0].type = RKNN_TENSOR_FLOAT32; |
| @@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl { | @@ -265,15 +279,21 @@ class OnlineZipformerTransducerModelRknn::Impl { | ||
| 265 | output.size = joiner_out.size() * sizeof(float); | 279 | output.size = joiner_out.size() * sizeof(float); |
| 266 | output.buf = joiner_out.data(); | 280 | output.buf = joiner_out.data(); |
| 267 | 281 | ||
| 268 | - auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data()); | 282 | + rknn_context joiner_ctx = 0; |
| 283 | + auto ret = rknn_dup_context(&joiner_ctx_, &joiner_ctx); | ||
| 284 | + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the joiner ctx"); | ||
| 285 | + | ||
| 286 | + ret = rknn_inputs_set(joiner_ctx, inputs.size(), inputs.data()); | ||
| 269 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs"); | 287 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs"); |
| 270 | 288 | ||
| 271 | - ret = rknn_run(joiner_ctx_, nullptr); | 289 | + ret = rknn_run(joiner_ctx, nullptr); |
| 272 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner"); | 290 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner"); |
| 273 | 291 | ||
| 274 | - ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr); | 292 | + ret = rknn_outputs_get(joiner_ctx, 1, &output, nullptr); |
| 275 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output"); | 293 | SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output"); |
| 276 | 294 | ||
| 295 | + rknn_destroy(joiner_ctx); | ||
| 296 | + | ||
| 277 | return joiner_out; | 297 | return joiner_out; |
| 278 | } | 298 | } |
| 279 | 299 |
-
请 注册 或 登录 后发表评论