Aruxxxi
Committed by GitHub

feat: add punctuation C++ API (#2510)

Co-authored-by: Aruxxxi <xiangcl@zhisuan.com>
@@ -30,6 +30,9 @@ target_link_libraries(sense-voice-cxx-api sherpa-onnx-cxx-api) @@ -30,6 +30,9 @@ target_link_libraries(sense-voice-cxx-api sherpa-onnx-cxx-api)
30 add_executable(nemo-canary-cxx-api ./nemo-canary-cxx-api.cc) 30 add_executable(nemo-canary-cxx-api ./nemo-canary-cxx-api.cc)
31 target_link_libraries(nemo-canary-cxx-api sherpa-onnx-cxx-api) 31 target_link_libraries(nemo-canary-cxx-api sherpa-onnx-cxx-api)
32 32
  33 +add_executable(punctuation-cxx-api ./punctuation-cxx-api.cc)
  34 +target_link_libraries(punctuation-cxx-api sherpa-onnx-cxx-api)
  35 +
33 if(SHERPA_ONNX_ENABLE_PORTAUDIO) 36 if(SHERPA_ONNX_ENABLE_PORTAUDIO)
34 add_executable(sense-voice-simulate-streaming-microphone-cxx-api 37 add_executable(sense-voice-simulate-streaming-microphone-cxx-api
35 ./sense-voice-simulate-streaming-microphone-cxx-api.cc 38 ./sense-voice-simulate-streaming-microphone-cxx-api.cc
  1 +// cxx-api-examples/punctuation-cxx-api.cc
  2 +// Copyright (c) 2025 Xiaomi Corporation
  3 +
  4 +// To use punctuation model:
  5 +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  6 +// tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  7 +// rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  8 +
  9 +#include <iostream>
  10 +#include <string>
  11 +
  12 +#include "sherpa-onnx/c-api/cxx-api.h"
  13 +
  14 +int32_t main() {
  15 + using namespace sherpa_onnx::cxx; // NOLINT
  16 +
  17 + OfflinePunctuationConfig punctuation_config;
  18 + punctuation_config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx";
  19 + punctuation_config.model.num_threads = 1;
  20 + punctuation_config.model.debug = false;
  21 + punctuation_config.model.provider = "cpu";
  22 +
  23 + OfflinePunctuation punct = OfflinePunctuation::Create(punctuation_config);
  24 + if (!punct.Get()) {
  25 + std::cerr << "Failed to create punctuation model. Please check your config\n";
  26 + return -1;
  27 + }
  28 +
  29 + std::string text = "你好吗how are you Fantasitic 谢谢我很好你怎么样呢";
  30 + std::string text_with_punct = punct.AddPunctuation(text);
  31 + std::cout << "Original text: " << text << std::endl;
  32 + std::cout << "With punctuation: " << text_with_punct << std::endl;
  33 +
  34 + return 0;
  35 +}
@@ -821,4 +821,33 @@ bool FileExists(const std::string &filename) { @@ -821,4 +821,33 @@ bool FileExists(const std::string &filename) {
821 return SherpaOnnxFileExists(filename.c_str()); 821 return SherpaOnnxFileExists(filename.c_str());
822 } 822 }
823 823
  824 +// ============================================================
  825 +// For Offline Punctuation
  826 +// ============================================================
  827 +OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &config) {
  828 + struct SherpaOnnxOfflinePunctuationConfig c;
  829 + memset(&c, 0, sizeof(c));
  830 + c.model.ct_transformer = config.model.ct_transformer.c_str();
  831 + c.model.num_threads = config.model.num_threads;
  832 + c.model.debug = config.model.debug;
  833 + c.model.provider = config.model.provider.c_str();
  834 +
  835 + const SherpaOnnxOfflinePunctuation *punct = SherpaOnnxCreateOfflinePunctuation(&c);
  836 + return OfflinePunctuation(punct);
  837 +}
  838 +
  839 +OfflinePunctuation::OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p)
  840 + : MoveOnly<OfflinePunctuation, SherpaOnnxOfflinePunctuation>(p) {}
  841 +
  842 +void OfflinePunctuation::Destroy(const SherpaOnnxOfflinePunctuation *p) const {
  843 + SherpaOnnxDestroyOfflinePunctuation(p);
  844 +}
  845 +
  846 +std::string OfflinePunctuation::AddPunctuation(const std::string &text) const {
  847 + const char *result = SherpaOfflinePunctuationAddPunct(p_, text.c_str());
  848 + std::string ans(result);
  849 + SherpaOfflinePunctuationFreeText(result);
  850 + return ans;
  851 +}
  852 +
824 } // namespace sherpa_onnx::cxx 853 } // namespace sherpa_onnx::cxx
@@ -673,6 +673,34 @@ SHERPA_ONNX_API std::string GetGitSha1(); @@ -673,6 +673,34 @@ SHERPA_ONNX_API std::string GetGitSha1();
673 SHERPA_ONNX_API std::string GetGitDate(); 673 SHERPA_ONNX_API std::string GetGitDate();
674 SHERPA_ONNX_API bool FileExists(const std::string &filename); 674 SHERPA_ONNX_API bool FileExists(const std::string &filename);
675 675
  676 +// ============================================================================
  677 +// Offline Punctuation
  678 +// ============================================================================
  679 +struct OfflinePunctuationModelConfig {
  680 + std::string ct_transformer;
  681 + int32_t num_threads = 1;
  682 + bool debug = false;
  683 + std::string provider = "cpu";
  684 +};
  685 +
  686 +struct OfflinePunctuationConfig {
  687 + OfflinePunctuationModelConfig model;
  688 +};
  689 +
  690 +class SHERPA_ONNX_API OfflinePunctuation
  691 + : public MoveOnly<OfflinePunctuation, SherpaOnnxOfflinePunctuation> {
  692 + public:
  693 + static OfflinePunctuation Create(const OfflinePunctuationConfig &config);
  694 +
  695 + void Destroy(const SherpaOnnxOfflinePunctuation *p) const;
  696 +
  697 + // Add punctuations to the input text and return it.
  698 + std::string AddPunctuation(const std::string &text) const;
  699 +
  700 + private:
  701 + explicit OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p);
  702 +};
  703 +
676 } // namespace sherpa_onnx::cxx 704 } // namespace sherpa_onnx::cxx
677 705
678 #endif // SHERPA_ONNX_C_API_CXX_API_H_ 706 #endif // SHERPA_ONNX_C_API_CXX_API_H_