Committed by
GitHub
Limit number of tokens in fire red asr decoding. (#2459)
正在显示
4 个修改的文件
包含
14 行增加
和
6 行删除
| @@ -31,7 +31,8 @@ class OfflineFireRedAsrDecoder { | @@ -31,7 +31,8 @@ class OfflineFireRedAsrDecoder { | ||
| 31 | * @return Return a vector of size `N` containing the decoded results. | 31 | * @return Return a vector of size `N` containing the decoded results. |
| 32 | */ | 32 | */ |
| 33 | virtual std::vector<OfflineFireRedAsrDecoderResult> Decode( | 33 | virtual std::vector<OfflineFireRedAsrDecoderResult> Decode( |
| 34 | - Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; | 34 | + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v, |
| 35 | + int32_t num_feature_frames) = 0; | ||
| 35 | }; | 36 | }; |
| 36 | 37 | ||
| 37 | } // namespace sherpa_onnx | 38 | } // namespace sherpa_onnx |
| @@ -16,7 +16,8 @@ namespace sherpa_onnx { | @@ -16,7 +16,8 @@ namespace sherpa_onnx { | ||
| 16 | // Note: this functions works only for batch size == 1 at present | 16 | // Note: this functions works only for batch size == 1 at present |
| 17 | std::vector<OfflineFireRedAsrDecoderResult> | 17 | std::vector<OfflineFireRedAsrDecoderResult> |
| 18 | OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, | 18 | OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, |
| 19 | - Ort::Value cross_v) { | 19 | + Ort::Value cross_v, |
| 20 | + int32_t num_feature_frames) { | ||
| 20 | const auto &meta_data = model_->GetModelMetadata(); | 21 | const auto &meta_data = model_->GetModelMetadata(); |
| 21 | 22 | ||
| 22 | auto memory_info = | 23 | auto memory_info = |
| @@ -53,7 +54,12 @@ OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -53,7 +54,12 @@ OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 53 | std::move(cross_v), | 54 | std::move(cross_v), |
| 54 | std::move(offset)}; | 55 | std::move(offset)}; |
| 55 | 56 | ||
| 56 | - for (int32_t i = 0; i < meta_data.max_len; ++i) { | 57 | + // assume at most 6 tokens per second |
| 58 | + int32_t num_possible_tokens = num_feature_frames / 100 * 6; | ||
| 59 | + num_possible_tokens = | ||
| 60 | + std::min<int32_t>(num_possible_tokens, meta_data.max_len / 2); | ||
| 61 | + | ||
| 62 | + for (int32_t i = 0; i < num_possible_tokens; ++i) { | ||
| 57 | decoder_out = model_->ForwardDecoder(View(&tokens), | 63 | decoder_out = model_->ForwardDecoder(View(&tokens), |
| 58 | std::move(std::get<1>(decoder_out)), | 64 | std::move(std::get<1>(decoder_out)), |
| 59 | std::move(std::get<2>(decoder_out)), | 65 | std::move(std::get<2>(decoder_out)), |
| @@ -18,7 +18,8 @@ class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder { | @@ -18,7 +18,8 @@ class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder { | ||
| 18 | : model_(model) {} | 18 | : model_(model) {} |
| 19 | 19 | ||
| 20 | std::vector<OfflineFireRedAsrDecoderResult> Decode( | 20 | std::vector<OfflineFireRedAsrDecoderResult> Decode( |
| 21 | - Ort::Value cross_k, Ort::Value cross_v) override; | 21 | + Ort::Value cross_k, Ort::Value cross_v, |
| 22 | + int32_t num_feature_frames) override; | ||
| 22 | 23 | ||
| 23 | private: | 24 | private: |
| 24 | OfflineFireRedAsrModel *model_; // not owned | 25 | OfflineFireRedAsrModel *model_; // not owned |
| @@ -119,8 +119,8 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { | @@ -119,8 +119,8 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { | ||
| 119 | 119 | ||
| 120 | auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len)); | 120 | auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len)); |
| 121 | 121 | ||
| 122 | - auto results = | ||
| 123 | - decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); | 122 | + auto results = decoder_->Decode(std::move(cross_kv.first), |
| 123 | + std::move(cross_kv.second), num_frames); | ||
| 124 | 124 | ||
| 125 | auto r = Convert(results[0], symbol_table_); | 125 | auto r = Convert(results[0], symbol_table_); |
| 126 | 126 |
-
请 注册 或 登录 后发表评论