Fangjun Kuang
Committed by GitHub

Support TDT transducer decoding (#2495)

@@ -207,6 +207,7 @@ def main(): @@ -207,6 +207,7 @@ def main():
207 for line in f: 207 for line in f:
208 t, idx = line.split() 208 t, idx = line.split()
209 id2token[int(idx)] = t 209 id2token[int(idx)] = t
  210 + vocab_size = len(id2token)
210 211
211 start = time.time() 212 start = time.time()
212 fbank = create_fbank() 213 fbank = create_fbank()
@@ -242,12 +243,21 @@ def main(): @@ -242,12 +243,21 @@ def main():
242 243
243 encoder_out = model.run_encoder(features) 244 encoder_out = model.run_encoder(features)
244 # encoder_out:[batch_size, dim, T) 245 # encoder_out:[batch_size, dim, T)
245 - for t in range(encoder_out.shape[2]): 246 + t = 0
  247 + while t < encoder_out.shape[2]:
246 encoder_out_t = encoder_out[:, :, t : t + 1] 248 encoder_out_t = encoder_out[:, :, t : t + 1]
247 logits = model.run_joiner(encoder_out_t, decoder_out) 249 logits = model.run_joiner(encoder_out_t, decoder_out)
248 logits = torch.from_numpy(logits) 250 logits = torch.from_numpy(logits)
249 logits = logits.squeeze() 251 logits = logits.squeeze()
250 - idx = torch.argmax(logits, dim=-1).item() 252 +
  253 + token_logits = logits[:vocab_size]
  254 + duration_logits = logits[vocab_size:]
  255 +
  256 + idx = torch.argmax(token_logits, dim=-1).item()
  257 + skip = torch.argmax(duration_logits, dim=-1).item()
  258 + if skip == 0:
  259 + skip = 1
  260 +
251 if idx != blank: 261 if idx != blank:
252 ans.append(idx) 262 ans.append(idx)
253 state0 = state0_next 263 state0 = state0_next
@@ -255,6 +265,7 @@ def main(): @@ -255,6 +265,7 @@ def main():
255 decoder_out, state0_next, state1_next = model.run_decoder( 265 decoder_out, state0_next, state1_next = model.run_decoder(
256 ans[-1], state0, state1 266 ans[-1], state0, state1
257 ) 267 )
  268 + t += skip
258 269
259 end = time.time() 270 end = time.time()
260 271
@@ -43,7 +43,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -43,7 +43,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
43 config_.model_config)) { 43 config_.model_config)) {
44 if (config_.decoding_method == "greedy_search") { 44 if (config_.decoding_method == "greedy_search") {
45 decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>( 45 decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>(
46 - model_.get(), config_.blank_penalty); 46 + model_.get(), config_.blank_penalty, model_->IsTDT());
47 } else { 47 } else {
48 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 48 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
49 config_.decoding_method.c_str()); 49 config_.decoding_method.c_str());
@@ -94,6 +94,72 @@ static OfflineTransducerDecoderResult DecodeOne( @@ -94,6 +94,72 @@ static OfflineTransducerDecoderResult DecodeOne(
94 return ans; 94 return ans;
95 } 95 }
96 96
  97 +static OfflineTransducerDecoderResult DecodeOneTDT(
  98 + const float *p, int32_t num_rows, int32_t num_cols,
  99 + OfflineTransducerNeMoModel *model, float blank_penalty) {
  100 + auto memory_info =
  101 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  102 +
  103 + OfflineTransducerDecoderResult ans;
  104 +
  105 + int32_t vocab_size = model->VocabSize();
  106 + int32_t blank_id = vocab_size - 1;
  107 +
  108 + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator());
  109 +
  110 + std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
  111 + model->RunDecoder(std::move(decoder_input_pair.first),
  112 + std::move(decoder_input_pair.second),
  113 + model->GetDecoderInitStates(1));
  114 +
  115 + std::array<int64_t, 3> encoder_shape{1, num_cols, 1};
  116 +
  117 + int32_t skip = 0;
  118 + for (int32_t t = 0; t < num_rows; t += skip) {
  119 + Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
  120 + memory_info, const_cast<float *>(p) + t * num_cols, num_cols,
  121 + encoder_shape.data(), encoder_shape.size());
  122 +
  123 + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out),
  124 + View(&decoder_output_pair.first));
  125 +
  126 + auto shape = logit.GetTensorTypeAndShapeInfo().GetShape();
  127 +
  128 + float *p_logit = logit.GetTensorMutableData<float>();
  129 + if (blank_penalty > 0) {
  130 + p_logit[blank_id] -= blank_penalty;
  131 + }
  132 +
  133 + auto y = static_cast<int32_t>(std::distance(
  134 + static_cast<const float *>(p_logit),
  135 + std::max_element(static_cast<const float *>(p_logit),
  136 + static_cast<const float *>(p_logit) + vocab_size)));
  137 +
  138 + skip = static_cast<int32_t>(std::distance(
  139 + static_cast<const float *>(p_logit) + vocab_size,
  140 + std::max_element(static_cast<const float *>(p_logit) + vocab_size,
  141 + static_cast<const float *>(p_logit) + shape.back())));
  142 +
  143 + if (skip == 0) {
  144 + skip = 1;
  145 + }
  146 +
  147 + if (y != blank_id) {
  148 + ans.tokens.push_back(y);
  149 + ans.timestamps.push_back(t);
  150 +
  151 + decoder_input_pair = BuildDecoderInput(y, model->Allocator());
  152 +
  153 + decoder_output_pair =
  154 + model->RunDecoder(std::move(decoder_input_pair.first),
  155 + std::move(decoder_input_pair.second),
  156 + std::move(decoder_output_pair.second));
  157 + }
  158 + } // for (int32_t t = 0; t < num_rows; ++t) {
  159 +
  160 + return ans;
  161 +}
  162 +
97 std::vector<OfflineTransducerDecoderResult> 163 std::vector<OfflineTransducerDecoderResult>
98 OfflineTransducerGreedySearchNeMoDecoder::Decode( 164 OfflineTransducerGreedySearchNeMoDecoder::Decode(
99 Ort::Value encoder_out, Ort::Value encoder_out_length, 165 Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -123,7 +189,11 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( @@ -123,7 +189,11 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode(
123 ? encoder_out_length.GetTensorData<int32_t>()[i] 189 ? encoder_out_length.GetTensorData<int32_t>()[i]
124 : encoder_out_length.GetTensorData<int64_t>()[i]; 190 : encoder_out_length.GetTensorData<int64_t>()[i];
125 191
126 - ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); 192 + if (is_tdt_) {
  193 + ans[i] = DecodeOneTDT(this_p, this_len, dim2, model_, blank_penalty_);
  194 + } else {
  195 + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_);
  196 + }
127 } 197 }
128 198
129 return ans; 199 return ans;
@@ -16,8 +16,8 @@ class OfflineTransducerGreedySearchNeMoDecoder @@ -16,8 +16,8 @@ class OfflineTransducerGreedySearchNeMoDecoder
16 : public OfflineTransducerDecoder { 16 : public OfflineTransducerDecoder {
17 public: 17 public:
18 OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model, 18 OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model,
19 - float blank_penalty)  
20 - : model_(model), blank_penalty_(blank_penalty) {} 19 + float blank_penalty, bool is_tdt)
  20 + : model_(model), blank_penalty_(blank_penalty), is_tdt_(is_tdt) {}
21 21
22 std::vector<OfflineTransducerDecoderResult> Decode( 22 std::vector<OfflineTransducerDecoderResult> Decode(
23 Ort::Value encoder_out, Ort::Value encoder_out_length, 23 Ort::Value encoder_out, Ort::Value encoder_out_length,
@@ -26,6 +26,7 @@ class OfflineTransducerGreedySearchNeMoDecoder @@ -26,6 +26,7 @@ class OfflineTransducerGreedySearchNeMoDecoder
26 private: 26 private:
27 OfflineTransducerNeMoModel *model_; // Not owned 27 OfflineTransducerNeMoModel *model_; // Not owned
28 float blank_penalty_; 28 float blank_penalty_;
  29 + bool is_tdt_;
29 }; 30 };
30 31
31 } // namespace sherpa_onnx 32 } // namespace sherpa_onnx
@@ -163,6 +163,7 @@ class OfflineTransducerNeMoModel::Impl { @@ -163,6 +163,7 @@ class OfflineTransducerNeMoModel::Impl {
163 std::string FeatureNormalizationMethod() const { return normalize_type_; } 163 std::string FeatureNormalizationMethod() const { return normalize_type_; }
164 164
165 bool IsGigaAM() const { return is_giga_am_; } 165 bool IsGigaAM() const { return is_giga_am_; }
  166 + bool IsTDT() const { return is_tdt_; }
166 167
167 int32_t FeatureDim() const { return feat_dim_; } 168 int32_t FeatureDim() const { return feat_dim_; }
168 169
@@ -208,6 +209,12 @@ class OfflineTransducerNeMoModel::Impl { @@ -208,6 +209,12 @@ class OfflineTransducerNeMoModel::Impl {
208 if (normalize_type_ == "NA") { 209 if (normalize_type_ == "NA") {
209 normalize_type_ = ""; 210 normalize_type_ = "";
210 } 211 }
  212 +
  213 + std::string url;
  214 + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(url, "url");
  215 + if (url.find("tdt") != std::string::npos) {
  216 + is_tdt_ = 1;
  217 + }
211 } 218 }
212 219
213 void InitDecoder(void *model_data, size_t model_data_length) { 220 void InitDecoder(void *model_data, size_t model_data_length) {
@@ -230,6 +237,26 @@ class OfflineTransducerNeMoModel::Impl { @@ -230,6 +237,26 @@ class OfflineTransducerNeMoModel::Impl {
230 237
231 GetOutputNames(joiner_sess_.get(), &joiner_output_names_, 238 GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
232 &joiner_output_names_ptr_); 239 &joiner_output_names_ptr_);
  240 +
  241 + auto shape = joiner_sess_->GetOutputTypeInfo(0)
  242 + .GetTensorTypeAndShapeInfo()
  243 + .GetShape();
  244 + int32_t output_size = shape.back();
  245 + if (is_tdt_) {
  246 + if (vocab_size_ == output_size) {
  247 + SHERPA_ONNX_LOGE("It is not a TDT model!");
  248 + SHERPA_ONNX_EXIT(-1);
  249 + }
  250 +
  251 + if (config_.debug) {
  252 + SHERPA_ONNX_LOGE("TDT model. vocab_size: %d, num_durations: %d",
  253 + vocab_size_, output_size - vocab_size_);
  254 + }
  255 + } else if (vocab_size_ != output_size) {
  256 + SHERPA_ONNX_LOGE("vocab_size: %d != output_size: %d", vocab_size_,
  257 + output_size);
  258 + SHERPA_ONNX_EXIT(-1);
  259 + }
233 } 260 }
234 261
235 private: 262 private:
@@ -266,6 +293,7 @@ class OfflineTransducerNeMoModel::Impl { @@ -266,6 +293,7 @@ class OfflineTransducerNeMoModel::Impl {
266 int32_t pred_rnn_layers_ = -1; 293 int32_t pred_rnn_layers_ = -1;
267 int32_t pred_hidden_ = -1; 294 int32_t pred_hidden_ = -1;
268 int32_t is_giga_am_ = 0; 295 int32_t is_giga_am_ = 0;
  296 + int32_t is_tdt_ = 0;
269 297
270 // giga am uses 64 298 // giga am uses 64
271 // parakeet-tdt-0.6b-v2 uses 128 299 // parakeet-tdt-0.6b-v2 uses 128
@@ -325,6 +353,8 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { @@ -325,6 +353,8 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
325 353
326 bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } 354 bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); }
327 355
  356 +bool OfflineTransducerNeMoModel::IsTDT() const { return impl_->IsTDT(); }
  357 +
328 int32_t OfflineTransducerNeMoModel::FeatureDim() const { 358 int32_t OfflineTransducerNeMoModel::FeatureDim() const {
329 return impl_->FeatureDim(); 359 return impl_->FeatureDim();
330 } 360 }
@@ -88,6 +88,10 @@ class OfflineTransducerNeMoModel { @@ -88,6 +88,10 @@ class OfflineTransducerNeMoModel {
88 88
89 bool IsGigaAM() const; 89 bool IsGigaAM() const;
90 90
  91 + // true if it is a Token-and-Duration Transducer model
  92 + // false otherwise
  93 + bool IsTDT() const;
  94 +
91 int32_t FeatureDim() const; 95 int32_t FeatureDim() const;
92 96
93 private: 97 private: