Fangjun Kuang
Committed by GitHub

Support TDT transducer decoding (#2495)

... ... @@ -207,6 +207,7 @@ def main():
for line in f:
t, idx = line.split()
id2token[int(idx)] = t
vocab_size = len(id2token)
start = time.time()
fbank = create_fbank()
... ... @@ -242,12 +243,21 @@ def main():
encoder_out = model.run_encoder(features)
# encoder_out:[batch_size, dim, T)
for t in range(encoder_out.shape[2]):
t = 0
while t < encoder_out.shape[2]:
encoder_out_t = encoder_out[:, :, t : t + 1]
logits = model.run_joiner(encoder_out_t, decoder_out)
logits = torch.from_numpy(logits)
logits = logits.squeeze()
idx = torch.argmax(logits, dim=-1).item()
token_logits = logits[:vocab_size]
duration_logits = logits[vocab_size:]
idx = torch.argmax(token_logits, dim=-1).item()
skip = torch.argmax(duration_logits, dim=-1).item()
if skip == 0:
skip = 1
if idx != blank:
ans.append(idx)
state0 = state0_next
... ... @@ -255,6 +265,7 @@ def main():
decoder_out, state0_next, state1_next = model.run_decoder(
ans[-1], state0, state1
)
t += skip
end = time.time()
... ...
... ... @@ -43,7 +43,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
model_.get(), config_.blank_penalty, model_->IsTDT());
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
... ...
... ... @@ -94,6 +94,72 @@ static OfflineTransducerDecoderResult DecodeOne(
return ans;
}
static OfflineTransducerDecoderResult DecodeOneTDT(
const float *p, int32_t num_rows, int32_t num_cols,
OfflineTransducerNeMoModel *model, float blank_penalty) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
OfflineTransducerDecoderResult ans;
int32_t vocab_size = model->VocabSize();
int32_t blank_id = vocab_size - 1;
auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
model->GetDecoderInitStates(1));
std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
int32_t skip = 0;
for (int32_t t = 0; t < num_rows; t += skip) {
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
encoder_shape.data(), encoder_shape.size());
Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
View(&decoder_output_pair.first));
auto shape = logit.GetTensorTypeAndShapeInfo().GetShape();
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty > 0) {
p_logit[blank_id] -= blank_penalty;
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
skip = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit) + vocab_size,
std::max_element(static_cast<const float *>(p_logit) + vocab_size,
static_cast<const float *>(p_logit) + shape.back())));
if (skip == 0) {
skip = 1;
}
if (y != blank_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
decoder_input_pair = BuildDecoderInput(y, model->Allocator());
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_input_pair.second),
std::move(decoder_output_pair.second));
}
} // for (int32_t t = 0; t < num_rows; ++t) {
return ans;
}
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerGreedySearchNeMoDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
... ... @@ -123,7 +189,11 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
? encoder_out_length.GetTensorData<int32_t>()[i]
: encoder_out_length.GetTensorData<int64_t>()[i];
ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
if (is_tdt_) {
ans[i] = DecodeOneTDT(this_p, this_len, dim2, model_, blank_penalty_);
} else {
ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
}
}
return ans;
... ...
... ... @@ -16,8 +16,8 @@ class OfflineTransducerGreedySearchNeMoDecoder
: public OfflineTransducerDecoder {
public:
OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model,
float blank_penalty)
: model_(model), blank_penalty_(blank_penalty) {}
float blank_penalty, bool is_tdt)
: model_(model), blank_penalty_(blank_penalty), is_tdt_(is_tdt) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
... ... @@ -26,6 +26,7 @@ class OfflineTransducerGreedySearchNeMoDecoder
private:
OfflineTransducerNeMoModel *model_; // Not owned
float blank_penalty_;
bool is_tdt_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -163,6 +163,7 @@ class OfflineTransducerNeMoModel::Impl {
std::string FeatureNormalizationMethod() const { return normalize_type_; }
bool IsGigaAM() const { return is_giga_am_; }
bool IsTDT() const { return is_tdt_; }
int32_t FeatureDim() const { return feat_dim_; }
... ... @@ -208,6 +209,12 @@ class OfflineTransducerNeMoModel::Impl {
if (normalize_type_ == "NA") {
normalize_type_ = "";
}
std::string url;
SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(url, "url");
if (url.find("tdt") != std::string::npos) {
is_tdt_ = 1;
}
}
void InitDecoder(void *model_data, size_t model_data_length) {
... ... @@ -230,6 +237,26 @@ class OfflineTransducerNeMoModel::Impl {
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
&joiner_output_names_ptr_);
auto shape = joiner_sess_->GetOutputTypeInfo(0)
.GetTensorTypeAndShapeInfo()
.GetShape();
int32_t output_size = shape.back();
if (is_tdt_) {
if (vocab_size_ == output_size) {
SHERPA_ONNX_LOGE("It is not a TDT model!");
SHERPA_ONNX_EXIT(-1);
}
if (config_.debug) {
SHERPA_ONNX_LOGE("TDT model. vocab_size: %d, num_durations: %d",
vocab_size_, output_size - vocab_size_);
}
} else if (vocab_size_ != output_size) {
SHERPA_ONNX_LOGE("vocab_size: %d != output_size: %d", vocab_size_,
output_size);
SHERPA_ONNX_EXIT(-1);
}
}
private:
... ... @@ -266,6 +293,7 @@ class OfflineTransducerNeMoModel::Impl {
int32_t pred_rnn_layers_ = -1;
int32_t pred_hidden_ = -1;
int32_t is_giga_am_ = 0;
int32_t is_tdt_ = 0;
// giga am uses 64
// parakeet-tdt-0.6b-v2 uses 128
... ... @@ -325,6 +353,8 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); }
bool OfflineTransducerNeMoModel::IsTDT() const { return impl_->IsTDT(); }
int32_t OfflineTransducerNeMoModel::FeatureDim() const {
return impl_->FeatureDim();
}
... ...
... ... @@ -88,6 +88,10 @@ class OfflineTransducerNeMoModel {
bool IsGigaAM() const;
// true if it is a Token-and-Duration Transducer model
// false otherwise
bool IsTDT() const;
int32_t FeatureDim() const;
private:
... ...