Committed by
GitHub
Support whisper language/task in various language bindings. (#679)
正在显示
15 个修改的文件
包含
117 行增加
和
62 行删除
| @@ -40,6 +40,12 @@ class OfflineDecodeFiles | @@ -40,6 +40,12 @@ class OfflineDecodeFiles | ||
| 40 | [Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")] | 40 | [Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")] |
| 41 | public string WhisperDecoder { get; set; } | 41 | public string WhisperDecoder { get; set; } |
| 42 | 42 | ||
| 43 | + [Option("whisper-language", Required = false, Default = "", HelpText = "Language of the input file. Can be empty")] | ||
| 44 | + public string WhisperLanguage{ get; set; } | ||
| 45 | + | ||
| 46 | + [Option("whisper-task", Required = false, Default = "transcribe", HelpText = "transcribe or translate")] | ||
| 47 | + public string WhisperTask{ get; set; } | ||
| 48 | + | ||
| 43 | [Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")] | 49 | [Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")] |
| 44 | public string TdnnModel { get; set; } | 50 | public string TdnnModel { get; set; } |
| 45 | 51 | ||
| @@ -193,6 +199,8 @@ to download pre-trained Tdnn models. | @@ -193,6 +199,8 @@ to download pre-trained Tdnn models. | ||
| 193 | { | 199 | { |
| 194 | config.ModelConfig.Whisper.Encoder = options.WhisperEncoder; | 200 | config.ModelConfig.Whisper.Encoder = options.WhisperEncoder; |
| 195 | config.ModelConfig.Whisper.Decoder = options.WhisperDecoder; | 201 | config.ModelConfig.Whisper.Decoder = options.WhisperDecoder; |
| 202 | + config.ModelConfig.Whisper.Language = options.WhisperLanguage; | ||
| 203 | + config.ModelConfig.Whisper.Task = options.WhisperTask; | ||
| 196 | } | 204 | } |
| 197 | else if (!String.IsNullOrEmpty(options.TdnnModel)) | 205 | else if (!String.IsNullOrEmpty(options.TdnnModel)) |
| 198 | { | 206 | { |
| @@ -29,6 +29,8 @@ func main() { | @@ -29,6 +29,8 @@ func main() { | ||
| 29 | 29 | ||
| 30 | flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model") | 30 | flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model") |
| 31 | flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model") | 31 | flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model") |
| 32 | + flag.StringVar(&config.ModelConfig.Whisper.Language, "whisper-language", "", "Language of the input wave. You can leave it empty ") | ||
| 33 | + flag.StringVar(&config.ModelConfig.Whisper.Task, "whisper-task", "transcribe", "transcribe or translate") | ||
| 32 | 34 | ||
| 33 | flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model") | 35 | flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model") |
| 34 | 36 |
| @@ -27,6 +27,8 @@ function createOfflineRecognizer() { | @@ -27,6 +27,8 @@ function createOfflineRecognizer() { | ||
| 27 | whisper: { | 27 | whisper: { |
| 28 | encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx', | 28 | encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx', |
| 29 | decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', | 29 | decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx', |
| 30 | + language: '', | ||
| 31 | + task: 'transcribe', | ||
| 30 | }, | 32 | }, |
| 31 | tdnn: { | 33 | tdnn: { |
| 32 | model: '', | 34 | model: '', |
| @@ -279,12 +279,20 @@ namespace SherpaOnnx | @@ -279,12 +279,20 @@ namespace SherpaOnnx | ||
| 279 | { | 279 | { |
| 280 | Encoder = ""; | 280 | Encoder = ""; |
| 281 | Decoder = ""; | 281 | Decoder = ""; |
| 282 | + Language = ""; | ||
| 283 | + Task = "transcribe"; | ||
| 282 | } | 284 | } |
| 283 | [MarshalAs(UnmanagedType.LPStr)] | 285 | [MarshalAs(UnmanagedType.LPStr)] |
| 284 | public string Encoder; | 286 | public string Encoder; |
| 285 | 287 | ||
| 286 | [MarshalAs(UnmanagedType.LPStr)] | 288 | [MarshalAs(UnmanagedType.LPStr)] |
| 287 | public string Decoder; | 289 | public string Decoder; |
| 290 | + | ||
| 291 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 292 | + public string Language; | ||
| 293 | + | ||
| 294 | + [MarshalAs(UnmanagedType.LPStr)] | ||
| 295 | + public string Task; | ||
| 288 | } | 296 | } |
| 289 | 297 | ||
| 290 | [StructLayout(LayoutKind.Sequential)] | 298 | [StructLayout(LayoutKind.Sequential)] |
| @@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct { | @@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct { | ||
| 326 | } | 326 | } |
| 327 | 327 | ||
| 328 | type OfflineWhisperModelConfig struct { | 328 | type OfflineWhisperModelConfig struct { |
| 329 | - Encoder string | ||
| 330 | - Decoder string | 329 | + Encoder string |
| 330 | + Decoder string | ||
| 331 | + Language string | ||
| 332 | + Task string | ||
| 331 | } | 333 | } |
| 332 | 334 | ||
| 333 | type OfflineTdnnModelConfig struct { | 335 | type OfflineTdnnModelConfig struct { |
| @@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { | @@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { | ||
| 423 | c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) | 425 | c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) |
| 424 | defer C.free(unsafe.Pointer(c.model_config.whisper.decoder)) | 426 | defer C.free(unsafe.Pointer(c.model_config.whisper.decoder)) |
| 425 | 427 | ||
| 428 | + c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language) | ||
| 429 | + defer C.free(unsafe.Pointer(c.model_config.whisper.language)) | ||
| 430 | + | ||
| 431 | + c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) | ||
| 432 | + defer C.free(unsafe.Pointer(c.model_config.whisper.task)) | ||
| 433 | + | ||
| 426 | c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) | 434 | c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) |
| 427 | defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) | 435 | defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) |
| 428 | 436 |
| @@ -11,13 +11,13 @@ | @@ -11,13 +11,13 @@ | ||
| 11 | 11 | ||
| 12 | #include "sherpa-onnx/csrc/circular-buffer.h" | 12 | #include "sherpa-onnx/csrc/circular-buffer.h" |
| 13 | #include "sherpa-onnx/csrc/display.h" | 13 | #include "sherpa-onnx/csrc/display.h" |
| 14 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 14 | #include "sherpa-onnx/csrc/macros.h" | 15 | #include "sherpa-onnx/csrc/macros.h" |
| 15 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 16 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 16 | #include "sherpa-onnx/csrc/offline-tts.h" | 17 | #include "sherpa-onnx/csrc/offline-tts.h" |
| 17 | #include "sherpa-onnx/csrc/online-recognizer.h" | 18 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 18 | #include "sherpa-onnx/csrc/voice-activity-detector.h" | 19 | #include "sherpa-onnx/csrc/voice-activity-detector.h" |
| 19 | #include "sherpa-onnx/csrc/wave-writer.h" | 20 | #include "sherpa-onnx/csrc/wave-writer.h" |
| 20 | -#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 21 | 21 | ||
| 22 | struct SherpaOnnxOnlineRecognizer { | 22 | struct SherpaOnnxOnlineRecognizer { |
| 23 | std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl; | 23 | std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl; |
| @@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | @@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( | ||
| 301 | recognizer_config.model_config.whisper.language = | 301 | recognizer_config.model_config.whisper.language = |
| 302 | SHERPA_ONNX_OR(config->model_config.whisper.language, ""); | 302 | SHERPA_ONNX_OR(config->model_config.whisper.language, ""); |
| 303 | 303 | ||
| 304 | + recognizer_config.model_config.whisper.task = | ||
| 305 | + SHERPA_ONNX_OR(config->model_config.whisper.task, "transcribe"); | ||
| 306 | + | ||
| 304 | recognizer_config.model_config.tdnn.model = | 307 | recognizer_config.model_config.tdnn.model = |
| 305 | SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); | 308 | SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); |
| 306 | 309 | ||
| @@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter { | @@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter { | ||
| 422 | std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; | 425 | std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; |
| 423 | }; | 426 | }; |
| 424 | 427 | ||
| 425 | -SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | ||
| 426 | - const SherpaOnnxKeywordSpotterConfig* config) { | 428 | +SherpaOnnxKeywordSpotter *CreateKeywordSpotter( |
| 429 | + const SherpaOnnxKeywordSpotterConfig *config) { | ||
| 427 | sherpa_onnx::KeywordSpotterConfig spotter_config; | 430 | sherpa_onnx::KeywordSpotterConfig spotter_config; |
| 428 | 431 | ||
| 429 | spotter_config.feat_config.sampling_rate = | 432 | spotter_config.feat_config.sampling_rate = |
| @@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | @@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | ||
| 457 | spotter_config.model_config.debug = | 460 | spotter_config.model_config.debug = |
| 458 | SHERPA_ONNX_OR(config->model_config.debug, 0); | 461 | SHERPA_ONNX_OR(config->model_config.debug, 0); |
| 459 | 462 | ||
| 460 | - spotter_config.max_active_paths = | ||
| 461 | - SHERPA_ONNX_OR(config->max_active_paths, 4); | 463 | + spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); |
| 462 | 464 | ||
| 463 | spotter_config.num_trailing_blanks = | 465 | spotter_config.num_trailing_blanks = |
| 464 | - SHERPA_ONNX_OR(config->num_trailing_blanks , 1); | 466 | + SHERPA_ONNX_OR(config->num_trailing_blanks, 1); |
| 465 | 467 | ||
| 466 | - spotter_config.keywords_score = | ||
| 467 | - SHERPA_ONNX_OR(config->keywords_score, 1.0); | 468 | + spotter_config.keywords_score = SHERPA_ONNX_OR(config->keywords_score, 1.0); |
| 468 | 469 | ||
| 469 | spotter_config.keywords_threshold = | 470 | spotter_config.keywords_threshold = |
| 470 | SHERPA_ONNX_OR(config->keywords_threshold, 0.25); | 471 | SHERPA_ONNX_OR(config->keywords_threshold, 0.25); |
| 471 | 472 | ||
| 472 | - spotter_config.keywords_file = | ||
| 473 | - SHERPA_ONNX_OR(config->keywords_file, ""); | 473 | + spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); |
| 474 | 474 | ||
| 475 | if (config->model_config.debug) { | 475 | if (config->model_config.debug) { |
| 476 | SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); | 476 | SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); |
| @@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | @@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | ||
| 481 | return nullptr; | 481 | return nullptr; |
| 482 | } | 482 | } |
| 483 | 483 | ||
| 484 | - SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; | 484 | + SherpaOnnxKeywordSpotter *spotter = new SherpaOnnxKeywordSpotter; |
| 485 | 485 | ||
| 486 | - spotter->impl = | ||
| 487 | - std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config); | 486 | + spotter->impl = std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config); |
| 488 | 487 | ||
| 489 | return spotter; | 488 | return spotter; |
| 490 | } | 489 | } |
| 491 | 490 | ||
| 492 | -void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { | 491 | +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) { |
| 493 | delete spotter; | 492 | delete spotter; |
| 494 | } | 493 | } |
| 495 | 494 | ||
| 496 | -SherpaOnnxOnlineStream* CreateKeywordStream( | ||
| 497 | - const SherpaOnnxKeywordSpotter* spotter) { | ||
| 498 | - SherpaOnnxOnlineStream* stream = | 495 | +SherpaOnnxOnlineStream *CreateKeywordStream( |
| 496 | + const SherpaOnnxKeywordSpotter *spotter) { | ||
| 497 | + SherpaOnnxOnlineStream *stream = | ||
| 499 | new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); | 498 | new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); |
| 500 | return stream; | 499 | return stream; |
| 501 | } | 500 | } |
| 502 | 501 | ||
| 503 | -int32_t IsKeywordStreamReady( | ||
| 504 | - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { | 502 | +int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter, |
| 503 | + SherpaOnnxOnlineStream *stream) { | ||
| 505 | return spotter->impl->IsReady(stream->impl.get()); | 504 | return spotter->impl->IsReady(stream->impl.get()); |
| 506 | } | 505 | } |
| 507 | 506 | ||
| 508 | -void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, | ||
| 509 | - SherpaOnnxOnlineStream* stream) { | 507 | +void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter, |
| 508 | + SherpaOnnxOnlineStream *stream) { | ||
| 510 | return spotter->impl->DecodeStream(stream->impl.get()); | 509 | return spotter->impl->DecodeStream(stream->impl.get()); |
| 511 | } | 510 | } |
| 512 | 511 | ||
| 513 | -void DecodeMultipleKeywordStreams( | ||
| 514 | - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, | ||
| 515 | - int32_t n) { | ||
| 516 | - std::vector<sherpa_onnx::OnlineStream*> ss(n); | 512 | +void DecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter, |
| 513 | + SherpaOnnxOnlineStream **streams, int32_t n) { | ||
| 514 | + std::vector<sherpa_onnx::OnlineStream *> ss(n); | ||
| 517 | for (int32_t i = 0; i != n; ++i) { | 515 | for (int32_t i = 0; i != n; ++i) { |
| 518 | ss[i] = streams[i]->impl.get(); | 516 | ss[i] = streams[i]->impl.get(); |
| 519 | } | 517 | } |
| @@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams( | @@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams( | ||
| 522 | 520 | ||
| 523 | const SherpaOnnxKeywordResult *GetKeywordResult( | 521 | const SherpaOnnxKeywordResult *GetKeywordResult( |
| 524 | SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { | 522 | SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { |
| 525 | - const sherpa_onnx::KeywordResult& result = | 523 | + const sherpa_onnx::KeywordResult &result = |
| 526 | spotter->impl->GetResult(stream->impl.get()); | 524 | spotter->impl->GetResult(stream->impl.get()); |
| 527 | const auto &keyword = result.keyword; | 525 | const auto &keyword = result.keyword; |
| 528 | 526 |
| @@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { | @@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig { | ||
| 333 | const char *encoder; | 333 | const char *encoder; |
| 334 | const char *decoder; | 334 | const char *decoder; |
| 335 | const char *language; | 335 | const char *language; |
| 336 | + const char *task; | ||
| 336 | } SherpaOnnxOfflineWhisperModelConfig; | 337 | } SherpaOnnxOfflineWhisperModelConfig; |
| 337 | 338 | ||
| 338 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { | 339 | SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig { |
| @@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { | @@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { | ||
| 483 | /// For Chinese, it consists of Chinese words without spaces. | 484 | /// For Chinese, it consists of Chinese words without spaces. |
| 484 | /// Example 1: "hello world" | 485 | /// Example 1: "hello world" |
| 485 | /// Example 2: "你好世界" | 486 | /// Example 2: "你好世界" |
| 486 | - const char* keyword; | 487 | + const char *keyword; |
| 487 | 488 | ||
| 488 | /// Decoded results at the token level. | 489 | /// Decoded results at the token level. |
| 489 | /// For instance, for BPE-based models it consists of a list of BPE tokens. | 490 | /// For instance, for BPE-based models it consists of a list of BPE tokens. |
| 490 | - const char* tokens; | 491 | + const char *tokens; |
| 491 | 492 | ||
| 492 | - const char* const* tokens_arr; | 493 | + const char *const *tokens_arr; |
| 493 | 494 | ||
| 494 | int32_t count; | 495 | int32_t count; |
| 495 | 496 | ||
| 496 | /// timestamps.size() == tokens.size() | 497 | /// timestamps.size() == tokens.size() |
| 497 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. | 498 | /// timestamps[i] records the time in seconds when tokens[i] is decoded. |
| 498 | - float* timestamps; | 499 | + float *timestamps; |
| 499 | 500 | ||
| 500 | /// Starting time of this segment. | 501 | /// Starting time of this segment. |
| 501 | /// When an endpoint is detected, it will change | 502 | /// When an endpoint is detected, it will change |
| @@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { | @@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { | ||
| 511 | * "start_time": x, | 512 | * "start_time": x, |
| 512 | * } | 513 | * } |
| 513 | */ | 514 | */ |
| 514 | - const char* json; | 515 | + const char *json; |
| 515 | } SherpaOnnxKeywordResult; | 516 | } SherpaOnnxKeywordResult; |
| 516 | 517 | ||
| 517 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { | 518 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { |
| @@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { | @@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { | ||
| 521 | int32_t num_trailing_blanks; | 522 | int32_t num_trailing_blanks; |
| 522 | float keywords_score; | 523 | float keywords_score; |
| 523 | float keywords_threshold; | 524 | float keywords_threshold; |
| 524 | - const char* keywords_file; | 525 | + const char *keywords_file; |
| 525 | } SherpaOnnxKeywordSpotterConfig; | 526 | } SherpaOnnxKeywordSpotterConfig; |
| 526 | 527 | ||
| 527 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | 528 | SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter |
| @@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | @@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | ||
| 530 | /// @param config Config for the keyword spotter. | 531 | /// @param config Config for the keyword spotter. |
| 531 | /// @return Return a pointer to the spotter. The user has to invoke | 532 | /// @return Return a pointer to the spotter. The user has to invoke |
| 532 | /// DestroyKeywordSpotter() to free it to avoid memory leak. | 533 | /// DestroyKeywordSpotter() to free it to avoid memory leak. |
| 533 | -SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | ||
| 534 | - const SherpaOnnxKeywordSpotterConfig* config); | 534 | +SHERPA_ONNX_API SherpaOnnxKeywordSpotter *CreateKeywordSpotter( |
| 535 | + const SherpaOnnxKeywordSpotterConfig *config); | ||
| 535 | 536 | ||
| 536 | /// Free a pointer returned by CreateKeywordSpotter() | 537 | /// Free a pointer returned by CreateKeywordSpotter() |
| 537 | /// | 538 | /// |
| 538 | /// @param p A pointer returned by CreateKeywordSpotter() | 539 | /// @param p A pointer returned by CreateKeywordSpotter() |
| 539 | -SHERPA_ONNX_API void DestroyKeywordSpotter( | ||
| 540 | - SherpaOnnxKeywordSpotter* spotter); | 540 | +SHERPA_ONNX_API void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter); |
| 541 | 541 | ||
| 542 | /// Create an online stream for accepting wave samples. | 542 | /// Create an online stream for accepting wave samples. |
| 543 | /// | 543 | /// |
| 544 | /// @param spotter A pointer returned by CreateKeywordSpotter() | 544 | /// @param spotter A pointer returned by CreateKeywordSpotter() |
| 545 | /// @return Return a pointer to an OnlineStream. The user has to invoke | 545 | /// @return Return a pointer to an OnlineStream. The user has to invoke |
| 546 | /// DestroyOnlineStream() to free it to avoid memory leak. | 546 | /// DestroyOnlineStream() to free it to avoid memory leak. |
| 547 | -SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream( | ||
| 548 | - const SherpaOnnxKeywordSpotter* spotter); | 547 | +SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateKeywordStream( |
| 548 | + const SherpaOnnxKeywordSpotter *spotter); | ||
| 549 | 549 | ||
| 550 | /// Return 1 if there are enough number of feature frames for decoding. | 550 | /// Return 1 if there are enough number of feature frames for decoding. |
| 551 | /// Return 0 otherwise. | 551 | /// Return 0 otherwise. |
| 552 | /// | 552 | /// |
| 553 | /// @param spotter A pointer returned by CreateKeywordSpotter | 553 | /// @param spotter A pointer returned by CreateKeywordSpotter |
| 554 | /// @param stream A pointer returned by CreateKeywordStream | 554 | /// @param stream A pointer returned by CreateKeywordStream |
| 555 | -SHERPA_ONNX_API int32_t IsKeywordStreamReady( | ||
| 556 | - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream); | 555 | +SHERPA_ONNX_API int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter, |
| 556 | + SherpaOnnxOnlineStream *stream); | ||
| 557 | 557 | ||
| 558 | /// Call this function to run the neural network model and decoding. | 558 | /// Call this function to run the neural network model and decoding. |
| 559 | // | 559 | // |
| 560 | /// Precondition for this function: IsKeywordStreamReady() MUST return 1. | 560 | /// Precondition for this function: IsKeywordStreamReady() MUST return 1. |
| 561 | -SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, | ||
| 562 | - SherpaOnnxOnlineStream* stream); | 561 | +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter, |
| 562 | + SherpaOnnxOnlineStream *stream); | ||
| 563 | 563 | ||
| 564 | /// This function is similar to DecodeKeywordStream(). It decodes multiple | 564 | /// This function is similar to DecodeKeywordStream(). It decodes multiple |
| 565 | /// OnlineStream in parallel. | 565 | /// OnlineStream in parallel. |
| @@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( | @@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( | ||
| 588 | /// Destroy the pointer returned by GetKeywordResult(). | 588 | /// Destroy the pointer returned by GetKeywordResult(). |
| 589 | /// | 589 | /// |
| 590 | /// @param r A pointer returned by GetKeywordResult() | 590 | /// @param r A pointer returned by GetKeywordResult() |
| 591 | -SHERPA_ONNX_API void DestroyKeywordResult( | ||
| 592 | - const SherpaOnnxKeywordResult *r); | 591 | +SHERPA_ONNX_API void DestroyKeywordResult(const SherpaOnnxKeywordResult *r); |
| 593 | 592 | ||
| 594 | // ============================================================ | 593 | // ============================================================ |
| 595 | // For VAD | 594 | // For VAD |
| @@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl { | @@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl { | ||
| 223 | inputs.push_back(std::move(length_scale_tensor)); | 223 | inputs.push_back(std::move(length_scale_tensor)); |
| 224 | inputs.push_back(std::move(noise_scale_w_tensor)); | 224 | inputs.push_back(std::move(noise_scale_w_tensor)); |
| 225 | 225 | ||
| 226 | - if (input_names_.size() == 6 && input_names_.back() == "sid") { | 226 | + if (input_names_.size() == 6 && |
| 227 | + (input_names_.back() == "sid" || input_names_.back() == "speaker")) { | ||
| 227 | inputs.push_back(std::move(sid_tensor)); | 228 | inputs.push_back(std::move(sid_tensor)); |
| 228 | } | 229 | } |
| 229 | 230 |
| @@ -2,14 +2,16 @@ | @@ -2,14 +2,16 @@ | ||
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023-2024 Xiaomi Corporation | 3 | // Copyright (c) 2023-2024 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" | ||
| 6 | + | ||
| 5 | #include <algorithm> | 7 | #include <algorithm> |
| 6 | #include <cmath> | 8 | #include <cmath> |
| 9 | +#include <cstring> | ||
| 7 | #include <utility> | 10 | #include <utility> |
| 8 | #include <vector> | 11 | #include <vector> |
| 9 | 12 | ||
| 10 | #include "sherpa-onnx/csrc/log.h" | 13 | #include "sherpa-onnx/csrc/log.h" |
| 11 | #include "sherpa-onnx/csrc/onnx-utils.h" | 14 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 12 | -#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" | ||
| 13 | 15 | ||
| 14 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 15 | 17 |
| @@ -242,17 +242,17 @@ class SherpaOnnxRecognizer { | @@ -242,17 +242,17 @@ class SherpaOnnxRecognizer { | ||
| 242 | /// the given hotWords appended to the default hotwords. | 242 | /// the given hotWords appended to the default hotwords. |
| 243 | func reset(hotwords: String? = nil) { | 243 | func reset(hotwords: String? = nil) { |
| 244 | guard let words = hotwords, !words.isEmpty else { | 244 | guard let words = hotwords, !words.isEmpty else { |
| 245 | - Reset(recognizer, stream) | ||
| 246 | - return | 245 | + Reset(recognizer, stream) |
| 246 | + return | ||
| 247 | } | 247 | } |
| 248 | - | 248 | + |
| 249 | words.withCString { cString in | 249 | words.withCString { cString in |
| 250 | - let newStream = CreateOnlineStreamWithHotwords(recognizer, cString) | ||
| 251 | - // lock while release and replace stream | ||
| 252 | - objc_sync_enter(self) | ||
| 253 | - DestroyOnlineStream(stream) | ||
| 254 | - stream = newStream | ||
| 255 | - objc_sync_exit(self) | 250 | + let newStream = CreateOnlineStreamWithHotwords(recognizer, cString) |
| 251 | + // lock while release and replace stream | ||
| 252 | + objc_sync_enter(self) | ||
| 253 | + DestroyOnlineStream(stream) | ||
| 254 | + stream = newStream | ||
| 255 | + objc_sync_exit(self) | ||
| 256 | } | 256 | } |
| 257 | } | 257 | } |
| 258 | 258 | ||
| @@ -300,11 +300,15 @@ func sherpaOnnxOfflineNemoEncDecCtcModelConfig( | @@ -300,11 +300,15 @@ func sherpaOnnxOfflineNemoEncDecCtcModelConfig( | ||
| 300 | 300 | ||
| 301 | func sherpaOnnxOfflineWhisperModelConfig( | 301 | func sherpaOnnxOfflineWhisperModelConfig( |
| 302 | encoder: String = "", | 302 | encoder: String = "", |
| 303 | - decoder: String = "" | 303 | + decoder: String = "", |
| 304 | + language: String = "", | ||
| 305 | + task: String = "transcribe" | ||
| 304 | ) -> SherpaOnnxOfflineWhisperModelConfig { | 306 | ) -> SherpaOnnxOfflineWhisperModelConfig { |
| 305 | return SherpaOnnxOfflineWhisperModelConfig( | 307 | return SherpaOnnxOfflineWhisperModelConfig( |
| 306 | encoder: toCPointer(encoder), | 308 | encoder: toCPointer(encoder), |
| 307 | - decoder: toCPointer(decoder) | 309 | + decoder: toCPointer(decoder), |
| 310 | + language: toCPointer(language), | ||
| 311 | + task: toCPointer(task) | ||
| 308 | ) | 312 | ) |
| 309 | } | 313 | } |
| 310 | 314 |
| @@ -393,11 +393,13 @@ function initSherpaOnnxOfflineNemoEncDecCtcModelConfig(config, Module) { | @@ -393,11 +393,13 @@ function initSherpaOnnxOfflineNemoEncDecCtcModelConfig(config, Module) { | ||
| 393 | function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { | 393 | function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { |
| 394 | const encoderLen = Module.lengthBytesUTF8(config.encoder) + 1; | 394 | const encoderLen = Module.lengthBytesUTF8(config.encoder) + 1; |
| 395 | const decoderLen = Module.lengthBytesUTF8(config.decoder) + 1; | 395 | const decoderLen = Module.lengthBytesUTF8(config.decoder) + 1; |
| 396 | + const languageLen = Module.lengthBytesUTF8(config.language) + 1; | ||
| 397 | + const taskLen = Module.lengthBytesUTF8(config.task) + 1; | ||
| 396 | 398 | ||
| 397 | - const n = encoderLen + decoderLen; | 399 | + const n = encoderLen + decoderLen + languageLen + taskLen; |
| 398 | const buffer = Module._malloc(n); | 400 | const buffer = Module._malloc(n); |
| 399 | 401 | ||
| 400 | - const len = 2 * 4; // 2 pointers | 402 | + const len = 4 * 4; // 4 pointers |
| 401 | const ptr = Module._malloc(len); | 403 | const ptr = Module._malloc(len); |
| 402 | 404 | ||
| 403 | let offset = 0; | 405 | let offset = 0; |
| @@ -405,12 +407,25 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { | @@ -405,12 +407,25 @@ function initSherpaOnnxOfflineWhisperModelConfig(config, Module) { | ||
| 405 | offset += encoderLen; | 407 | offset += encoderLen; |
| 406 | 408 | ||
| 407 | Module.stringToUTF8(config.decoder, buffer + offset, decoderLen); | 409 | Module.stringToUTF8(config.decoder, buffer + offset, decoderLen); |
| 410 | + offset += decoderLen; | ||
| 411 | + | ||
| 412 | + Module.stringToUTF8(config.language, buffer + offset, languageLen); | ||
| 413 | + offset += languageLen; | ||
| 414 | + | ||
| 415 | + Module.stringToUTF8(config.task, buffer + offset, taskLen); | ||
| 408 | 416 | ||
| 409 | offset = 0; | 417 | offset = 0; |
| 410 | Module.setValue(ptr, buffer + offset, 'i8*'); | 418 | Module.setValue(ptr, buffer + offset, 'i8*'); |
| 411 | offset += encoderLen; | 419 | offset += encoderLen; |
| 412 | 420 | ||
| 413 | Module.setValue(ptr + 4, buffer + offset, 'i8*'); | 421 | Module.setValue(ptr + 4, buffer + offset, 'i8*'); |
| 422 | + offset += decoderLen; | ||
| 423 | + | ||
| 424 | + Module.setValue(ptr + 8, buffer + offset, 'i8*'); | ||
| 425 | + offset += languageLen; | ||
| 426 | + | ||
| 427 | + Module.setValue(ptr + 12, buffer + offset, 'i8*'); | ||
| 428 | + offset += taskLen; | ||
| 414 | 429 | ||
| 415 | return { | 430 | return { |
| 416 | buffer: buffer, ptr: ptr, len: len, | 431 | buffer: buffer, ptr: ptr, len: len, |
| @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); | @@ -14,7 +14,7 @@ static_assert(sizeof(SherpaOnnxOfflineTransducerModelConfig) == 3 * 4, ""); | ||
| 14 | static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); | 14 | static_assert(sizeof(SherpaOnnxOfflineParaformerModelConfig) == 4, ""); |
| 15 | 15 | ||
| 16 | static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); | 16 | static_assert(sizeof(SherpaOnnxOfflineNemoEncDecCtcModelConfig) == 4, ""); |
| 17 | -static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 2 * 4, ""); | 17 | +static_assert(sizeof(SherpaOnnxOfflineWhisperModelConfig) == 4 * 4, ""); |
| 18 | static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); | 18 | static_assert(sizeof(SherpaOnnxOfflineTdnnModelConfig) == 4, ""); |
| 19 | static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); | 19 | static_assert(sizeof(SherpaOnnxOfflineLMConfig) == 2 * 4, ""); |
| 20 | 20 | ||
| @@ -77,6 +77,8 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { | @@ -77,6 +77,8 @@ void PrintOfflineRecognizerConfig(SherpaOnnxOfflineRecognizerConfig *config) { | ||
| 77 | fprintf(stdout, "----------offline whisper model config----------\n"); | 77 | fprintf(stdout, "----------offline whisper model config----------\n"); |
| 78 | fprintf(stdout, "encoder: %s\n", whisper->encoder); | 78 | fprintf(stdout, "encoder: %s\n", whisper->encoder); |
| 79 | fprintf(stdout, "decoder: %s\n", whisper->decoder); | 79 | fprintf(stdout, "decoder: %s\n", whisper->decoder); |
| 80 | + fprintf(stdout, "language: %s\n", whisper->language); | ||
| 81 | + fprintf(stdout, "task: %s\n", whisper->task); | ||
| 80 | 82 | ||
| 81 | fprintf(stdout, "----------offline tdnn model config----------\n"); | 83 | fprintf(stdout, "----------offline tdnn model config----------\n"); |
| 82 | fprintf(stdout, "model: %s\n", tdnn->model); | 84 | fprintf(stdout, "model: %s\n", tdnn->model); |
-
请 注册 或 登录 后发表评论