xinhecuican
Committed by GitHub

c++ api for keyword spotter (#642)

@@ -17,6 +17,7 @@ @@ -17,6 +17,7 @@
17 #include "sherpa-onnx/csrc/online-recognizer.h" 17 #include "sherpa-onnx/csrc/online-recognizer.h"
18 #include "sherpa-onnx/csrc/voice-activity-detector.h" 18 #include "sherpa-onnx/csrc/voice-activity-detector.h"
19 #include "sherpa-onnx/csrc/wave-writer.h" 19 #include "sherpa-onnx/csrc/wave-writer.h"
  20 +#include "sherpa-onnx/csrc/keyword-spotter.h"
20 21
21 struct SherpaOnnxOnlineRecognizer { 22 struct SherpaOnnxOnlineRecognizer {
22 std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl; 23 std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
@@ -411,6 +412,189 @@ void DestroyOfflineRecognizerResult( @@ -411,6 +412,189 @@ void DestroyOfflineRecognizerResult(
411 } 412 }
412 413
413 // ============================================================ 414 // ============================================================
  415 +// For Keyword Spot
  416 +// ============================================================
  417 +
  418 +struct SherpaOnnxKeywordSpotter {
  419 + std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
  420 +};
  421 +
  422 +SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
  423 + const SherpaOnnxKeywordSpotterConfig* config) {
  424 + sherpa_onnx::KeywordSpotterConfig spotter_config;
  425 +
  426 + spotter_config.feat_config.sampling_rate =
  427 + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000);
  428 + spotter_config.feat_config.feature_dim =
  429 + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80);
  430 +
  431 + spotter_config.model_config.transducer.encoder =
  432 + SHERPA_ONNX_OR(config->model_config.transducer.encoder, "");
  433 + spotter_config.model_config.transducer.decoder =
  434 + SHERPA_ONNX_OR(config->model_config.transducer.decoder, "");
  435 + spotter_config.model_config.transducer.joiner =
  436 + SHERPA_ONNX_OR(config->model_config.transducer.joiner, "");
  437 +
  438 + spotter_config.model_config.paraformer.encoder =
  439 + SHERPA_ONNX_OR(config->model_config.paraformer.encoder, "");
  440 + spotter_config.model_config.paraformer.decoder =
  441 + SHERPA_ONNX_OR(config->model_config.paraformer.decoder, "");
  442 +
  443 + spotter_config.model_config.zipformer2_ctc.model =
  444 + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, "");
  445 +
  446 + spotter_config.model_config.tokens =
  447 + SHERPA_ONNX_OR(config->model_config.tokens, "");
  448 + spotter_config.model_config.num_threads =
  449 + SHERPA_ONNX_OR(config->model_config.num_threads, 1);
  450 + spotter_config.model_config.provider =
  451 + SHERPA_ONNX_OR(config->model_config.provider, "cpu");
  452 + spotter_config.model_config.model_type =
  453 + SHERPA_ONNX_OR(config->model_config.model_type, "");
  454 + spotter_config.model_config.debug =
  455 + SHERPA_ONNX_OR(config->model_config.debug, 0);
  456 +
  457 + spotter_config.max_active_paths =
  458 + SHERPA_ONNX_OR(config->max_active_paths, 4);
  459 +
  460 + spotter_config.num_trailing_blanks =
  461 + SHERPA_ONNX_OR(config->num_trailing_blanks , 1);
  462 +
  463 + spotter_config.keywords_score =
  464 + SHERPA_ONNX_OR(config->keywords_score, 1.0);
  465 +
  466 + spotter_config.keywords_threshold =
  467 + SHERPA_ONNX_OR(config->keywords_threshold, 0.25);
  468 +
  469 + spotter_config.keywords_file =
  470 + SHERPA_ONNX_OR(config->keywords_file, "");
  471 +
  472 + if (config->model_config.debug) {
  473 + SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
  474 + }
  475 +
  476 + if (!spotter_config.Validate()) {
  477 + SHERPA_ONNX_LOGE("Errors in config!");
  478 + return nullptr;
  479 + }
  480 +
  481 + SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter;
  482 +
  483 + spotter->impl =
  484 + std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);
  485 +
  486 + return spotter;
  487 +}
  488 +
  489 +void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) {
  490 + delete spotter;
  491 +}
  492 +
  493 +SherpaOnnxOnlineStream* CreateKeywordStream(
  494 + const SherpaOnnxKeywordSpotter* spotter) {
  495 + SherpaOnnxOnlineStream* stream =
  496 + new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
  497 + return stream;
  498 +}
  499 +
  500 +int32_t IsKeywordStreamReady(
  501 + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) {
  502 + return spotter->impl->IsReady(stream->impl.get());
  503 +}
  504 +
  505 +void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
  506 + SherpaOnnxOnlineStream* stream) {
  507 + return spotter->impl->DecodeStream(stream->impl.get());
  508 +}
  509 +
  510 +void DecodeMultipleKeywordStreams(
  511 + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
  512 + int32_t n) {
  513 + std::vector<sherpa_onnx::OnlineStream*> ss(n);
  514 + for (int32_t i = 0; i != n; ++i) {
  515 + ss[i] = streams[i]->impl.get();
  516 + }
  517 + spotter->impl->DecodeStreams(ss.data(), n);
  518 +}
  519 +
  520 +const SherpaOnnxKeywordResult *GetKeywordResult(
  521 + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
  522 + const sherpa_onnx::KeywordResult& result =
  523 + spotter->impl->GetResult(stream->impl.get());
  524 + const auto &keyword = result.keyword;
  525 +
  526 + auto r = new SherpaOnnxKeywordResult;
  527 + memset(r, 0, sizeof(SherpaOnnxKeywordResult));
  528 +
  529 + r->start_time = result.start_time;
  530 +
  531 + // copy keyword
  532 + r->keyword = new char[keyword.size() + 1];
  533 + std::copy(keyword.begin(), keyword.end(), const_cast<char *>(r->keyword));
  534 + const_cast<char *>(r->keyword)[keyword.size()] = 0;
  535 +
  536 + // copy json
  537 + const auto &json = result.AsJsonString();
  538 + r->json = new char[json.size() + 1];
  539 + std::copy(json.begin(), json.end(), const_cast<char *>(r->json));
  540 + const_cast<char *>(r->json)[json.size()] = 0;
  541 +
  542 + // copy tokens
  543 + auto count = result.tokens.size();
  544 + if (count > 0) {
  545 + size_t total_length = 0;
  546 + for (const auto &token : result.tokens) {
  547 + // +1 for the null character at the end of each token
  548 + total_length += token.size() + 1;
  549 + }
  550 +
  551 + r->count = count;
  552 + // Each word ends with nullptr
  553 + r->tokens = new char[total_length];
  554 + memset(reinterpret_cast<void *>(const_cast<char *>(r->tokens)), 0,
  555 + total_length);
  556 + char **tokens_temp = new char *[r->count];
  557 + int32_t pos = 0;
  558 + for (int32_t i = 0; i < r->count; ++i) {
  559 + tokens_temp[i] = const_cast<char *>(r->tokens) + pos;
  560 + memcpy(reinterpret_cast<void *>(const_cast<char *>(r->tokens + pos)),
  561 + result.tokens[i].c_str(), result.tokens[i].size());
  562 + // +1 to move past the null character
  563 + pos += result.tokens[i].size() + 1;
  564 + }
  565 + r->tokens_arr = tokens_temp;
  566 +
  567 + if (!result.timestamps.empty()) {
  568 + r->timestamps = new float[result.timestamps.size()];
  569 + std::copy(result.timestamps.begin(), result.timestamps.end(),
  570 + r->timestamps);
  571 + } else {
  572 + r->timestamps = nullptr;
  573 + }
  574 +
  575 + } else {
  576 + r->count = 0;
  577 + r->timestamps = nullptr;
  578 + r->tokens = nullptr;
  579 + r->tokens_arr = nullptr;
  580 + }
  581 +
  582 + return r;
  583 +}
  584 +
  585 +void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) {
  586 + if (r) {
  587 + delete[] r->keyword;
  588 + delete[] r->json;
  589 + delete[] r->tokens;
  590 + delete[] r->tokens_arr;
  591 + delete[] r->timestamps;
  592 + delete r;
  593 + }
  594 +}
  595 +
  596 +
  597 +// ============================================================
414 // For VAD 598 // For VAD
415 // ============================================================ 599 // ============================================================
416 // 600 //
@@ -474,6 +474,123 @@ SHERPA_ONNX_API void DestroyOfflineRecognizerResult( @@ -474,6 +474,123 @@ SHERPA_ONNX_API void DestroyOfflineRecognizerResult(
474 const SherpaOnnxOfflineRecognizerResult *r); 474 const SherpaOnnxOfflineRecognizerResult *r);
475 475
476 // ============================================================ 476 // ============================================================
  477 +// For Keyword Spot
  478 +// ============================================================
  479 +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
  480 + /// The triggered keyword.
  481 + /// For English, it consists of space separated words.
  482 + /// For Chinese, it consists of Chinese words without spaces.
  483 + /// Example 1: "hello world"
  484 + /// Example 2: "你好世界"
  485 + const char* keyword;
  486 +
  487 + /// Decoded results at the token level.
  488 + /// For instance, for BPE-based models it consists of a list of BPE tokens.
  489 + const char* tokens;
  490 +
  491 + const char* const* tokens_arr;
  492 +
  493 + int32_t count;
  494 +
  495 + /// timestamps.size() == tokens.size()
  496 + /// timestamps[i] records the time in seconds when tokens[i] is decoded.
  497 + float* timestamps;
  498 +
  499 + /// Starting time of this segment.
  500 + /// When an endpoint is detected, it will change
  501 + float start_time;
  502 +
  503 + /** Return a json string.
  504 + *
  505 + * The returned string contains:
  506 + * {
  507 + * "keyword": "The triggered keyword",
  508 + * "tokens": [x, x, x],
  509 + * "timestamps": [x, x, x],
  510 + * "start_time": x,
  511 + * }
  512 + */
  513 + const char* json;
  514 +} SherpaOnnxKeywordResult;
  515 +
  516 +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
  517 + SherpaOnnxFeatureConfig feat_config;
  518 + SherpaOnnxOnlineModelConfig model_config;
  519 + int32_t max_active_paths;
  520 + int32_t num_trailing_blanks;
  521 + float keywords_score;
  522 + float keywords_threshold;
  523 + const char* keywords_file;
  524 +} SherpaOnnxKeywordSpotterConfig;
  525 +
  526 +SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
  527 + SherpaOnnxKeywordSpotter;
  528 +
  529 +/// @param config Config for the keyword spotter.
  530 +/// @return Return a pointer to the spotter. The user has to invoke
  531 +/// DestroyKeywordSpotter() to free it to avoid memory leak.
  532 +SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
  533 + const SherpaOnnxKeywordSpotterConfig* config);
  534 +
  535 +/// Free a pointer returned by CreateKeywordSpotter()
  536 +///
  537 +/// @param p A pointer returned by CreateKeywordSpotter()
  538 +SHERPA_ONNX_API void DestroyKeywordSpotter(
  539 + SherpaOnnxKeywordSpotter* spotter);
  540 +
  541 +/// Create an online stream for accepting wave samples.
  542 +///
  543 +/// @param spotter A pointer returned by CreateKeywordSpotter()
  544 +/// @return Return a pointer to an OnlineStream. The user has to invoke
  545 +/// DestroyOnlineStream() to free it to avoid memory leak.
  546 +SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream(
  547 + const SherpaOnnxKeywordSpotter* spotter);
  548 +
  549 +/// Return 1 if there are enough number of feature frames for decoding.
  550 +/// Return 0 otherwise.
  551 +///
  552 +/// @param spotter A pointer returned by CreateKeywordSpotter
  553 +/// @param stream A pointer returned by CreateKeywordStream
  554 +SHERPA_ONNX_API int32_t IsKeywordStreamReady(
  555 + SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream);
  556 +
  557 +/// Call this function to run the neural network model and decoding.
  558 +//
  559 +/// Precondition for this function: IsKeywordStreamReady() MUST return 1.
  560 +SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
  561 + SherpaOnnxOnlineStream* stream);
  562 +
  563 +/// This function is similar to DecodeKeywordStream(). It decodes multiple
  564 +/// OnlineStream in parallel.
  565 +///
  566 +/// Caution: The caller has to ensure each OnlineStream is ready, i.e.,
  567 +/// IsKeywordStreamReady() for that stream should return 1.
  568 +///
  569 +/// @param spotter A pointer returned by CreateKeywordSpotter()
  570 +/// @param streams A pointer array containing pointers returned by
  571 +/// CreateKeywordStream()
  572 +/// @param n Number of elements in the given streams array.
  573 +SHERPA_ONNX_API void DecodeMultipleKeywordStreams(
  574 + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
  575 + int32_t n);
  576 +
  577 +/// Get the decoding results so far for an OnlineStream.
  578 +///
  579 +/// @param recognizer A pointer returned by CreateKeywordSpotter().
  580 +/// @param stream A pointer returned by CreateKeywordStream().
  581 +/// @return A pointer containing the result. The user has to invoke
  582 +/// DestroyKeywordResult() to free the returned pointer to
  583 +/// avoid memory leak.
  584 +SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult(
  585 + SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream);
  586 +
  587 +/// Destroy the pointer returned by GetKeywordResult().
  588 +///
  589 +/// @param r A pointer returned by GetKeywordResult()
  590 +SHERPA_ONNX_API void DestroyKeywordResult(
  591 + const SherpaOnnxKeywordResult *r);
  592 +
  593 +// ============================================================
477 // For VAD 594 // For VAD
478 // ============================================================ 595 // ============================================================
479 596