Committed by
GitHub
Allow modify model config at decode time for ASR (#1124)
正在显示
15 个修改的文件
包含
121 行增加
和
13 行删除
| @@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream { | @@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream { | ||
| 308 | : impl(std::move(p)) {} | 308 | : impl(std::move(p)) {} |
| 309 | }; | 309 | }; |
| 310 | 310 | ||
| 311 | +static sherpa_onnx::OfflineRecognizerConfig convertConfig( | ||
| 312 | + const SherpaOnnxOfflineRecognizerConfig *config); | ||
| 311 | SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | 313 | SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( |
| 312 | const SherpaOnnxOfflineRecognizerConfig *config) { | 314 | const SherpaOnnxOfflineRecognizerConfig *config) { |
| 315 | + sherpa_onnx::OfflineRecognizerConfig recognizer_config = | ||
| 316 | + convertConfig(config); | ||
| 317 | + | ||
| 318 | + if (!recognizer_config.Validate()) { | ||
| 319 | + SHERPA_ONNX_LOGE("Errors in config"); | ||
| 320 | + return nullptr; | ||
| 321 | + } | ||
| 322 | + | ||
| 323 | + SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer; | ||
| 324 | + | ||
| 325 | + recognizer->impl = | ||
| 326 | + std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config); | ||
| 327 | + | ||
| 328 | + return recognizer; | ||
| 329 | +} | ||
| 330 | +sherpa_onnx::OfflineRecognizerConfig convertConfig( | ||
| 331 | + const SherpaOnnxOfflineRecognizerConfig *config) { | ||
| 313 | sherpa_onnx::OfflineRecognizerConfig recognizer_config; | 332 | sherpa_onnx::OfflineRecognizerConfig recognizer_config; |
| 314 | 333 | ||
| 315 | recognizer_config.feat_config.sampling_rate = | 334 | recognizer_config.feat_config.sampling_rate = |
| @@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | @@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | ||
| 398 | SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str()); | 417 | SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str()); |
| 399 | } | 418 | } |
| 400 | 419 | ||
| 401 | - if (!recognizer_config.Validate()) { | ||
| 402 | - SHERPA_ONNX_LOGE("Errors in config"); | ||
| 403 | - return nullptr; | ||
| 404 | - } | ||
| 405 | - | ||
| 406 | - SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer; | ||
| 407 | - | ||
| 408 | - recognizer->impl = | ||
| 409 | - std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config); | 420 | + return recognizer_config; |
| 421 | +} | ||
| 410 | 422 | ||
| 411 | - return recognizer; | 423 | +void SherpaOnnxOfflineRecognizerSetConfig( |
| 424 | + const SherpaOnnxOfflineRecognizer *recognizer, | ||
| 425 | + const SherpaOnnxOfflineRecognizerConfig *config){ | ||
| 426 | + sherpa_onnx::OfflineRecognizerConfig recognizer_config = | ||
| 427 | + convertConfig(config); | ||
| 428 | + recognizer->impl->SetConfig(recognizer_config); | ||
| 412 | } | 429 | } |
| 413 | 430 | ||
| 414 | void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { | 431 | void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { |
| @@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( | @@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( | ||
| 461 | pText[text.size()] = 0; | 478 | pText[text.size()] = 0; |
| 462 | r->text = pText; | 479 | r->text = pText; |
| 463 | 480 | ||
| 481 | + //lang | ||
| 482 | + const auto &lang = result.lang; | ||
| 483 | + char *c_lang = new char[lang.size() + 1]; | ||
| 484 | + std::copy(lang.begin(), lang.end(), c_lang); | ||
| 485 | + c_lang[lang.size()] = '\0'; | ||
| 486 | + r->lang = c_lang; | ||
| 487 | + | ||
| 464 | // copy json | 488 | // copy json |
| 465 | std::string json = result.AsJsonString(); | 489 | std::string json = result.AsJsonString(); |
| 466 | char *pJson = new char[json.size() + 1]; | 490 | char *pJson = new char[json.size() + 1]; |
| @@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult( | @@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult( | ||
| 517 | delete[] r->tokens; | 541 | delete[] r->tokens; |
| 518 | delete[] r->tokens_arr; | 542 | delete[] r->tokens_arr; |
| 519 | delete[] r->json; | 543 | delete[] r->json; |
| 544 | + delete[] r->lang; | ||
| 520 | delete r; | 545 | delete r; |
| 521 | } | 546 | } |
| 522 | } | 547 | } |
| @@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream; | @@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream; | ||
| 428 | SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | 428 | SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( |
| 429 | const SherpaOnnxOfflineRecognizerConfig *config); | 429 | const SherpaOnnxOfflineRecognizerConfig *config); |
| 430 | 430 | ||
| 431 | +/// @param config Config for the recognizer. | ||
| 432 | +SHERPA_ONNX_API void SherpaOnnxOfflineRecognizerSetConfig( | ||
| 433 | + const SherpaOnnxOfflineRecognizer *recognizer, | ||
| 434 | + const SherpaOnnxOfflineRecognizerConfig *config); | ||
| 435 | + | ||
| 431 | /// Free a pointer returned by CreateOfflineRecognizer() | 436 | /// Free a pointer returned by CreateOfflineRecognizer() |
| 432 | /// | 437 | /// |
| 433 | /// @param p A pointer returned by CreateOfflineRecognizer() | 438 | /// @param p A pointer returned by CreateOfflineRecognizer() |
| @@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( | @@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( | ||
| 491 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | 496 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { |
| 492 | const char *text; | 497 | const char *text; |
| 493 | 498 | ||
| 494 | - // Pointer to continuous memory which holds timestamps | 499 | + // Pointer to continuous memory which holds timestamps |
| 495 | // | 500 | // |
| 496 | // It is NULL if the model does not support timestamps | 501 | // It is NULL if the model does not support timestamps |
| 497 | float *timestamps; | 502 | float *timestamps; |
| @@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | @@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { | ||
| 519 | * } | 524 | * } |
| 520 | */ | 525 | */ |
| 521 | const char *json; | 526 | const char *json; |
| 527 | + | ||
| 528 | + //return recognized language | ||
| 529 | + const char *lang; | ||
| 530 | + | ||
| 522 | } SherpaOnnxOfflineRecognizerResult; | 531 | } SherpaOnnxOfflineRecognizerResult; |
| 523 | 532 | ||
| 524 | /// Get the result of the offline stream. | 533 | /// Get the result of the offline stream. |
| @@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 212 | } | 212 | } |
| 213 | } | 213 | } |
| 214 | 214 | ||
| 215 | + OfflineRecognizerConfig GetConfig() const override { | ||
| 216 | + return config_; | ||
| 217 | + } | ||
| 218 | + | ||
| 219 | + | ||
| 215 | private: | 220 | private: |
| 216 | // Decode a single stream. | 221 | // Decode a single stream. |
| 217 | // Some models do not support batch size > 1, e.g., WeNet CTC models. | 222 | // Some models do not support batch size > 1, e.g., WeNet CTC models. |
| @@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( | @@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( | ||
| 431 | return text; | 431 | return text; |
| 432 | } | 432 | } |
| 433 | 433 | ||
| 434 | +void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { | ||
| 435 | + config_ = config; | ||
| 436 | +} | ||
| 437 | + | ||
| 434 | } // namespace sherpa_onnx | 438 | } // namespace sherpa_onnx |
| @@ -48,6 +48,10 @@ class OfflineRecognizerImpl { | @@ -48,6 +48,10 @@ class OfflineRecognizerImpl { | ||
| 48 | 48 | ||
| 49 | virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; | 49 | virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; |
| 50 | 50 | ||
| 51 | + virtual void SetConfig(const OfflineRecognizerConfig &config); | ||
| 52 | + | ||
| 53 | + virtual OfflineRecognizerConfig GetConfig() const = 0; | ||
| 54 | + | ||
| 51 | std::string ApplyInverseTextNormalization(std::string text) const; | 55 | std::string ApplyInverseTextNormalization(std::string text) const; |
| 52 | 56 | ||
| 53 | private: | 57 | private: |
| @@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | @@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 211 | } | 211 | } |
| 212 | } | 212 | } |
| 213 | 213 | ||
| 214 | + OfflineRecognizerConfig GetConfig() const override { | ||
| 215 | + return config_; | ||
| 216 | + } | ||
| 217 | + | ||
| 214 | private: | 218 | private: |
| 215 | std::vector<float> ApplyLFR(const std::vector<float> &in) const { | 219 | std::vector<float> ApplyLFR(const std::vector<float> &in) const { |
| 216 | int32_t lfr_window_size = model_->LfrWindowSize(); | 220 | int32_t lfr_window_size = model_->LfrWindowSize(); |
| @@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 246 | } | 246 | } |
| 247 | } | 247 | } |
| 248 | 248 | ||
| 249 | + OfflineRecognizerConfig GetConfig() const override { | ||
| 250 | + return config_; | ||
| 251 | + } | ||
| 252 | + | ||
| 253 | + | ||
| 249 | void InitHotwords() { | 254 | void InitHotwords() { |
| 250 | // each line in hotwords_file contains space-separated words | 255 | // each line in hotwords_file contains space-separated words |
| 251 | 256 |
| @@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | @@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { | ||
| 139 | } | 139 | } |
| 140 | } | 140 | } |
| 141 | 141 | ||
| 142 | + OfflineRecognizerConfig GetConfig() const override { | ||
| 143 | + return config_; | ||
| 144 | + } | ||
| 145 | + | ||
| 142 | private: | 146 | private: |
| 143 | void PostInit() { | 147 | void PostInit() { |
| 144 | config_.feat_config.nemo_normalize_type = | 148 | config_.feat_config.nemo_normalize_type = |
| @@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | @@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, | ||
| 45 | } | 45 | } |
| 46 | 46 | ||
| 47 | r.text = text; | 47 | r.text = text; |
| 48 | + r.lang = src.lang; | ||
| 48 | 49 | ||
| 49 | return r; | 50 | return r; |
| 50 | } | 51 | } |
| @@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | @@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { | ||
| 100 | } | 101 | } |
| 101 | } | 102 | } |
| 102 | 103 | ||
| 104 | + void SetConfig(const OfflineRecognizerConfig &config) override { | ||
| 105 | + config_.model_config.whisper = config.model_config.whisper; | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + OfflineRecognizerConfig GetConfig() const override { | ||
| 109 | + return config_; | ||
| 110 | + } | ||
| 111 | + | ||
| 103 | private: | 112 | private: |
| 104 | void DecodeStream(OfflineStream *s) const { | 113 | void DecodeStream(OfflineStream *s) const { |
| 114 | + decoder_->SetConfig(config_.model_config.whisper); | ||
| 115 | + | ||
| 105 | int32_t max_num_frames = 3000; | 116 | int32_t max_num_frames = 3000; |
| 106 | auto memory_info = | 117 | auto memory_info = |
| 107 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 118 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| @@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { | @@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { | ||
| 156 | impl_->DecodeStreams(ss, n); | 156 | impl_->DecodeStreams(ss, n); |
| 157 | } | 157 | } |
| 158 | 158 | ||
| 159 | +void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) { | ||
| 160 | + impl_->SetConfig(config); | ||
| 161 | +} | ||
| 162 | + | ||
| 163 | +OfflineRecognizerConfig OfflineRecognizer::GetConfig() const { | ||
| 164 | + return impl_->GetConfig(); | ||
| 165 | +} | ||
| 166 | + | ||
| 159 | } // namespace sherpa_onnx | 167 | } // namespace sherpa_onnx |
| @@ -119,6 +119,15 @@ class OfflineRecognizer { | @@ -119,6 +119,15 @@ class OfflineRecognizer { | ||
| 119 | */ | 119 | */ |
| 120 | void DecodeStreams(OfflineStream **ss, int32_t n) const; | 120 | void DecodeStreams(OfflineStream **ss, int32_t n) const; |
| 121 | 121 | ||
| 122 | + /** Onnxruntime Session objects are not affected by this method. | ||
| 123 | + * The exact behavior can be defined by a specific recognizer impl. | ||
| 124 | + * For instance, for the whisper recognizer, you can retrieve the language and task from | ||
| 125 | + * the config and ignore any remaining fields in `config`. | ||
| 126 | + */ | ||
| 127 | + void SetConfig(const OfflineRecognizerConfig &config); | ||
| 128 | + | ||
| 129 | + OfflineRecognizerConfig GetConfig() const; | ||
| 130 | + | ||
| 122 | private: | 131 | private: |
| 123 | std::unique_ptr<OfflineRecognizerImpl> impl_; | 132 | std::unique_ptr<OfflineRecognizerImpl> impl_; |
| 124 | }; | 133 | }; |
| @@ -26,7 +26,9 @@ struct OfflineRecognitionResult { | @@ -26,7 +26,9 @@ struct OfflineRecognitionResult { | ||
| 26 | // For instance, for BPE-based models it consists of a list of BPE tokens. | 26 | // For instance, for BPE-based models it consists of a list of BPE tokens. |
| 27 | std::vector<std::string> tokens; | 27 | std::vector<std::string> tokens; |
| 28 | 28 | ||
| 29 | - /// timestamps.size() == tokens.size() | 29 | + std::string lang; |
| 30 | + | ||
| 31 | + /// timestamps.size() == tokens.size() | ||
| 30 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. | 32 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. |
| 31 | std::vector<float> timestamps; | 33 | std::vector<float> timestamps; |
| 32 | 34 |
| @@ -6,14 +6,17 @@ | @@ -6,14 +6,17 @@ | ||
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ |
| 7 | 7 | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | +#include <string> | ||
| 9 | 10 | ||
| 10 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" | ||
| 11 | 13 | ||
| 12 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 13 | 15 | ||
| 14 | struct OfflineWhisperDecoderResult { | 16 | struct OfflineWhisperDecoderResult { |
| 15 | /// The decoded token IDs | 17 | /// The decoded token IDs |
| 16 | std::vector<int32_t> tokens; | 18 | std::vector<int32_t> tokens; |
| 19 | + std::string lang; | ||
| 17 | }; | 20 | }; |
| 18 | 21 | ||
| 19 | class OfflineWhisperDecoder { | 22 | class OfflineWhisperDecoder { |
| @@ -31,6 +34,9 @@ class OfflineWhisperDecoder { | @@ -31,6 +34,9 @@ class OfflineWhisperDecoder { | ||
| 31 | */ | 34 | */ |
| 32 | virtual std::vector<OfflineWhisperDecoderResult> Decode( | 35 | virtual std::vector<OfflineWhisperDecoderResult> Decode( |
| 33 | Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; | 36 | Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; |
| 37 | + | ||
| 38 | + virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; | ||
| 39 | + | ||
| 34 | }; | 40 | }; |
| 35 | 41 | ||
| 36 | } // namespace sherpa_onnx | 42 | } // namespace sherpa_onnx |
| @@ -12,6 +12,10 @@ | @@ -12,6 +12,10 @@ | ||
| 12 | 12 | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | +void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) { | ||
| 16 | + config_ = config; | ||
| 17 | +} | ||
| 18 | + | ||
| 15 | std::vector<OfflineWhisperDecoderResult> | 19 | std::vector<OfflineWhisperDecoderResult> |
| 16 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | 20 | OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, |
| 17 | Ort::Value cross_v) { | 21 | Ort::Value cross_v) { |
| @@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | @@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, | ||
| 129 | 133 | ||
| 130 | std::vector<OfflineWhisperDecoderResult> ans(1); | 134 | std::vector<OfflineWhisperDecoderResult> ans(1); |
| 131 | 135 | ||
| 136 | + const auto &id2lang = model_->GetID2Lang(); | ||
| 137 | + if (id2lang.count(initial_tokens[1])) { | ||
| 138 | + ans[0].lang = id2lang.at(initial_tokens[1]); | ||
| 139 | + } else { | ||
| 140 | + ans[0].lang = ""; | ||
| 141 | + } | ||
| 142 | + | ||
| 132 | ans[0].tokens = std::move(predicted_tokens); | 143 | ans[0].tokens = std::move(predicted_tokens); |
| 133 | 144 | ||
| 134 | return ans; | 145 | return ans; |
| @@ -8,7 +8,6 @@ | @@ -8,7 +8,6 @@ | ||
| 8 | #include <vector> | 8 | #include <vector> |
| 9 | 9 | ||
| 10 | #include "sherpa-onnx/csrc/offline-whisper-decoder.h" | 10 | #include "sherpa-onnx/csrc/offline-whisper-decoder.h" |
| 11 | -#include "sherpa-onnx/csrc/offline-whisper-model-config.h" | ||
| 12 | #include "sherpa-onnx/csrc/offline-whisper-model.h" | 11 | #include "sherpa-onnx/csrc/offline-whisper-model.h" |
| 13 | 12 | ||
| 14 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| @@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | @@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { | ||
| 22 | std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, | 21 | std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, |
| 23 | Ort::Value cross_v) override; | 22 | Ort::Value cross_v) override; |
| 24 | 23 | ||
| 24 | + void SetConfig(const OfflineWhisperModelConfig &config) override; | ||
| 25 | + | ||
| 25 | private: | 26 | private: |
| 26 | OfflineWhisperModelConfig config_; | 27 | OfflineWhisperModelConfig config_; |
| 27 | OfflineWhisperModel *model_; // not owned | 28 | OfflineWhisperModel *model_; // not owned |
-
请 注册 或 登录 后发表评论