ivan provalov
Committed by GitHub

Allow modify model config at decode time for ASR (#1124)

@@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream { @@ -308,8 +308,27 @@ struct SherpaOnnxOfflineStream {
308 : impl(std::move(p)) {} 308 : impl(std::move(p)) {}
309 }; 309 };
310 310
  311 +static sherpa_onnx::OfflineRecognizerConfig convertConfig(
  312 + const SherpaOnnxOfflineRecognizerConfig *config);
311 SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( 313 SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
312 const SherpaOnnxOfflineRecognizerConfig *config) { 314 const SherpaOnnxOfflineRecognizerConfig *config) {
  315 + sherpa_onnx::OfflineRecognizerConfig recognizer_config =
  316 + convertConfig(config);
  317 +
  318 + if (!recognizer_config.Validate()) {
  319 + SHERPA_ONNX_LOGE("Errors in config");
  320 + return nullptr;
  321 + }
  322 +
  323 + SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;
  324 +
  325 + recognizer->impl =
  326 + std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config);
  327 +
  328 + return recognizer;
  329 +}
  330 +sherpa_onnx::OfflineRecognizerConfig convertConfig(
  331 + const SherpaOnnxOfflineRecognizerConfig *config) {
313 sherpa_onnx::OfflineRecognizerConfig recognizer_config; 332 sherpa_onnx::OfflineRecognizerConfig recognizer_config;
314 333
315 recognizer_config.feat_config.sampling_rate = 334 recognizer_config.feat_config.sampling_rate =
@@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( @@ -398,17 +417,15 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
398 SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str()); 417 SHERPA_ONNX_LOGE("%s", recognizer_config.ToString().c_str());
399 } 418 }
400 419
401 - if (!recognizer_config.Validate()) {  
402 - SHERPA_ONNX_LOGE("Errors in config");  
403 - return nullptr;  
404 - }  
405 -  
406 - SherpaOnnxOfflineRecognizer *recognizer = new SherpaOnnxOfflineRecognizer;  
407 -  
408 - recognizer->impl =  
409 - std::make_unique<sherpa_onnx::OfflineRecognizer>(recognizer_config); 420 + return recognizer_config;
  421 +}
410 422
411 - return recognizer; 423 +void SherpaOnnxOfflineRecognizerSetConfig(
  424 + const SherpaOnnxOfflineRecognizer *recognizer,
  425 + const SherpaOnnxOfflineRecognizerConfig *config){
  426 + sherpa_onnx::OfflineRecognizerConfig recognizer_config =
  427 + convertConfig(config);
  428 + recognizer->impl->SetConfig(recognizer_config);
412 } 429 }
413 430
414 void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { 431 void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) {
@@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( @@ -461,6 +478,13 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult(
461 pText[text.size()] = 0; 478 pText[text.size()] = 0;
462 r->text = pText; 479 r->text = pText;
463 480
  481 + //lang
  482 + const auto &lang = result.lang;
  483 + char *c_lang = new char[lang.size() + 1];
  484 + std::copy(lang.begin(), lang.end(), c_lang);
  485 + c_lang[lang.size()] = '\0';
  486 + r->lang = c_lang;
  487 +
464 // copy json 488 // copy json
465 std::string json = result.AsJsonString(); 489 std::string json = result.AsJsonString();
466 char *pJson = new char[json.size() + 1]; 490 char *pJson = new char[json.size() + 1];
@@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult( @@ -517,6 +541,7 @@ void DestroyOfflineRecognizerResult(
517 delete[] r->tokens; 541 delete[] r->tokens;
518 delete[] r->tokens_arr; 542 delete[] r->tokens_arr;
519 delete[] r->json; 543 delete[] r->json;
  544 + delete[] r->lang;
520 delete r; 545 delete r;
521 } 546 }
522 } 547 }
@@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream; @@ -428,6 +428,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineStream SherpaOnnxOfflineStream;
428 SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( 428 SHERPA_ONNX_API SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
429 const SherpaOnnxOfflineRecognizerConfig *config); 429 const SherpaOnnxOfflineRecognizerConfig *config);
430 430
  431 +/// @param config Config for the recognizer.
  432 +SHERPA_ONNX_API void SherpaOnnxOfflineRecognizerSetConfig(
  433 + const SherpaOnnxOfflineRecognizer *recognizer,
  434 + const SherpaOnnxOfflineRecognizerConfig *config);
  435 +
431 /// Free a pointer returned by CreateOfflineRecognizer() 436 /// Free a pointer returned by CreateOfflineRecognizer()
432 /// 437 ///
433 /// @param p A pointer returned by CreateOfflineRecognizer() 438 /// @param p A pointer returned by CreateOfflineRecognizer()
@@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( @@ -491,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams(
491 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { 496 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
492 const char *text; 497 const char *text;
493 498
494 - // Pointer to continuous memory which holds timestamps 499 + // Pointer to continuous memory which holds timestamps
495 // 500 //
496 // It is NULL if the model does not support timestamps 501 // It is NULL if the model does not support timestamps
497 float *timestamps; 502 float *timestamps;
@@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { @@ -519,6 +524,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult {
519 * } 524 * }
520 */ 525 */
521 const char *json; 526 const char *json;
  527 +
  528 + //return recognized language
  529 + const char *lang;
  530 +
522 } SherpaOnnxOfflineRecognizerResult; 531 } SherpaOnnxOfflineRecognizerResult;
523 532
524 /// Get the result of the offline stream. 533 /// Get the result of the offline stream.
@@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -212,6 +212,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
212 } 212 }
213 } 213 }
214 214
  215 + OfflineRecognizerConfig GetConfig() const override {
  216 + return config_;
  217 + }
  218 +
  219 +
215 private: 220 private:
216 // Decode a single stream. 221 // Decode a single stream.
217 // Some models do not support batch size > 1, e.g., WeNet CTC models. 222 // Some models do not support batch size > 1, e.g., WeNet CTC models.
@@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( @@ -431,4 +431,8 @@ std::string OfflineRecognizerImpl::ApplyInverseTextNormalization(
431 return text; 431 return text;
432 } 432 }
433 433
  434 +void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) {
  435 + config_ = config;
  436 +}
  437 +
434 } // namespace sherpa_onnx 438 } // namespace sherpa_onnx
@@ -48,6 +48,10 @@ class OfflineRecognizerImpl { @@ -48,6 +48,10 @@ class OfflineRecognizerImpl {
48 48
49 virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; 49 virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0;
50 50
  51 + virtual void SetConfig(const OfflineRecognizerConfig &config);
  52 +
  53 + virtual OfflineRecognizerConfig GetConfig() const = 0;
  54 +
51 std::string ApplyInverseTextNormalization(std::string text) const; 55 std::string ApplyInverseTextNormalization(std::string text) const;
52 56
53 private: 57 private:
@@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { @@ -211,6 +211,10 @@ class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
211 } 211 }
212 } 212 }
213 213
  214 + OfflineRecognizerConfig GetConfig() const override {
  215 + return config_;
  216 + }
  217 +
214 private: 218 private:
215 std::vector<float> ApplyLFR(const std::vector<float> &in) const { 219 std::vector<float> ApplyLFR(const std::vector<float> &in) const {
216 int32_t lfr_window_size = model_->LfrWindowSize(); 220 int32_t lfr_window_size = model_->LfrWindowSize();
@@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { @@ -246,6 +246,11 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
246 } 246 }
247 } 247 }
248 248
  249 + OfflineRecognizerConfig GetConfig() const override {
  250 + return config_;
  251 + }
  252 +
  253 +
249 void InitHotwords() { 254 void InitHotwords() {
250 // each line in hotwords_file contains space-separated words 255 // each line in hotwords_file contains space-separated words
251 256
@@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { @@ -139,6 +139,10 @@ class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl {
139 } 139 }
140 } 140 }
141 141
  142 + OfflineRecognizerConfig GetConfig() const override {
  143 + return config_;
  144 + }
  145 +
142 private: 146 private:
143 void PostInit() { 147 void PostInit() {
144 config_.feat_config.nemo_normalize_type = 148 config_.feat_config.nemo_normalize_type =
@@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, @@ -45,6 +45,7 @@ static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src,
45 } 45 }
46 46
47 r.text = text; 47 r.text = text;
  48 + r.lang = src.lang;
48 49
49 return r; 50 return r;
50 } 51 }
@@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { @@ -100,8 +101,18 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
100 } 101 }
101 } 102 }
102 103
  104 + void SetConfig(const OfflineRecognizerConfig &config) override {
  105 + config_.model_config.whisper = config.model_config.whisper;
  106 + }
  107 +
  108 + OfflineRecognizerConfig GetConfig() const override {
  109 + return config_;
  110 + }
  111 +
103 private: 112 private:
104 void DecodeStream(OfflineStream *s) const { 113 void DecodeStream(OfflineStream *s) const {
  114 + decoder_->SetConfig(config_.model_config.whisper);
  115 +
105 int32_t max_num_frames = 3000; 116 int32_t max_num_frames = 3000;
106 auto memory_info = 117 auto memory_info =
107 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 118 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { @@ -156,4 +156,12 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
156 impl_->DecodeStreams(ss, n); 156 impl_->DecodeStreams(ss, n);
157 } 157 }
158 158
  159 +void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) {
  160 + impl_->SetConfig(config);
  161 +}
  162 +
  163 +OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {
  164 + return impl_->GetConfig();
  165 +}
  166 +
159 } // namespace sherpa_onnx 167 } // namespace sherpa_onnx
@@ -119,6 +119,15 @@ class OfflineRecognizer { @@ -119,6 +119,15 @@ class OfflineRecognizer {
119 */ 119 */
120 void DecodeStreams(OfflineStream **ss, int32_t n) const; 120 void DecodeStreams(OfflineStream **ss, int32_t n) const;
121 121
  122 + /** Onnxruntime Session objects are not affected by this method.
  123 + * The exact behavior can be defined by a specific recognizer impl.
  124 + * For instance, for the whisper recognizer, you can retrieve the language and task from
  125 + * the config and ignore any remaining fields in `config`.
  126 + */
  127 + void SetConfig(const OfflineRecognizerConfig &config);
  128 +
  129 + OfflineRecognizerConfig GetConfig() const;
  130 +
122 private: 131 private:
123 std::unique_ptr<OfflineRecognizerImpl> impl_; 132 std::unique_ptr<OfflineRecognizerImpl> impl_;
124 }; 133 };
@@ -26,7 +26,9 @@ struct OfflineRecognitionResult { @@ -26,7 +26,9 @@ struct OfflineRecognitionResult {
26 // For instance, for BPE-based models it consists of a list of BPE tokens. 26 // For instance, for BPE-based models it consists of a list of BPE tokens.
27 std::vector<std::string> tokens; 27 std::vector<std::string> tokens;
28 28
29 - /// timestamps.size() == tokens.size() 29 + std::string lang;
  30 +
  31 + /// timestamps.size() == tokens.size()
30 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 32 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
31 std::vector<float> timestamps; 33 std::vector<float> timestamps;
32 34
@@ -6,14 +6,17 @@ @@ -6,14 +6,17 @@
6 #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_
7 7
8 #include <vector> 8 #include <vector>
  9 +#include <string>
9 10
10 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
11 13
12 namespace sherpa_onnx { 14 namespace sherpa_onnx {
13 15
14 struct OfflineWhisperDecoderResult { 16 struct OfflineWhisperDecoderResult {
15 /// The decoded token IDs 17 /// The decoded token IDs
16 std::vector<int32_t> tokens; 18 std::vector<int32_t> tokens;
  19 + std::string lang;
17 }; 20 };
18 21
19 class OfflineWhisperDecoder { 22 class OfflineWhisperDecoder {
@@ -31,6 +34,9 @@ class OfflineWhisperDecoder { @@ -31,6 +34,9 @@ class OfflineWhisperDecoder {
31 */ 34 */
32 virtual std::vector<OfflineWhisperDecoderResult> Decode( 35 virtual std::vector<OfflineWhisperDecoderResult> Decode(
33 Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; 36 Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
  37 +
  38 + virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
  39 +
34 }; 40 };
35 41
36 } // namespace sherpa_onnx 42 } // namespace sherpa_onnx
@@ -12,6 +12,10 @@ @@ -12,6 +12,10 @@
12 12
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
  15 +void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) {
  16 + config_ = config;
  17 +}
  18 +
15 std::vector<OfflineWhisperDecoderResult> 19 std::vector<OfflineWhisperDecoderResult>
16 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, 20 OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
17 Ort::Value cross_v) { 21 Ort::Value cross_v) {
@@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, @@ -129,6 +133,13 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
129 133
130 std::vector<OfflineWhisperDecoderResult> ans(1); 134 std::vector<OfflineWhisperDecoderResult> ans(1);
131 135
  136 + const auto &id2lang = model_->GetID2Lang();
  137 + if (id2lang.count(initial_tokens[1])) {
  138 + ans[0].lang = id2lang.at(initial_tokens[1]);
  139 + } else {
  140 + ans[0].lang = "";
  141 + }
  142 +
132 ans[0].tokens = std::move(predicted_tokens); 143 ans[0].tokens = std::move(predicted_tokens);
133 144
134 return ans; 145 return ans;
@@ -8,7 +8,6 @@ @@ -8,7 +8,6 @@
8 #include <vector> 8 #include <vector>
9 9
10 #include "sherpa-onnx/csrc/offline-whisper-decoder.h" 10 #include "sherpa-onnx/csrc/offline-whisper-decoder.h"
11 -#include "sherpa-onnx/csrc/offline-whisper-model-config.h"  
12 #include "sherpa-onnx/csrc/offline-whisper-model.h" 11 #include "sherpa-onnx/csrc/offline-whisper-model.h"
13 12
14 namespace sherpa_onnx { 13 namespace sherpa_onnx {
@@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { @@ -22,6 +21,8 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
22 std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k, 21 std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
23 Ort::Value cross_v) override; 22 Ort::Value cross_v) override;
24 23
  24 + void SetConfig(const OfflineWhisperModelConfig &config) override;
  25 +
25 private: 26 private:
26 OfflineWhisperModelConfig config_; 27 OfflineWhisperModelConfig config_;
27 OfflineWhisperModel *model_; // not owned 28 OfflineWhisperModel *model_; // not owned