Committed by
GitHub
Add C++ runtime for parakeet-tdt-0.6b-v2. (#2181)
正在显示
3 个修改的文件
包含
20 行增加
和
0 行删除
| @@ -138,6 +138,12 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -138,6 +138,12 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 138 | 138 | ||
| 139 | private: | 139 | private: |
| 140 | void PostInit() { | 140 | void PostInit() { |
| 141 | + int32_t feat_dim = model_->FeatureDim(); | ||
| 142 | + | ||
| 143 | + if (feat_dim > 0) { | ||
| 144 | + config_.feat_config.feature_dim = feat_dim; | ||
| 145 | + } | ||
| 146 | + | ||
| 141 | config_.feat_config.nemo_normalize_type = | 147 | config_.feat_config.nemo_normalize_type = |
| 142 | model_->FeatureNormalizationMethod(); | 148 | model_->FeatureNormalizationMethod(); |
| 143 | 149 |
| @@ -164,6 +164,8 @@ class OfflineTransducerNeMoModel::Impl { | @@ -164,6 +164,8 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 164 | 164 | ||
| 165 | bool IsGigaAM() const { return is_giga_am_; } | 165 | bool IsGigaAM() const { return is_giga_am_; } |
| 166 | 166 | ||
| 167 | + int32_t FeatureDim() const { return feat_dim_; } | ||
| 168 | + | ||
| 167 | private: | 169 | private: |
| 168 | void InitEncoder(void *model_data, size_t model_data_length) { | 170 | void InitEncoder(void *model_data, size_t model_data_length) { |
| 169 | encoder_sess_ = std::make_unique<Ort::Session>( | 171 | encoder_sess_ = std::make_unique<Ort::Session>( |
| @@ -201,6 +203,7 @@ class OfflineTransducerNeMoModel::Impl { | @@ -201,6 +203,7 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 201 | SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); | 203 | SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); |
| 202 | SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); | 204 | SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); |
| 203 | SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); | 205 | SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); |
| 206 | + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(feat_dim_, "feat_dim", -1); | ||
| 204 | 207 | ||
| 205 | if (normalize_type_ == "NA") { | 208 | if (normalize_type_ == "NA") { |
| 206 | normalize_type_ = ""; | 209 | normalize_type_ = ""; |
| @@ -263,6 +266,11 @@ class OfflineTransducerNeMoModel::Impl { | @@ -263,6 +266,11 @@ class OfflineTransducerNeMoModel::Impl { | ||
| 263 | int32_t pred_rnn_layers_ = -1; | 266 | int32_t pred_rnn_layers_ = -1; |
| 264 | int32_t pred_hidden_ = -1; | 267 | int32_t pred_hidden_ = -1; |
| 265 | int32_t is_giga_am_ = 0; | 268 | int32_t is_giga_am_ = 0; |
| 269 | + | ||
| 270 | + // giga am uses 64 | ||
| 271 | + // parakeet-tdt-0.6b-v2 uses 128 | ||
| 272 | + // others use 80 | ||
| 273 | + int32_t feat_dim_ = -1; // -1 means to use default values. | ||
| 266 | }; | 274 | }; |
| 267 | 275 | ||
| 268 | OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( | 276 | OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( |
| @@ -317,6 +325,10 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { | @@ -317,6 +325,10 @@ std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { | ||
| 317 | 325 | ||
| 318 | bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } | 326 | bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } |
| 319 | 327 | ||
| 328 | +int32_t OfflineTransducerNeMoModel::FeatureDim() const { | ||
| 329 | + return impl_->FeatureDim(); | ||
| 330 | +} | ||
| 331 | + | ||
| 320 | #if __ANDROID_API__ >= 9 | 332 | #if __ANDROID_API__ >= 9 |
| 321 | template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( | 333 | template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( |
| 322 | AAssetManager *mgr, const OfflineModelConfig &config); | 334 | AAssetManager *mgr, const OfflineModelConfig &config); |
| @@ -88,6 +88,8 @@ class OfflineTransducerNeMoModel { | @@ -88,6 +88,8 @@ class OfflineTransducerNeMoModel { | ||
| 88 | 88 | ||
| 89 | bool IsGigaAM() const; | 89 | bool IsGigaAM() const; |
| 90 | 90 | ||
| 91 | + int32_t FeatureDim() const; | ||
| 92 | + | ||
| 91 | private: | 93 | private: |
| 92 | class Impl; | 94 | class Impl; |
| 93 | std::unique_ptr<Impl> impl_; | 95 | std::unique_ptr<Impl> impl_; |
-
请 注册 或 登录 后发表评论