Fangjun Kuang
Committed by GitHub

Limit number of tokens in fire red asr decoding. (#2459)

... ... @@ -31,7 +31,8 @@ class OfflineFireRedAsrDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v,
int32_t num_feature_frames) = 0;
};
} // namespace sherpa_onnx
... ...
... ... @@ -16,7 +16,8 @@ namespace sherpa_onnx {
// Note: this functions works only for batch size == 1 at present
std::vector<OfflineFireRedAsrDecoderResult>
OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
Ort::Value cross_v,
int32_t num_feature_frames) {
const auto &meta_data = model_->GetModelMetadata();
auto memory_info =
... ... @@ -53,7 +54,12 @@ OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k,
std::move(cross_v),
std::move(offset)};
for (int32_t i = 0; i < meta_data.max_len; ++i) {
// assume at most 6 tokens per second
int32_t num_possible_tokens = num_feature_frames / 100 * 6;
num_possible_tokens =
std::min<int32_t>(num_possible_tokens, meta_data.max_len / 2);
for (int32_t i = 0; i < num_possible_tokens; ++i) {
decoder_out = model_->ForwardDecoder(View(&tokens),
std::move(std::get<1>(decoder_out)),
std::move(std::get<2>(decoder_out)),
... ...
... ... @@ -18,7 +18,8 @@ class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder {
: model_(model) {}
std::vector<OfflineFireRedAsrDecoderResult> Decode(
Ort::Value cross_k, Ort::Value cross_v) override;
Ort::Value cross_k, Ort::Value cross_v,
int32_t num_feature_frames) override;
private:
OfflineFireRedAsrModel *model_; // not owned
... ...
... ... @@ -119,8 +119,8 @@ class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl {
auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len));
auto results =
decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second));
auto results = decoder_->Decode(std::move(cross_kv.first),
std::move(cross_kv.second), num_frames);
auto r = Convert(results[0], symbol_table_);
... ...