Fangjun Kuang
Committed by GitHub

Support whisper language/task in various language bindings. (#679)

@@ -40,6 +40,12 @@ class OfflineDecodeFiles @@ -40,6 +40,12 @@ class OfflineDecodeFiles
40 [Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")] 40 [Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")]
41 public string WhisperDecoder { get; set; } 41 public string WhisperDecoder { get; set; }
42 42
  43 + [Option("whisper-language", Required = false, Default = "", HelpText = "Language of the input file. Can be empty")]
  44 + public string WhisperLanguage{ get; set; }
  45 +
  46 + [Option("whisper-task", Required = false, Default = "transcribe", HelpText = "transcribe or translate")]
  47 + public string WhisperTask{ get; set; }
  48 +
43 [Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")] 49 [Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")]
44 public string TdnnModel { get; set; } 50 public string TdnnModel { get; set; }
45 51
@@ -193,6 +199,8 @@ to download pre-trained Tdnn models. @@ -193,6 +199,8 @@ to download pre-trained Tdnn models.
193 { 199 {
194 config.ModelConfig.Whisper.Encoder = options.WhisperEncoder; 200 config.ModelConfig.Whisper.Encoder = options.WhisperEncoder;
195 config.ModelConfig.Whisper.Decoder = options.WhisperDecoder; 201 config.ModelConfig.Whisper.Decoder = options.WhisperDecoder;
  202 + config.ModelConfig.Whisper.Language = options.WhisperLanguage;
  203 + config.ModelConfig.Whisper.Task = options.WhisperTask;
196 } 204 }
197 else if (!String.IsNullOrEmpty(options.TdnnModel)) 205 else if (!String.IsNullOrEmpty(options.TdnnModel))
198 { 206 {
@@ -29,6 +29,8 @@ func main() { @@ -29,6 +29,8 @@ func main() {
29 29
30 flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model") 30 flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model")
31 flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model") 31 flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model")
  32 + flag.StringVar(&config.ModelConfig.Whisper.Language, "whisper-language", "", "Language of the input wave. You can leave it empty ")
  33 + flag.StringVar(&config.ModelConfig.Whisper.Task, "whisper-task", "transcribe", "transcribe or translate")
32 34
33 flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model") 35 flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model")
34 36
@@ -27,6 +27,8 @@ function createOfflineRecognizer() { @@ -27,6 +27,8 @@ function createOfflineRecognizer() {
27 whisper: { 27 whisper: {
28 encoder: '', 28 encoder: '',
29 decoder: '', 29 decoder: '',
  30 + language: '',
  31 + task: '',
30 }, 32 },
31 tdnn: { 33 tdnn: {
32 model: '', 34 model: '',
@@ -27,6 +27,8 @@ function createOfflineRecognizer() { @@ -27,6 +27,8 @@ function createOfflineRecognizer() {
27 whisper: { 27 whisper: {
28 encoder: '', 28 encoder: '',
29 decoder: '', 29 decoder: '',
  30 + language: '',
  31 + task: '',
30 }, 32 },
31 tdnn: { 33 tdnn: {
32 model: '', 34 model: '',
@@ -30,6 +30,8 @@ function createOfflineRecognizer() { @@ -30,6 +30,8 @@ function createOfflineRecognizer() {
30 whisper: { 30 whisper: {
31 encoder: '', 31 encoder: '',
32 decoder: '', 32 decoder: '',
  33 + language: '',
  34 + task: '',
33 }, 35 },
34 tdnn: { 36 tdnn: {
35 model: '', 37 model: '',
@@ -27,6 +27,8 @@ function createOfflineRecognizer() { @@ -27,6 +27,8 @@ function createOfflineRecognizer() {
27 whisper: { 27 whisper: {
28 encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx', 28 encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx',
29 decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', 29 decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
  30 + language: '',
  31 + task: 'transcribe',
30 }, 32 },
31 tdnn: { 33 tdnn: {
32 model: '', 34 model: '',
@@ -279,12 +279,20 @@ namespace SherpaOnnx @@ -279,12 +279,20 @@ namespace SherpaOnnx
279 { 279 {
280 Encoder = ""; 280 Encoder = "";
281 Decoder = ""; 281 Decoder = "";
  282 + Language = "";
  283 + Task = "transcribe";
282 } 284 }
283 [MarshalAs(UnmanagedType.LPStr)] 285 [MarshalAs(UnmanagedType.LPStr)]
284 public string Encoder; 286 public string Encoder;
285 287
286 [MarshalAs(UnmanagedType.LPStr)] 288 [MarshalAs(UnmanagedType.LPStr)]
287 public string Decoder; 289 public string Decoder;
  290 +
  291 + [MarshalAs(UnmanagedType.LPStr)]
  292 + public string Language;
  293 +
  294 + [MarshalAs(UnmanagedType.LPStr)]
  295 + public string Task;
288 } 296 }
289 297
290 [StructLayout(LayoutKind.Sequential)] 298 [StructLayout(LayoutKind.Sequential)]
@@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct { @@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct {
326 } 326 }
327 327
328 type OfflineWhisperModelConfig struct { 328 type OfflineWhisperModelConfig struct {
329 - Encoder string  
330 - Decoder string 329 + Encoder string
  330 + Decoder string
  331 + Language string
  332 + Task string
331 } 333 }
332 334
333 type OfflineTdnnModelConfig struct { 335 type OfflineTdnnModelConfig struct {
@@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { @@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
423 c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) 425 c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder)
424 defer C.free(unsafe.Pointer(c.model_config.whisper.decoder)) 426 defer C.free(unsafe.Pointer(c.model_config.whisper.decoder))
425 427
  428 + c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language)
  429 + defer C.free(unsafe.Pointer(c.model_config.whisper.language))
  430 +
  431 + c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
  432 + defer C.free(unsafe.Pointer(c.model_config.whisper.task))
  433 +
426 c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) 434 c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
427 defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) 435 defer C.free(unsafe.Pointer(c.model_config.tdnn.model))
428 436
@@ -11,13 +11,13 @@ @@ -11,13 +11,13 @@
11 11
12 #include "sherpa-onnx/csrc/circular-buffer.h" 12 #include "sherpa-onnx/csrc/circular-buffer.h"
13 #include "sherpa-onnx/csrc/display.h" 13 #include "sherpa-onnx/csrc/display.h"
  14 +#include "sherpa-onnx/csrc/keyword-spotter.h"
14 #include "sherpa-onnx/csrc/macros.h" 15 #include "sherpa-onnx/csrc/macros.h"
15 #include "sherpa-onnx/csrc/offline-recognizer.h" 16 #include "sherpa-onnx/csrc/offline-recognizer.h"
16 #include "sherpa-onnx/csrc/offline-tts.h" 17 #include "sherpa-onnx/csrc/offline-tts.h"
17 #include "sherpa-onnx/csrc/online-recognizer.h" 18 #include "sherpa-onnx/csrc/online-recognizer.h"
18 #include "sherpa-onnx/csrc/voice-activity-detector.h" 19 #include "sherpa-onnx/csrc/voice-activity-detector.h"
19 #include "sherpa-onnx/csrc/wave-writer.h" 20 #include "sherpa-onnx/csrc/wave-writer.h"
20 -#include "sherpa-onnx/csrc/keyword-spotter.h"  
21 21
22 struct SherpaOnnxOnlineRecognizer { 22 struct SherpaOnnxOnlineRecognizer {
23 std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl; 23 std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
@@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( @@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
301 recognizer_config.model_config.whisper.language = 301 recognizer_config.model_config.whisper.language =
302 SHERPA_ONNX_OR(config->model_config.whisper.language, ""); 302 SHERPA_ONNX_OR(config->model_config.whisper.language, "");
303 303
  304 + recognizer_config.model_config.whisper.task =
  305 + SHERPA_ONNX_OR(config->model_config.whisper.task, "transcribe");
  306 +
304 recognizer_config.model_config.tdnn.model = 307 recognizer_config.model_config.tdnn.model =
305 SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); 308 SHERPA_ONNX_OR(config->model_config.tdnn.model, "");
306 309
@@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter { @@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter {
422 std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; 425 std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
423 }; 426 };
424 427
425 -SherpaOnnxKeywordSpotter* CreateKeywordSpotter(  
426 - const SherpaOnnxKeywordSpotterConfig* config) { 428 +SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
  429 + const SherpaOnnxKeywordSpotterConfig *config) {
427 sherpa_onnx::KeywordSpotterConfig spotter_config; 430 sherpa_onnx::KeywordSpotterConfig spotter_config;
428 431
429 spotter_config.feat_config.sampling_rate = 432 spotter_config.feat_config.sampling_rate =
@@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( @@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
457 spotter_config.model_config.debug = 460 spotter_config.model_config.debug =
458 SHERPA_ONNX_OR(config->model_config.debug, 0); 461 SHERPA_ONNX_OR(config->model_config.debug, 0);
459 462
460 - spotter_config.max_active_paths =  
461 - SHERPA_ONNX_OR(config->max_active_paths, 4); 463 + spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);
462 464
463 spotter_config.num_trailing_blanks = 465 spotter_config.num_trailing_blanks =
464 - SHERPA_ONNX_OR(config->num_trailing_blanks , 1); 466 + SHERPA_ONNX_OR(config->num_trailing_blanks, 1);
465 467
466 - spotter_config.keywords_score =  
467 - SHERPA_ONNX_OR(config->keywords_score, 1.0); 468 + spotter_config.keywords_score = SHERPA_ONNX_OR(config->keywords_score, 1.0);
468 469
469 spotter_config.keywords_threshold = 470 spotter_config.keywords_threshold =
470 SHERPA_ONNX_OR(config->keywords_threshold, 0.25); 471 SHERPA_ONNX_OR(config->keywords_threshold, 0.25);
471 472
472 - spotter_config.keywords_file =  
473 - SHERPA_ONNX_OR(config->keywords_file, ""); 473 + spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, "");
474 474
475 if (config->model_config.debug) { 475 if (config->model_config.debug) {
476 SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); 476 SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
@@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( @@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
481 return nullptr; 481 return nullptr;
482 } 482 }
483 483
484 - SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; 484 + SherpaOnnxKeywordSpotter *spotter = new SherpaOnnxKeywordSpotter;
485 485
486 - spotter->impl =  
487 - std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config); 486 + spotter->impl = std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);
488 487
489 return spotter; 488 return spotter;
490 } 489 }
491 490
492 -void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { 491 +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) {
493 delete spotter; 492 delete spotter;
494 } 493 }
495 494
496 -SherpaOnnxOnlineStream* CreateKeywordStream(  
497 - const SherpaOnnxKeywordSpotter* spotter) {  
498 - SherpaOnnxOnlineStream* stream = 495 +SherpaOnnxOnlineStream *CreateKeywordStream(
  496 + const SherpaOnnxKeywordSpotter *spotter) {
  497 + SherpaOnnxOnlineStream *stream =
499 new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); 498 new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
500 return stream; 499 return stream;
501 } 500 }
502 501
503 -int32_t IsKeywordStreamReady(  
504 - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { 502 +int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
  503 + SherpaOnnxOnlineStream *stream) {
505 return spotter->impl->IsReady(stream->impl.get()); 504 return spotter->impl->IsReady(stream->impl.get());
506 } 505 }
507 506
508 -void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,  
509 - SherpaOnnxOnlineStream* stream) { 507 +void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
  508 + SherpaOnnxOnlineStream *stream) {
510 return spotter->impl->DecodeStream(stream->impl.get()); 509 return spotter->impl->DecodeStream(stream->impl.get());
511 } 510 }
512 511
513 -void DecodeMultipleKeywordStreams(  
514 - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,  
515 - int32_t n) {  
516 - std::vector<sherpa_onnx::OnlineStream*> ss(n); 512 +void DecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
  513 + SherpaOnnxOnlineStream **streams, int32_t n) {
  514 + std::vector<sherpa_onnx::OnlineStream *> ss(n);
517 for (int32_t i = 0; i != n; ++i) { 515 for (int32_t i = 0; i != n; ++i) {
518 ss[i] = streams[i]->impl.get(); 516 ss[i] = streams[i]->impl.get();
519 } 517 }
@@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams( @@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams(
522 520
523 const SherpaOnnxKeywordResult *GetKeywordResult( 521 const SherpaOnnxKeywordResult *GetKeywordResult(
524 SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { 522 SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
525 - const sherpa_onnx::KeywordResult& result = 523 + const sherpa_onnx::KeywordResult &result =
526 spotter->impl->GetResult(stream->impl.get()); 524 spotter->impl->GetResult(stream->impl.get());
527 const auto &keyword = result.keyword; 525 const auto &keyword = result.keyword;
528 526
@@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { @@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
333 const char *encoder; 333 const char *encoder;
334 const char *decoder; 334 const char *decoder;
335 const char *language; 335 const char *language;
  336 + const char *task;
336 } SherpaOnnxOfflineWhisperModelConfig; 337 } SherpaOnnxOfflineWhisperModelConfig;
337 338
338 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { 339 SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
@@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { @@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
483 /// For Chinese, it consists of Chinese words without spaces. 484 /// For Chinese, it consists of Chinese words without spaces.
484 /// Example 1: "hello world" 485 /// Example 1: "hello world"
485 /// Example 2: "你好世界" 486 /// Example 2: "你好世界"
486 - const char* keyword; 487 + const char *keyword;
487 488
488 /// Decoded results at the token level. 489 /// Decoded results at the token level.
489 /// For instance, for BPE-based models it consists of a list of BPE tokens. 490 /// For instance, for BPE-based models it consists of a list of BPE tokens.
490 - const char* tokens; 491 + const char *tokens;
491 492
492 - const char* const* tokens_arr; 493 + const char *const *tokens_arr;
493 494
494 int32_t count; 495 int32_t count;
495 496
496 /// timestamps.size() == tokens.size() 497 /// timestamps.size() == tokens.size()
497 /// timestamps[i] records the time in seconds when tokens[i] is decoded. 498 /// timestamps[i] records the time in seconds when tokens[i] is decoded.
498 - float* timestamps; 499 + float *timestamps;
499 500
500 /// Starting time of this segment. 501 /// Starting time of this segment.
501 /// When an endpoint is detected, it will change 502 /// When an endpoint is detected, it will change
@@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { @@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
511 * "start_time": x, 512 * "start_time": x,
512 * } 513 * }
513 */ 514 */
514 - const char* json; 515 + const char *json;
515 } SherpaOnnxKeywordResult; 516 } SherpaOnnxKeywordResult;
516 517
517 SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { 518 SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
@@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { @@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
521 int32_t num_trailing_blanks; 522 int32_t num_trailing_blanks;
522 float keywords_score; 523 float keywords_score;
523 float keywords_threshold; 524 float keywords_threshold;
524 - const char* keywords_file; 525 + const char *keywords_file;
525 } SherpaOnnxKeywordSpotterConfig; 526 } SherpaOnnxKeywordSpotterConfig;
526 527
527 SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter 528 SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
@@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter @@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
530 /// @param config Config for the keyword spotter. 531 /// @param config Config for the keyword spotter.
531 /// @return Return a pointer to the spotter. The user has to invoke 532 /// @return Return a pointer to the spotter. The user has to invoke
532 /// DestroyKeywordSpotter() to free it to avoid memory leak. 533 /// DestroyKeywordSpotter() to free it to avoid memory leak.
533 -SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter(  
534 - const SherpaOnnxKeywordSpotterConfig* config); 534 +SHERPA_ONNX_API SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
  535 + const SherpaOnnxKeywordSpotterConfig *config);
535 536
536 /// Free a pointer returned by CreateKeywordSpotter() 537 /// Free a pointer returned by CreateKeywordSpotter()
537 /// 538 ///
538 /// @param p A pointer returned by CreateKeywordSpotter() 539 /// @param p A pointer returned by CreateKeywordSpotter()
539 -SHERPA_ONNX_API void DestroyKeywordSpotter(  
540 - SherpaOnnxKeywordSpotter* spotter); 540 +SHERPA_ONNX_API void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter);
541 541
542 /// Create an online stream for accepting wave samples. 542 /// Create an online stream for accepting wave samples.
543 /// 543 ///
544 /// @param spotter A pointer returned by CreateKeywordSpotter() 544 /// @param spotter A pointer returned by CreateKeywordSpotter()
545 /// @return Return a pointer to an OnlineStream. The user has to invoke 545 /// @return Return a pointer to an OnlineStream. The user has to invoke
546 /// DestroyOnlineStream() to free it to avoid memory leak. 546 /// DestroyOnlineStream() to free it to avoid memory leak.
547 -SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream(  
548 - const SherpaOnnxKeywordSpotter* spotter); 547 +SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateKeywordStream(
  548 + const SherpaOnnxKeywordSpotter *spotter);
549 549
550 /// Return 1 if there are enough number of feature frames for decoding. 550 /// Return 1 if there are enough number of feature frames for decoding.
551 /// Return 0 otherwise. 551 /// Return 0 otherwise.
552 /// 552 ///
553 /// @param spotter A pointer returned by CreateKeywordSpotter 553 /// @param spotter A pointer returned by CreateKeywordSpotter
554 /// @param stream A pointer returned by CreateKeywordStream 554 /// @param stream A pointer returned by CreateKeywordStream
555 -SHERPA_ONNX_API int32_t IsKeywordStreamReady(  
556 - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream); 555 +SHERPA_ONNX_API int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
  556 + SherpaOnnxOnlineStream *stream);
557 557
558 /// Call this function to run the neural network model and decoding. 558 /// Call this function to run the neural network model and decoding.
559 // 559 //
560 /// Precondition for this function: IsKeywordStreamReady() MUST return 1. 560 /// Precondition for this function: IsKeywordStreamReady() MUST return 1.
561 -SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,  
562 - SherpaOnnxOnlineStream* stream); 561 +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
  562 + SherpaOnnxOnlineStream *stream);
563 563
564 /// This function is similar to DecodeKeywordStream(). It decodes multiple 564 /// This function is similar to DecodeKeywordStream(). It decodes multiple
565 /// OnlineStream in parallel. 565 /// OnlineStream in parallel.
@@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( @@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult(
588 /// Destroy the pointer returned by GetKeywordResult(). 588 /// Destroy the pointer returned by GetKeywordResult().
589 /// 589 ///
590 /// @param r A pointer returned by GetKeywordResult() 590 /// @param r A pointer returned by GetKeywordResult()
591 -SHERPA_ONNX_API void DestroyKeywordResult(  
592 - const SherpaOnnxKeywordResult *r); 591 +SHERPA_ONNX_API void DestroyKeywordResult(const SherpaOnnxKeywordResult *r);
593 592
594 // ============================================================ 593 // ============================================================
595 // For VAD 594 // For VAD
@@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl { @@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl {
223 inputs.push_back(std::move(length_scale_tensor)); 223 inputs.push_back(std::move(length_scale_tensor));
224 inputs.push_back(std::move(noise_scale_w_tensor)); 224 inputs.push_back(std::move(noise_scale_w_tensor));
225 225
226 - if (input_names_.size() == 6 && input_names_.back() == "sid") { 226 + if (input_names_.size() == 6 &&
  227 + (input_names_.back() == "sid" || input_names_.back() == "speaker")) {
227 inputs.push_back(std::move(sid_tensor)); 228 inputs.push_back(std::move(sid_tensor));
228 } 229 }
229 230
@@ -2,14 +2,16 @@ @@ -2,14 +2,16 @@
2 // 2 //
3 // Copyright (c) 2023-2024 Xiaomi Corporation 3 // Copyright (c) 2023-2024 Xiaomi Corporation
4 4
  5 +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
  6 +
5 #include <algorithm> 7 #include <algorithm>
6 #include <cmath> 8 #include <cmath>
  9 +#include <cstring>
7 #include <utility> 10 #include <utility>
8 #include <vector> 11 #include <vector>
9 12
10 #include "sherpa-onnx/csrc/log.h" 13 #include "sherpa-onnx/csrc/log.h"
11 #include "sherpa-onnx/csrc/onnx-utils.h" 14 #include "sherpa-onnx/csrc/onnx-utils.h"
12 -#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"  
13 15
14 namespace sherpa_onnx { 16 namespace sherpa_onnx {
15 17
@@ -242,17 +242,17 @@ class SherpaOnnxRecognizer { @@ -242,17 +242,17 @@ class SherpaOnnxRecognizer {
242 /// the given hotWords appended to the default hotwords. 242 /// the given hotWords appended to the default hotwords.
243 func reset(hotwords: String? = nil) { 243 func reset(hotwords: String? = nil) {
244 guard let words = hotwords, !words.isEmpty else { 244 guard let words = hotwords, !words.isEmpty else {
245 - Reset(recognizer, stream)  
246 - return 245 + Reset(recognizer, stream)
  246 + return
247 } 247 }
248 - 248 +
249 words.withCString { cString in 249 words.withCString { cString in
250 - let newStream = CreateOnlineStreamWithHotwords(recognizer, cString)  
251 - // lock while release and replace stream  
252 - objc_sync_enter(self)  
253 - DestroyOnlineStream(stream)  
254 - stream = newStream  
255 - objc_sync_exit(self) 250 + let newStream = CreateOnlineStreamWithHotwords(recognizer, cString)
  251 + // lock while release and replace stream
  252 + objc_sync_enter(self)
  253 + DestroyOnlineStream(stream)
  254 + stream = newStream
  255 + objc_sync_exit(self)
256 } 256 }
257 } 257 }
258 258
@@ -300,11 +300,15 @@ func sherpaOnnxOfflineNemoEncDecCtcModelConfig( @@ -300,11 +300,15 @@ func sherpaOnnxOfflineNemoEncDecCtcModelConfig(
300 300
301 func sherpaOnnxOfflineWhisperModelConfig( 301 func sherpaOnnxOfflineWhisperModelConfig(
302 encoder: String = "", 302 encoder: String = "",
303 - decoder: String = "" 303 + decoder: String = "",
  304 + language: String = "",
  305 + task: String = "transcribe"
304 ) -> SherpaOnnxOfflineWhisperModelConfig { 306 ) -> SherpaOnnxOfflineWhisperModelConfig {
305 return SherpaOnnxOfflineWhisperModelConfig( 307 return SherpaOnnxOfflineWhisperModelConfig(
306 encoder: toCPointer(encoder), 308 encoder: toCPointer(encoder),
307 - decoder: toCPointer(decoder) 309 + decoder: toCPointer(decoder),
  310 + language: toCPointer(language),
  311 + task: toCPointer(task)
308 ) 312 )
309 } 313 }
310 314
@@ -393,11 +393,13 @@ function initSherpaOnnxOfflineNemoEncDecCtcModelConfig(config, Module) { @@ -393,11 +393,13 @@ function initSherpaOnnxOfflineNemoEncDecCtcModelConfig(config, Module) {
393 function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { 393 function initSherpaOnnxOfflineWhisperModelConfig(config, Module) {
394 const encoderLen = Module.lengthBytesUTF8(config.encoder) + 1; 394 const encoderLen = Module.lengthBytesUTF8(config.encoder) + 1;
395 const decoderLen = Module.lengthBytesUTF8(config.decoder) + 1; 395 const decoderLen = Module.lengthBytesUTF8(config.decoder) + 1;
  396 + const languageLen = Module.lengthBytesUTF8(config.language) + 1;
  397 + const taskLen = Module.lengthBytesUTF8(config.task) + 1;
396 398
397 - const n = encoderLen + decoderLen; 399 + const n = encoderLen + decoderLen + languageLen + taskLen;
398 const buffer = Module._malloc(n); 400 const buffer = Module._malloc(n);
399 401
400 - const len = 2 * 4; // 2 pointers 402 + const len = 4 * 4; // 4 pointers
401 const ptr = Module._malloc(len); 403 const ptr = Module._malloc(len);
402 404
403 let offset = 0; 405 let offset = 0;
@@ -405,12 +407,25 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { @@ -405,12 +407,25 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) {
405 offset += encoderLen; 407 offset += encoderLen;
406 408
407 Module.stringToUTF8(config.decoder, buffer + offset, decoderLen); 409 Module.stringToUTF8(config.decoder, buffer + offset, decoderLen);
  410 + offset += decoderLen;
  411 +
  412 + Module.stringToUTF8(config.language, buffer + offset, languageLen);
  413 + offset += languageLen;
  414 +
  415 + Module.stringToUTF8(config.task, buffer + offset, taskLen);
408 416
409 offset = 0; 417 offset = 0;
410 Module.setValue(ptr, buffer + offset, 'i8*'); 418 Module.setValue(ptr, buffer + offset, 'i8*');
411 offset += encoderLen; 419 offset += encoderLen;
412 420
413 Module.setValue(ptr + 4, buffer + offset, 'i8*'); 421 Module.setValue(ptr + 4, buffer + offset, 'i8*');
  422 + offset += decoderLen;
  423 +
  424 + Module.setValue(ptr + 8, buffer + offset, 'i8*');
  425 + offset += languageLen;
  426 +
  427 + Module.setValue(ptr + 12, buffer + offset, 'i8*');
  428 + offset += taskLen;
414 429
415 return { 430 return {
416 buffer: buffer, ptr: ptr, len: len, 431 buffer: buffer, ptr: ptr, len: len,
@@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, "");
14 static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); 14 static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, "");
15 15
16 static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); 16 static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, "");
17 -static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 2 * 4, ""); 17 +static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, "");
18 static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); 18 static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, "");
19 static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); 19 static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, "");
20 20
@@ -77,6 +77,8 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { @@ -77,6 +77,8 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) {
77 fprintf(stdout, "----------offline whisper model config----------\n"); 77 fprintf(stdout, "----------offline whisper model config----------\n");
78 fprintf(stdout, "encoder: %s\n", whisper->encoder); 78 fprintf(stdout, "encoder: %s\n", whisper->encoder);
79 fprintf(stdout, "decoder: %s\n", whisper->decoder); 79 fprintf(stdout, "decoder: %s\n", whisper->decoder);
  80 + fprintf(stdout, "language: %s\n", whisper->language);
  81 + fprintf(stdout, "task: %s\n", whisper->task);
80 82
81 fprintf(stdout, "----------offline tdnn model config----------\n"); 83 fprintf(stdout, "----------offline tdnn model config----------\n");
82 fprintf(stdout, "model: %s\n", tdnn->model); 84 fprintf(stdout, "model: %s\n", tdnn->model);