Fangjun Kuang
Committed by GitHub

Add C++ runtime for parakeet-tdt-0.6b-v2. (#2181)

@@ -138,6 +138,12 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -138,6 +138,12 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
138 138
139 private: 139 private:
140 void PostInit() { 140 void PostInit() {
  141 + int32_t feat_dim = model_->FeatureDim();
  142 +
  143 + if (feat_dim > 0) {
  144 + config_.feat_config.feature_dim = feat_dim;
  145 + }
  146 +
141 config_.feat_config.nemo_normalize_type = 147 config_.feat_config.nemo_normalize_type =
142 model_->FeatureNormalizationMethod(); 148 model_->FeatureNormalizationMethod();
143 149
@@ -164,6 +164,8 @@ class OfflineTransducerNeMoModel::Impl { @@ -164,6 +164,8 @@ class OfflineTransducerNeMoModel::Impl {
164 164
165 bool IsGigaAM() const { return is_giga_am_; } 165 bool IsGigaAM() const { return is_giga_am_; }
166 166
  167 + int32_t FeatureDim() const { return feat_dim_; }
  168 +
167 private: 169 private:
168 void InitEncoder(void *model_data, size_t model_data_length) { 170 void InitEncoder(void *model_data, size_t model_data_length) {
169 encoder_sess_ = std::make_unique<Ort::Session>( 171 encoder_sess_ = std::make_unique<Ort::Session>(
@@ -201,6 +203,7 @@ class OfflineTransducerNeMoModel::Impl { @@ -201,6 +203,7 @@ class OfflineTransducerNeMoModel::Impl {
201 SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); 203 SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers");
202 SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); 204 SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden");
203 SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); 205 SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0);
  206 + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(feat_dim_, "feat_dim", -1);
204 207
205 if (normalize_type_ == "NA") { 208 if (normalize_type_ == "NA") {
206 normalize_type_ = ""; 209 normalize_type_ = "";
@@ -263,6 +266,11 @@ class OfflineTransducerNeMoModel::Impl { @@ -263,6 +266,11 @@ class OfflineTransducerNeMoModel::Impl {
263 int32_t pred_rnn_layers_ = -1; 266 int32_t pred_rnn_layers_ = -1;
264 int32_t pred_hidden_ = -1; 267 int32_t pred_hidden_ = -1;
265 int32_t is_giga_am_ = 0; 268 int32_t is_giga_am_ = 0;
  269 +
  270 + // giga am uses 64
  271 + // parakeet-tdt-0.6b-v2 uses 128
  272 + // others use 80
  273 + int32_t feat_dim_ = -1; // -1 means to use default values.
266 }; 274 };
267 275
268 OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( 276 OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
@@ -317,6 +325,10 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { @@ -317,6 +325,10 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const {
317 325
318 bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } 326 bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); }
319 327
  328 +int32_t OfflineTransducerNeMoModel::FeatureDim() const {
  329 + return impl_->FeatureDim();
  330 +}
  331 +
320 #if __ANDROID_API__ >= 9 332 #if __ANDROID_API__ >= 9
321 template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( 333 template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel(
322 AAssetManager *mgr, const OfflineModelConfig &config); 334 AAssetManager *mgr, const OfflineModelConfig &config);
@@ -88,6 +88,8 @@ class OfflineTransducerNeMoModel { @@ -88,6 +88,8 @@ class OfflineTransducerNeMoModel {
88 88
89 bool IsGigaAM() const; 89 bool IsGigaAM() const;
90 90
  91 + int32_t FeatureDim() const;
  92 +
91 private: 93 private:
92 class Impl; 94 class Impl;
93 std::unique_ptr<Impl> impl_; 95 std::unique_ptr<Impl> impl_;