Fangjun Kuang
Committed by GitHub

Limit number of tokens in fire red asr decoding. (#2459)

@@ -31,7 +31,8 @@ class OfflineFireRedAsrDecoder { @@ -31,7 +31,8 @@ class OfflineFireRedAsrDecoder {
31 * @return Return a vector of size `N` containing the decoded results. 31 * @return Return a vector of size `N` containing the decoded results.
32 */ 32 */
33 virtual std::vector<OfflineFireRedAsrDecoderResult> Decode( 33 virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
34 - Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; 34 + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v,
  35 + int32_t num_feature_frames) = 0;
35 }; 36 };
36 37
37 } // namespace sherpa_onnx 38 } // namespace sherpa_onnx
@@ -16,7 +16,8 @@ namespace sherpa_onnx { @@ -16,7 +16,8 @@ namespace sherpa_onnx {
16 // Note: this functions works only for batch size == 1 at present 16 // Note: this functions works only for batch size == 1 at present
17 std::vector<OfflineFireRedAsrDecoderResult> 17 std::vector<OfflineFireRedAsrDecoderResult>
18 OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, 18 OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
19 - Ort::Value cross_v) { 19 + Ort::Value cross_v,
  20 + int32_t num_feature_frames) {
20 const auto &meta_data = model_->GetModelMetadata(); 21 const auto &meta_data = model_->GetModelMetadata();
21 22
22 auto memory_info = 23 auto memory_info =
@@ -53,7 +54,12 @@ OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -53,7 +54,12 @@ OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
53 std::move(cross_v), 54 std::move(cross_v),
54 std::move(offset)}; 55 std::move(offset)};
55 56
56 - for (int32_t i = 0; i < meta_data.max_len; ++i) { 57 + // assume at most 6 tokens per second
  58 + int32_t num_possible_tokens = num_feature_frames / 100 * 6;
  59 + num_possible_tokens =
  60 + std::min<int32_t>(num_possible_tokens, meta_data.max_len / 2);
  61 +
  62 + for (int32_t i = 0; i < num_possible_tokens; ++i) {
57 decoder_out = model_->ForwardDecoder(View(&tokens), 63 decoder_out = model_->ForwardDecoder(View(&tokens),
58 std::move(std::get<1>(decoder_out)), 64 std::move(std::get<1>(decoder_out)),
59 std::move(std::get<2>(decoder_out)), 65 std::move(std::get<2>(decoder_out)),
@@ -18,7 +18,8 @@ class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder { @@ -18,7 +18,8 @@ class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
18 : model_(model) {} 18 : model_(model) {}
19 19
20 std::vector<OfflineFireRedAsrDecoderResult> Decode( 20 std::vector<OfflineFireRedAsrDecoderResult> Decode(
21 - Ort::Value cross_k, Ort::Value cross_v) override; 21 + Ort::Value cross_k, Ort::Value cross_v,
  22 + int32_t num_feature_frames) override;
22 23
23 private: 24 private:
24 OfflineFireRedAsrModel *model_; // not owned 25 OfflineFireRedAsrModel *model_; // not owned
@@ -119,8 +119,8 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { @@ -119,8 +119,8 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
119 119
120 auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len)); 120 auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len));
121 121
122 - auto results =  
123 - decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); 122 + auto results = decoder_->Decode(std::move(cross_kv.first),
  123 + std::move(cross_kv.second), num_frames);
124 124
125 auto r = Convert(results[0], symbol_table_); 125 auto r = Convert(results[0], symbol_table_);
126 126