正在显示
2 个修改的文件
包含
301 行增加
和
0 行删除
| @@ -17,6 +17,7 @@ | @@ -17,6 +17,7 @@ | ||
| 17 | #include "sherpa-onnx/csrc/online-recognizer.h" | 17 | #include "sherpa-onnx/csrc/online-recognizer.h" |
| 18 | #include "sherpa-onnx/csrc/voice-activity-detector.h" | 18 | #include "sherpa-onnx/csrc/voice-activity-detector.h" |
| 19 | #include "sherpa-onnx/csrc/wave-writer.h" | 19 | #include "sherpa-onnx/csrc/wave-writer.h" |
| 20 | +#include "sherpa-onnx/csrc/keyword-spotter.h" | ||
| 20 | 21 | ||
| 21 | struct SherpaOnnxOnlineRecognizer { | 22 | struct SherpaOnnxOnlineRecognizer { |
| 22 | std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl; | 23 | std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl; |
| @@ -411,6 +412,189 @@ void DestroyOfflineRecognizerResult( | @@ -411,6 +412,189 @@ void DestroyOfflineRecognizerResult( | ||
| 411 | } | 412 | } |
| 412 | 413 | ||
| 413 | // ============================================================ | 414 | // ============================================================ |
| 415 | +// For Keyword Spot | ||
| 416 | +// ============================================================ | ||
| 417 | + | ||
| 418 | +struct SherpaOnnxKeywordSpotter { | ||
| 419 | + std::unique_ptr<sherpa_onnx::KeywordSpotter> impl; | ||
| 420 | +}; | ||
| 421 | + | ||
| 422 | +SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | ||
| 423 | + const SherpaOnnxKeywordSpotterConfig* config) { | ||
| 424 | + sherpa_onnx::KeywordSpotterConfig spotter_config; | ||
| 425 | + | ||
| 426 | + spotter_config.feat_config.sampling_rate = | ||
| 427 | + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); | ||
| 428 | + spotter_config.feat_config.feature_dim = | ||
| 429 | + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); | ||
| 430 | + | ||
| 431 | + spotter_config.model_config.transducer.encoder = | ||
| 432 | + SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); | ||
| 433 | + spotter_config.model_config.transducer.decoder = | ||
| 434 | + SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); | ||
| 435 | + spotter_config.model_config.transducer.joiner = | ||
| 436 | + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); | ||
| 437 | + | ||
| 438 | + spotter_config.model_config.paraformer.encoder = | ||
| 439 | + SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); | ||
| 440 | + spotter_config.model_config.paraformer.decoder = | ||
| 441 | + SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); | ||
| 442 | + | ||
| 443 | + spotter_config.model_config.zipformer2_ctc.model = | ||
| 444 | + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); | ||
| 445 | + | ||
| 446 | + spotter_config.model_config.tokens = | ||
| 447 | + SHERPA_ONNX_OR(config->model_config.tokens, ""); | ||
| 448 | + spotter_config.model_config.num_threads = | ||
| 449 | + SHERPA_ONNX_OR(config->model_config.num_threads, 1); | ||
| 450 | + spotter_config.model_config.provider = | ||
| 451 | + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); | ||
| 452 | + spotter_config.model_config.model_type = | ||
| 453 | + SHERPA_ONNX_OR(config->model_config.model_type, ""); | ||
| 454 | + spotter_config.model_config.debug = | ||
| 455 | + SHERPA_ONNX_OR(config->model_config.debug, 0); | ||
| 456 | + | ||
| 457 | + spotter_config.max_active_paths = | ||
| 458 | + SHERPA_ONNX_OR(config->max_active_paths, 4); | ||
| 459 | + | ||
| 460 | + spotter_config.num_trailing_blanks = | ||
| 461 | + SHERPA_ONNX_OR(config->num_trailing_blanks , 1); | ||
| 462 | + | ||
| 463 | + spotter_config.keywords_score = | ||
| 464 | + SHERPA_ONNX_OR(config->keywords_score, 1.0); | ||
| 465 | + | ||
| 466 | + spotter_config.keywords_threshold = | ||
| 467 | + SHERPA_ONNX_OR(config->keywords_threshold, 0.25); | ||
| 468 | + | ||
| 469 | + spotter_config.keywords_file = | ||
| 470 | + SHERPA_ONNX_OR(config->keywords_file, ""); | ||
| 471 | + | ||
| 472 | + if (config->model_config.debug) { | ||
| 473 | + SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); | ||
| 474 | + } | ||
| 475 | + | ||
| 476 | + if (!spotter_config.Validate()) { | ||
| 477 | + SHERPA_ONNX_LOGE("Errors in config!"); | ||
| 478 | + return nullptr; | ||
| 479 | + } | ||
| 480 | + | ||
| 481 | + SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; | ||
| 482 | + | ||
| 483 | + spotter->impl = | ||
| 484 | + std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config); | ||
| 485 | + | ||
| 486 | + return spotter; | ||
| 487 | +} | ||
| 488 | + | ||
| 489 | +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { | ||
| 490 | + delete spotter; | ||
| 491 | +} | ||
| 492 | + | ||
| 493 | +SherpaOnnxOnlineStream* CreateKeywordStream( | ||
| 494 | + const SherpaOnnxKeywordSpotter* spotter) { | ||
| 495 | + SherpaOnnxOnlineStream* stream = | ||
| 496 | + new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); | ||
| 497 | + return stream; | ||
| 498 | +} | ||
| 499 | + | ||
| 500 | +int32_t IsKeywordStreamReady( | ||
| 501 | + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { | ||
| 502 | + return spotter->impl->IsReady(stream->impl.get()); | ||
| 503 | +} | ||
| 504 | + | ||
| 505 | +void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, | ||
| 506 | + SherpaOnnxOnlineStream* stream) { | ||
| 507 | + return spotter->impl->DecodeStream(stream->impl.get()); | ||
| 508 | +} | ||
| 509 | + | ||
| 510 | +void DecodeMultipleKeywordStreams( | ||
| 511 | + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, | ||
| 512 | + int32_t n) { | ||
| 513 | + std::vector<sherpa_onnx::OnlineStream*> ss(n); | ||
| 514 | + for (int32_t i = 0; i != n; ++i) { | ||
| 515 | + ss[i] = streams[i]->impl.get(); | ||
| 516 | + } | ||
| 517 | + spotter->impl->DecodeStreams(ss.data(), n); | ||
| 518 | +} | ||
| 519 | + | ||
| 520 | +const SherpaOnnxKeywordResult *GetKeywordResult( | ||
| 521 | + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { | ||
| 522 | + const sherpa_onnx::KeywordResult& result = | ||
| 523 | + spotter->impl->GetResult(stream->impl.get()); | ||
| 524 | + const auto &keyword = result.keyword; | ||
| 525 | + | ||
| 526 | + auto r = new SherpaOnnxKeywordResult; | ||
| 527 | + memset(r, 0, sizeof(SherpaOnnxKeywordResult)); | ||
| 528 | + | ||
| 529 | + r->start_time = result.start_time; | ||
| 530 | + | ||
| 531 | + // copy keyword | ||
| 532 | + r->keyword = new char[keyword.size() + 1]; | ||
| 533 | + std::copy(keyword.begin(), keyword.end(), const_cast<char *>(r->keyword)); | ||
| 534 | + const_cast<char *>(r->keyword)[keyword.size()] = 0; | ||
| 535 | + | ||
| 536 | + // copy json | ||
| 537 | + const auto &json = result.AsJsonString(); | ||
| 538 | + r->json = new char[json.size() + 1]; | ||
| 539 | + std::copy(json.begin(), json.end(), const_cast<char *>(r->json)); | ||
| 540 | + const_cast<char *>(r->json)[json.size()] = 0; | ||
| 541 | + | ||
| 542 | + // copy tokens | ||
| 543 | + auto count = result.tokens.size(); | ||
| 544 | + if (count > 0) { | ||
| 545 | + size_t total_length = 0; | ||
| 546 | + for (const auto &token : result.tokens) { | ||
| 547 | + // +1 for the null character at the end of each token | ||
| 548 | + total_length += token.size() + 1; | ||
| 549 | + } | ||
| 550 | + | ||
| 551 | + r->count = count; | ||
| 552 | + // Each word ends with nullptr | ||
| 553 | + r->tokens = new char[total_length]; | ||
| 554 | + memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0, | ||
| 555 | + total_length); | ||
| 556 | + char **tokens_temp = new char *[r->count]; | ||
| 557 | + int32_t pos = 0; | ||
| 558 | + for (int32_t i = 0; i < r->count; ++i) { | ||
| 559 | + tokens_temp[i] = const_cast<char *>(r->tokens) + pos; | ||
| 560 | + memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)), | ||
| 561 | + result.tokens[i].c_str(), result.tokens[i].size()); | ||
| 562 | + // +1 to move past the null character | ||
| 563 | + pos += result.tokens[i].size() + 1; | ||
| 564 | + } | ||
| 565 | + r->tokens_arr = tokens_temp; | ||
| 566 | + | ||
| 567 | + if (!result.timestamps.empty()) { | ||
| 568 | + r->timestamps = new float[result.timestamps.size()]; | ||
| 569 | + std::copy(result.timestamps.begin(), result.timestamps.end(), | ||
| 570 | + r->timestamps); | ||
| 571 | + } else { | ||
| 572 | + r->timestamps = nullptr; | ||
| 573 | + } | ||
| 574 | + | ||
| 575 | + } else { | ||
| 576 | + r->count = 0; | ||
| 577 | + r->timestamps = nullptr; | ||
| 578 | + r->tokens = nullptr; | ||
| 579 | + r->tokens_arr = nullptr; | ||
| 580 | + } | ||
| 581 | + | ||
| 582 | + return r; | ||
| 583 | +} | ||
| 584 | + | ||
| 585 | +void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) { | ||
| 586 | + if (r) { | ||
| 587 | + delete[] r->keyword; | ||
| 588 | + delete[] r->json; | ||
| 589 | + delete[] r->tokens; | ||
| 590 | + delete[] r->tokens_arr; | ||
| 591 | + delete[] r->timestamps; | ||
| 592 | + delete r; | ||
| 593 | + } | ||
| 594 | +} | ||
| 595 | + | ||
| 596 | + | ||
| 597 | +// ============================================================ | ||
| 414 | // For VAD | 598 | // For VAD |
| 415 | // ============================================================ | 599 | // ============================================================ |
| 416 | // | 600 | // |
| @@ -474,6 +474,123 @@ SHERPA_ONNX_API void DestroyOfflineRecognizerResult( | @@ -474,6 +474,123 @@ SHERPA_ONNX_API void DestroyOfflineRecognizerResult( | ||
| 474 | const SherpaOnnxOfflineRecognizerResult *r); | 474 | const SherpaOnnxOfflineRecognizerResult *r); |
| 475 | 475 | ||
| 476 | // ============================================================ | 476 | // ============================================================ |
| 477 | +// For Keyword Spot | ||
| 478 | +// ============================================================ | ||
| 479 | +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult { | ||
| 480 | + /// The triggered keyword. | ||
| 481 | + /// For English, it consists of space separated words. | ||
| 482 | + /// For Chinese, it consists of Chinese words without spaces. | ||
| 483 | + /// Example 1: "hello world" | ||
| 484 | + /// Example 2: "你好世界" | ||
| 485 | + const char* keyword; | ||
| 486 | + | ||
| 487 | + /// Decoded results at the token level. | ||
| 488 | + /// For instance, for BPE-based models it consists of a list of BPE tokens. | ||
| 489 | + const char* tokens; | ||
| 490 | + | ||
| 491 | + const char* const* tokens_arr; | ||
| 492 | + | ||
| 493 | + int32_t count; | ||
| 494 | + | ||
| 495 | + /// timestamps.size() == tokens.size() | ||
| 496 | + /// timestamps[i] records the time in seconds when tokens[i] is decoded. | ||
| 497 | + float* timestamps; | ||
| 498 | + | ||
| 499 | + /// Starting time of this segment. | ||
| 500 | + /// When an endpoint is detected, it will change | ||
| 501 | + float start_time; | ||
| 502 | + | ||
| 503 | + /** Return a json string. | ||
| 504 | + * | ||
| 505 | + * The returned string contains: | ||
| 506 | + * { | ||
| 507 | + * "keyword": "The triggered keyword", | ||
| 508 | + * "tokens": [x, x, x], | ||
| 509 | + * "timestamps": [x, x, x], | ||
| 510 | + * "start_time": x, | ||
| 511 | + * } | ||
| 512 | + */ | ||
| 513 | + const char* json; | ||
| 514 | +} SherpaOnnxKeywordResult; | ||
| 515 | + | ||
| 516 | +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig { | ||
| 517 | + SherpaOnnxFeatureConfig feat_config; | ||
| 518 | + SherpaOnnxOnlineModelConfig model_config; | ||
| 519 | + int32_t max_active_paths; | ||
| 520 | + int32_t num_trailing_blanks; | ||
| 521 | + float keywords_score; | ||
| 522 | + float keywords_threshold; | ||
| 523 | + const char* keywords_file; | ||
| 524 | +} SherpaOnnxKeywordSpotterConfig; | ||
| 525 | + | ||
| 526 | +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter | ||
| 527 | + SherpaOnnxKeywordSpotter; | ||
| 528 | + | ||
| 529 | +/// @param config Config for the keyword spotter. | ||
| 530 | +/// @return Return a pointer to the spotter. The user has to invoke | ||
| 531 | +/// DestroyKeywordSpotter() to free it to avoid memory leak. | ||
| 532 | +SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter( | ||
| 533 | + const SherpaOnnxKeywordSpotterConfig* config); | ||
| 534 | + | ||
| 535 | +/// Free a pointer returned by CreateKeywordSpotter() | ||
| 536 | +/// | ||
| 537 | +/// @param p A pointer returned by CreateKeywordSpotter() | ||
| 538 | +SHERPA_ONNX_API void DestroyKeywordSpotter( | ||
| 539 | + SherpaOnnxKeywordSpotter* spotter); | ||
| 540 | + | ||
| 541 | +/// Create an online stream for accepting wave samples. | ||
| 542 | +/// | ||
| 543 | +/// @param spotter A pointer returned by CreateKeywordSpotter() | ||
| 544 | +/// @return Return a pointer to an OnlineStream. The user has to invoke | ||
| 545 | +/// DestroyOnlineStream() to free it to avoid memory leak. | ||
| 546 | +SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream( | ||
| 547 | + const SherpaOnnxKeywordSpotter* spotter); | ||
| 548 | + | ||
| 549 | +/// Return 1 if there are enough number of feature frames for decoding. | ||
| 550 | +/// Return 0 otherwise. | ||
| 551 | +/// | ||
| 552 | +/// @param spotter A pointer returned by CreateKeywordSpotter | ||
| 553 | +/// @param stream A pointer returned by CreateKeywordStream | ||
| 554 | +SHERPA_ONNX_API int32_t IsKeywordStreamReady( | ||
| 555 | + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream); | ||
| 556 | + | ||
| 557 | +/// Call this function to run the neural network model and decoding. | ||
| 558 | +// | ||
| 559 | +/// Precondition for this function: IsKeywordStreamReady() MUST return 1. | ||
| 560 | +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, | ||
| 561 | + SherpaOnnxOnlineStream* stream); | ||
| 562 | + | ||
| 563 | +/// This function is similar to DecodeKeywordStream(). It decodes multiple | ||
| 564 | +/// OnlineStream in parallel. | ||
| 565 | +/// | ||
| 566 | +/// Caution: The caller has to ensure each OnlineStream is ready, i.e., | ||
| 567 | +/// IsKeywordStreamReady() for that stream should return 1. | ||
| 568 | +/// | ||
| 569 | +/// @param spotter A pointer returned by CreateKeywordSpotter() | ||
| 570 | +/// @param streams A pointer array containing pointers returned by | ||
| 571 | +/// CreateKeywordStream() | ||
| 572 | +/// @param n Number of elements in the given streams array. | ||
| 573 | +SHERPA_ONNX_API void DecodeMultipleKeywordStreams( | ||
| 574 | + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, | ||
| 575 | + int32_t n); | ||
| 576 | + | ||
| 577 | +/// Get the decoding results so far for an OnlineStream. | ||
| 578 | +/// | ||
| 579 | +/// @param recognizer A pointer returned by CreateKeywordSpotter(). | ||
| 580 | +/// @param stream A pointer returned by CreateKeywordStream(). | ||
| 581 | +/// @return A pointer containing the result. The user has to invoke | ||
| 582 | +/// DestroyKeywordResult() to free the returned pointer to | ||
| 583 | +/// avoid memory leak. | ||
| 584 | +SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult( | ||
| 585 | + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream); | ||
| 586 | + | ||
| 587 | +/// Destroy the pointer returned by GetKeywordResult(). | ||
| 588 | +/// | ||
| 589 | +/// @param r A pointer returned by GetKeywordResult() | ||
| 590 | +SHERPA_ONNX_API void DestroyKeywordResult( | ||
| 591 | + const SherpaOnnxKeywordResult *r); | ||
| 592 | + | ||
| 593 | +// ============================================================ | ||
| 477 | // For VAD | 594 | // For VAD |
| 478 | // ============================================================ | 595 | // ============================================================ |
| 479 | 596 |
-
请 注册 或 登录 后发表评论