正在显示
6 个修改的文件
包含
122 行增加
和
6 行删除
| @@ -207,6 +207,7 @@ def main(): | @@ -207,6 +207,7 @@ def main(): | ||
| 207 | for line in f: | 207 | for line in f: |
| 208 | t, idx = line.split() | 208 | t, idx = line.split() |
| 209 | id2token[int(idx)] = t | 209 | id2token[int(idx)] = t |
| 210 | + vocab_size = len(id2token) | ||
| 210 | 211 | ||
| 211 | start = time.time() | 212 | start = time.time() |
| 212 | fbank = create_fbank() | 213 | fbank = create_fbank() |
| @@ -242,12 +243,21 @@ def main(): | @@ -242,12 +243,21 @@ def main(): | ||
| 242 | 243 | ||
| 243 | encoder_out = model.run_encoder(features) | 244 | encoder_out = model.run_encoder(features) |
| 244 | # encoder_out:[batch_size, dim, T) | 245 | # encoder_out:[batch_size, dim, T) |
| 245 | - for t in range(encoder_out.shape[2]): | 246 | + t = 0 |
| 247 | + while t < encoder_out.shape[2]: | ||
| 246 | encoder_out_t = encoder_out[:, :, t : t + 1] | 248 | encoder_out_t = encoder_out[:, :, t : t + 1] |
| 247 | logits = model.run_joiner(encoder_out_t, decoder_out) | 249 | logits = model.run_joiner(encoder_out_t, decoder_out) |
| 248 | logits = torch.from_numpy(logits) | 250 | logits = torch.from_numpy(logits) |
| 249 | logits = logits.squeeze() | 251 | logits = logits.squeeze() |
| 250 | - idx = torch.argmax(logits, dim=-1).item() | 252 | + |
| 253 | + token_logits = logits[:vocab_size] | ||
| 254 | + duration_logits = logits[vocab_size:] | ||
| 255 | + | ||
| 256 | + idx = torch.argmax(token_logits, dim=-1).item() | ||
| 257 | + skip = torch.argmax(duration_logits, dim=-1).item() | ||
| 258 | + if skip == 0: | ||
| 259 | + skip = 1 | ||
| 260 | + | ||
| 251 | if idx != blank: | 261 | if idx != blank: |
| 252 | ans.append(idx) | 262 | ans.append(idx) |
| 253 | state0 = state0_next | 263 | state0 = state0_next |
| @@ -255,6 +265,7 @@ def main(): | @@ -255,6 +265,7 @@ def main(): | ||
| 255 | decoder_out, state0_next, state1_next = model.run_decoder( | 265 | decoder_out, state0_next, state1_next = model.run_decoder( |
| 256 | ans[-1], state0, state1 | 266 | ans[-1], state0, state1 |
| 257 | ) | 267 | ) |
| 268 | + t += skip | ||
| 258 | 269 | ||
| 259 | end = time.time() | 270 | end = time.time() |
| 260 | 271 |
| @@ -43,7 +43,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -43,7 +43,7 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 43 | config_.model_config)) { | 43 | config_.model_config)) { |
| 44 | if (config_.decoding_method == "greedy_search") { | 44 | if (config_.decoding_method == "greedy_search") { |
| 45 | decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>( | 45 | decoder_ = std::make_unique<OfflineTransducerGreedySearchNeMoDecoder>( |
| 46 | - model_.get(), config_.blank_penalty); | 46 | + model_.get(), config_.blank_penalty, model_->IsTDT()); |
| 47 | } else { | 47 | } else { |
| 48 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 48 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 49 | config_.decoding_method.c_str()); | 49 | config_.decoding_method.c_str()); |
| @@ -94,6 +94,72 @@ static OfflineTransducerDecoderResult DecodeOne( | @@ -94,6 +94,72 @@ static OfflineTransducerDecoderResult DecodeOne( | ||
| 94 | return ans; | 94 | return ans; |
| 95 | } | 95 | } |
| 96 | 96 | ||
| 97 | +static OfflineTransducerDecoderResult DecodeOneTDT( | ||
| 98 | + const float *p, int32_t num_rows, int32_t num_cols, | ||
| 99 | + OfflineTransducerNeMoModel *model, float blank_penalty) { | ||
| 100 | + auto memory_info = | ||
| 101 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 102 | + | ||
| 103 | + OfflineTransducerDecoderResult ans; | ||
| 104 | + | ||
| 105 | + int32_t vocab_size = model->VocabSize(); | ||
| 106 | + int32_t blank_id = vocab_size - 1; | ||
| 107 | + | ||
| 108 | + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); | ||
| 109 | + | ||
| 110 | + std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair = | ||
| 111 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 112 | + std::move(decoder_input_pair.second), | ||
| 113 | + model->GetDecoderInitStates(1)); | ||
| 114 | + | ||
| 115 | + std::array<int64_t, 3> encoder_shape{1, num_cols, 1}; | ||
| 116 | + | ||
| 117 | + int32_t skip = 0; | ||
| 118 | + for (int32_t t = 0; t < num_rows; t += skip) { | ||
| 119 | + Ort::Value cur_encoder_out = Ort::Value::CreateTensor( | ||
| 120 | + memory_info, const_cast<float *>(p) + t * num_cols, num_cols, | ||
| 121 | + encoder_shape.data(), encoder_shape.size()); | ||
| 122 | + | ||
| 123 | + Ort::Value logit = model->RunJoiner(View(&cur_encoder_out), | ||
| 124 | + View(&decoder_output_pair.first)); | ||
| 125 | + | ||
| 126 | + auto shape = logit.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 127 | + | ||
| 128 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 129 | + if (blank_penalty > 0) { | ||
| 130 | + p_logit[blank_id] -= blank_penalty; | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + auto y = static_cast<int32_t>(std::distance( | ||
| 134 | + static_cast<const float *>(p_logit), | ||
| 135 | + std::max_element(static_cast<const float *>(p_logit), | ||
| 136 | + static_cast<const float *>(p_logit) + vocab_size))); | ||
| 137 | + | ||
| 138 | + skip = static_cast<int32_t>(std::distance( | ||
| 139 | + static_cast<const float *>(p_logit) + vocab_size, | ||
| 140 | + std::max_element(static_cast<const float *>(p_logit) + vocab_size, | ||
| 141 | + static_cast<const float *>(p_logit) + shape.back()))); | ||
| 142 | + | ||
| 143 | + if (skip == 0) { | ||
| 144 | + skip = 1; | ||
| 145 | + } | ||
| 146 | + | ||
| 147 | + if (y != blank_id) { | ||
| 148 | + ans.tokens.push_back(y); | ||
| 149 | + ans.timestamps.push_back(t); | ||
| 150 | + | ||
| 151 | + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); | ||
| 152 | + | ||
| 153 | + decoder_output_pair = | ||
| 154 | + model->RunDecoder(std::move(decoder_input_pair.first), | ||
| 155 | + std::move(decoder_input_pair.second), | ||
| 156 | + std::move(decoder_output_pair.second)); | ||
| 157 | + } | ||
| 158 | + } // for (int32_t t = 0; t < num_rows; ++t) { | ||
| 159 | + | ||
| 160 | + return ans; | ||
| 161 | +} | ||
| 162 | + | ||
| 97 | std::vector<OfflineTransducerDecoderResult> | 163 | std::vector<OfflineTransducerDecoderResult> |
| 98 | OfflineTransducerGreedySearchNeMoDecoder::Decode( | 164 | OfflineTransducerGreedySearchNeMoDecoder::Decode( |
| 99 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 165 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -123,7 +189,11 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | @@ -123,7 +189,11 @@ OfflineTransducerGreedySearchNeMoDecoder::Decode( | ||
| 123 | ? encoder_out_length.GetTensorData<int32_t>()[i] | 189 | ? encoder_out_length.GetTensorData<int32_t>()[i] |
| 124 | : encoder_out_length.GetTensorData<int64_t>()[i]; | 190 | : encoder_out_length.GetTensorData<int64_t>()[i]; |
| 125 | 191 | ||
| 126 | - ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); | 192 | + if (is_tdt_) { |
| 193 | + ans[i] = DecodeOneTDT(this_p, this_len, dim2, model_, blank_penalty_); | ||
| 194 | + } else { | ||
| 195 | + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); | ||
| 196 | + } | ||
| 127 | } | 197 | } |
| 128 | 198 | ||
| 129 | return ans; | 199 | return ans; |
| @@ -16,8 +16,8 @@ class OfflineTransducerGreedySearchNeMoDecoder | @@ -16,8 +16,8 @@ class OfflineTransducerGreedySearchNeMoDecoder | ||
| 16 | : public OfflineTransducerDecoder { | 16 | : public OfflineTransducerDecoder { |
| 17 | public: | 17 | public: |
| 18 | OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model, | 18 | OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model, |
| 19 | - float blank_penalty) | ||
| 20 | - : model_(model), blank_penalty_(blank_penalty) {} | 19 | + float blank_penalty, bool is_tdt) |
| 20 | + : model_(model), blank_penalty_(blank_penalty), is_tdt_(is_tdt) {} | ||
| 21 | 21 | ||
| 22 | std::vector<OfflineTransducerDecoderResult> Decode( | 22 | std::vector<OfflineTransducerDecoderResult> Decode( |
| 23 | Ort::Value encoder_out, Ort::Value encoder_out_length, | 23 | Ort::Value encoder_out, Ort::Value encoder_out_length, |
| @@ -26,6 +26,7 @@ class OfflineTransducerGreedySearchNeMoDecoder | @@ -26,6 +26,7 @@ class OfflineTransducerGreedySearchNeMoDecoder | ||
| 26 | private: | 26 | private: |
| 27 | OfflineTransducerNeMoModel *model_; // Not owned | 27 | OfflineTransducerNeMoModel *model_; // Not owned |
| 28 | float blank_penalty_; | 28 | float blank_penalty_; |
| 29 | + bool is_tdt_; | ||
| 29 | }; | 30 | }; |
| 30 | 31 | ||
| 31 | } // namespace sherpa_onnx | 32 | } // namespace sherpa_onnx |
| @@ -163,6 +163,7 @@ class OfflineTransducerNeMoModel::Impl { | @@ -163,6 +163,7 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 163 | std::string FeatureNormalizationMethod() const { return normalize_type_; } | 163 | std::string FeatureNormalizationMethod() const { return normalize_type_; } |
| 164 | 164 | ||
| 165 | bool IsGigaAM() const { return is_giga_am_; } | 165 | bool IsGigaAM() const { return is_giga_am_; } |
| 166 | + bool IsTDT() const { return is_tdt_; } | ||
| 166 | 167 | ||
| 167 | int32_t FeatureDim() const { return feat_dim_; } | 168 | int32_t FeatureDim() const { return feat_dim_; } |
| 168 | 169 | ||
| @@ -208,6 +209,12 @@ class OfflineTransducerNeMoModel::Impl { | @@ -208,6 +209,12 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 208 | if (normalize_type_ == "NA") { | 209 | if (normalize_type_ == "NA") { |
| 209 | normalize_type_ = ""; | 210 | normalize_type_ = ""; |
| 210 | } | 211 | } |
| 212 | + | ||
| 213 | + std::string url; | ||
| 214 | + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(url, "url"); | ||
| 215 | + if (url.find("tdt") != std::string::npos) { | ||
| 216 | + is_tdt_ = 1; | ||
| 217 | + } | ||
| 211 | } | 218 | } |
| 212 | 219 | ||
| 213 | void InitDecoder(void *model_data, size_t model_data_length) { | 220 | void InitDecoder(void *model_data, size_t model_data_length) { |
| @@ -230,6 +237,26 @@ class OfflineTransducerNeMoModel::Impl { | @@ -230,6 +237,26 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 230 | 237 | ||
| 231 | GetOutputNames(joiner_sess_.get(), &joiner_output_names_, | 238 | GetOutputNames(joiner_sess_.get(), &joiner_output_names_, |
| 232 | &joiner_output_names_ptr_); | 239 | &joiner_output_names_ptr_); |
| 240 | + | ||
| 241 | + auto shape = joiner_sess_->GetOutputTypeInfo(0) | ||
| 242 | + .GetTensorTypeAndShapeInfo() | ||
| 243 | + .GetShape(); | ||
| 244 | + int32_t output_size = shape.back(); | ||
| 245 | + if (is_tdt_) { | ||
| 246 | + if (vocab_size_ == output_size) { | ||
| 247 | + SHERPA_ONNX_LOGE("It is not a TDT model!"); | ||
| 248 | + SHERPA_ONNX_EXIT(-1); | ||
| 249 | + } | ||
| 250 | + | ||
| 251 | + if (config_.debug) { | ||
| 252 | + SHERPA_ONNX_LOGE("TDT model. vocab_size: %d, num_durations: %d", | ||
| 253 | + vocab_size_, output_size - vocab_size_); | ||
| 254 | + } | ||
| 255 | + } else if (vocab_size_ != output_size) { | ||
| 256 | + SHERPA_ONNX_LOGE("vocab_size: %d != output_size: %d", vocab_size_, | ||
| 257 | + output_size); | ||
| 258 | + SHERPA_ONNX_EXIT(-1); | ||
| 259 | + } | ||
| 233 | } | 260 | } |
| 234 | 261 | ||
| 235 | private: | 262 | private: |
| @@ -266,6 +293,7 @@ class OfflineTransducerNeMoModel::Impl { | @@ -266,6 +293,7 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 266 | int32_t pred_rnn_layers_ = -1; | 293 | int32_t pred_rnn_layers_ = -1; |
| 267 | int32_t pred_hidden_ = -1; | 294 | int32_t pred_hidden_ = -1; |
| 268 | int32_t is_giga_am_ = 0; | 295 | int32_t is_giga_am_ = 0; |
| 296 | + int32_t is_tdt_ = 0; | ||
| 269 | 297 | ||
| 270 | // giga am uses 64 | 298 | // giga am uses 64 |
| 271 | // parakeet-tdt-0.6b-v2 uses 128 | 299 | // parakeet-tdt-0.6b-v2 uses 128 |
| @@ -325,6 +353,8 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { | @@ -325,6 +353,8 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { | ||
| 325 | 353 | ||
| 326 | bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } | 354 | bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } |
| 327 | 355 | ||
| 356 | +bool OfflineTransducerNeMoModel::IsTDT() const { return impl_->IsTDT(); } | ||
| 357 | + | ||
| 328 | int32_t OfflineTransducerNeMoModel::FeatureDim() const { | 358 | int32_t OfflineTransducerNeMoModel::FeatureDim() const { |
| 329 | return impl_->FeatureDim(); | 359 | return impl_->FeatureDim(); |
| 330 | } | 360 | } |
| @@ -88,6 +88,10 @@ class OfflineTransducerNeMoModel { | @@ -88,6 +88,10 @@ class OfflineTransducerNeMoModel { | ||
| 88 | 88 | ||
| 89 | bool IsGigaAM() const; | 89 | bool IsGigaAM() const; |
| 90 | 90 | ||
| 91 | + // true if it is a Token-and-Duration Transducer model | ||
| 92 | + // false otherwise | ||
| 93 | + bool IsTDT() const; | ||
| 94 | + | ||
| 91 | int32_t FeatureDim() const; | 95 | int32_t FeatureDim() const; |
| 92 | 96 | ||
| 93 | private: | 97 | private: |
-
请 注册 或 登录 后发表评论