yujinqiu
Committed by GitHub

Add swift online punctuation (#1661)

@@ -24,6 +24,7 @@ @@ -24,6 +24,7 @@
24 #include "sherpa-onnx/csrc/macros.h" 24 #include "sherpa-onnx/csrc/macros.h"
25 #include "sherpa-onnx/csrc/offline-punctuation.h" 25 #include "sherpa-onnx/csrc/offline-punctuation.h"
26 #include "sherpa-onnx/csrc/offline-recognizer.h" 26 #include "sherpa-onnx/csrc/offline-recognizer.h"
  27 +#include "sherpa-onnx/csrc/online-punctuation.h"
27 #include "sherpa-onnx/csrc/online-recognizer.h" 28 #include "sherpa-onnx/csrc/online-recognizer.h"
28 #include "sherpa-onnx/csrc/resample.h" 29 #include "sherpa-onnx/csrc/resample.h"
29 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 30 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
@@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct( @@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct(
1717 1718
1718 void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; } 1719 void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; }
1719 1720
  1721 +struct SherpaOnnxOnlinePunctuation {
  1722 + std::unique_ptr<sherpa_onnx::OnlinePunctuation> impl;
  1723 +};
  1724 +
  1725 +const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
  1726 + const SherpaOnnxOnlinePunctuationConfig *config) {
  1727 + auto p = new SherpaOnnxOnlinePunctuation;
  1728 + try {
  1729 + sherpa_onnx::OnlinePunctuationConfig punctuation_config;
  1730 + punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
  1731 + punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
  1732 + punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
  1733 + punctuation_config.model.debug = config->model.debug;
  1734 + punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
  1735 +
  1736 + p->impl =
  1737 + std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
  1738 + } catch (const std::exception &e) {
  1739 + SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what());
  1740 + delete p;
  1741 + return nullptr;
  1742 + }
  1743 + return p;
  1744 +}
  1745 +
  1746 +void SherpaOnnxDestroyOnlinePunctuation(const SherpaOnnxOnlinePunctuation *p) {
  1747 + delete p;
  1748 +}
  1749 +
  1750 +const char *SherpaOnnxOnlinePunctuationAddPunct(
  1751 + const SherpaOnnxOnlinePunctuation *punctuation, const char *text) {
  1752 + if (!punctuation || !text) return nullptr;
  1753 +
  1754 + try {
  1755 + std::string s = punctuation->impl->AddPunctuationWithCase(text);
  1756 + char *p = new char[s.size() + 1];
  1757 + std::copy(s.begin(), s.end(), p);
  1758 + p[s.size()] = '\0';
  1759 + return p;
  1760 + } catch (const std::exception &e) {
  1761 + SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what());
  1762 + return nullptr;
  1763 + }
  1764 +}
  1765 +
  1766 +void SherpaOnnxOnlinePunctuationFreeText(const char *text) { delete[] text; }
  1767 +
1720 struct SherpaOnnxLinearResampler { 1768 struct SherpaOnnxLinearResampler {
1721 std::unique_ptr<sherpa_onnx::LinearResample> impl; 1769 std::unique_ptr<sherpa_onnx::LinearResample> impl;
1722 }; 1770 };
@@ -1369,6 +1369,39 @@ SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct( @@ -1369,6 +1369,39 @@ SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct(
1369 1369
1370 SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text); 1370 SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text);
1371 1371
  1372 +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationModelConfig {
  1373 + const char *cnn_bilstm;
  1374 + const char *bpe_vocab;
  1375 + int32_t num_threads;
  1376 + int32_t debug;
  1377 + const char *provider;
  1378 +} SherpaOnnxOnlinePunctuationModelConfig;
  1379 +
  1380 +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
  1381 + SherpaOnnxOnlinePunctuationModelConfig model;
  1382 +} SherpaOnnxOnlinePunctuationConfig;
  1383 +
  1384 +SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
  1385 +
  1386 +// Create an online punctuation processor. The user has to invoke
  1387 +// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
  1388 +// to avoid memory leak
  1389 +SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
  1390 + const SherpaOnnxOnlinePunctuationConfig *config);
  1391 +
  1392 +// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
  1393 +SHERPA_ONNX_API void SherpaOnnxDestroyOnlinePunctuation(
  1394 + const SherpaOnnxOnlinePunctuation *punctuation);
  1395 +
  1396 +// Add punctuations to the input text. The user has to invoke
  1397 +// SherpaOnnxOnlinePunctuationFreeText() to free the returned pointer
  1398 +// to avoid memory leak
  1399 +SHERPA_ONNX_API const char *SherpaOnnxOnlinePunctuationAddPunct(
  1400 + const SherpaOnnxOnlinePunctuation *punctuation, const char *text);
  1401 +
  1402 +// Free a pointer returned by SherpaOnnxOnlinePunctuationAddPunct()
  1403 +SHERPA_ONNX_API void SherpaOnnxOnlinePunctuationFreeText(const char *text);
  1404 +
1372 // for resampling 1405 // for resampling
1373 SHERPA_ONNX_API typedef struct SherpaOnnxLinearResampler 1406 SHERPA_ONNX_API typedef struct SherpaOnnxLinearResampler
1374 SherpaOnnxLinearResampler; 1407 SherpaOnnxLinearResampler;
@@ -1095,6 +1095,52 @@ class SherpaOnnxOfflinePunctuationWrapper { @@ -1095,6 +1095,52 @@ class SherpaOnnxOfflinePunctuationWrapper {
1095 } 1095 }
1096 } 1096 }
1097 1097
  1098 +func sherpaOnnxOnlinePunctuationModelConfig(
  1099 + cnnBiLstm: String,
  1100 + bpeVocab: String,
  1101 + numThreads: Int = 1,
  1102 + debug: Int = 0,
  1103 + provider: String = "cpu"
  1104 +) -> SherpaOnnxOnlinePunctuationModelConfig {
  1105 + return SherpaOnnxOnlinePunctuationModelConfig(
  1106 + cnn_bilstm: toCPointer(cnnBiLstm),
  1107 + bpe_vocab: toCPointer(bpeVocab),
  1108 + num_threads: Int32(numThreads),
  1109 + debug: Int32(debug),
  1110 + provider: toCPointer(provider))
  1111 +}
  1112 +
  1113 +func sherpaOnnxOnlinePunctuationConfig(
  1114 + model: SherpaOnnxOnlinePunctuationModelConfig
  1115 +) -> SherpaOnnxOnlinePunctuationConfig {
  1116 + return SherpaOnnxOnlinePunctuationConfig(model: model)
  1117 +}
  1118 +
  1119 +class SherpaOnnxOnlinePunctuationWrapper {
  1120 + /// A pointer to the underlying counterpart in C
  1121 + let ptr: OpaquePointer!
  1122 +
  1123 + /// Constructor taking a model config
  1124 + init(
  1125 + config: UnsafePointer<SherpaOnnxOnlinePunctuationConfig>!
  1126 + ) {
  1127 + ptr = SherpaOnnxCreateOnlinePunctuation(config)
  1128 + }
  1129 +
  1130 + deinit {
  1131 + if let ptr {
  1132 + SherpaOnnxDestroyOnlinePunctuation(ptr)
  1133 + }
  1134 + }
  1135 +
  1136 + func addPunct(text: String) -> String {
  1137 + let cText = SherpaOnnxOnlinePunctuationAddPunct(ptr, toCPointer(text))
  1138 + let ans = String(cString: cText!)
  1139 + SherpaOnnxOnlinePunctuationFreeText(cText)
  1140 + return ans
  1141 + }
  1142 +}
  1143 +
1098 func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String) 1144 func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String)
1099 -> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig 1145 -> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig
1100 { 1146 {
  1 +func run() {
  2 + let model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx"
  3 + let bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab"
  4 +
  5 + // Create model config
  6 + let modelConfig = sherpaOnnxOnlinePunctuationModelConfig(
  7 + cnnBiLstm: model,
  8 + bpeVocab: bpe
  9 + )
  10 +
  11 + // Create punctuation config
  12 + var config = sherpaOnnxOnlinePunctuationConfig(model: modelConfig)
  13 +
  14 + // Create punctuation instance
  15 + let punct = SherpaOnnxOnlinePunctuationWrapper(config: &config)
  16 +
  17 + // Test texts
  18 + let textList = [
  19 + "how are you i am fine thank you",
  20 + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry"
  21 + ]
  22 +
  23 + // Process each text
  24 + for i in 0..<textList.count {
  25 + let t = punct.addPunct(text: textList[i])
  26 + print("\nresult is:\n\(t)")
  27 + }
  28 +}
  29 +
  30 +@main
  31 +struct App {
  32 + static func main() {
  33 + run()
  34 + }
  35 +}
  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +if [ ! -d ../build-swift-macos ]; then
  6 + echo "Please run ../build-swift-macos.sh first!"
  7 + exit 1
  8 +fi
  9 +
  10 +# Download and extract the online punctuation model if not exists
  11 +if [ ! -d ./sherpa-onnx-online-punct-en-2024-08-06 ]; then
  12 + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  13 + tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  14 + rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
  15 +fi
  16 +
  17 +if [ ! -e ./add-punctuation-online ]; then
  18 + # Note: We use -lc++ to link against libc++ instead of libstdc++
  19 + swiftc \
  20 + -lc++ \
  21 + -I ../build-swift-macos/install/include \
  22 + -import-objc-header ./SherpaOnnx-Bridging-Header.h \
  23 + ./add-punctuation-online.swift ./SherpaOnnx.swift \
  24 + -L ../build-swift-macos/install/lib/ \
  25 + -l sherpa-onnx \
  26 + -l onnxruntime \
  27 + -o ./add-punctuation-online
  28 +
  29 + strip ./add-punctuation-online
  30 +else
  31 + echo "./add-punctuation-online exists - skip building"
  32 +fi
  33 +
  34 +# Set library path and run the executable
  35 +export DYLD_LIBRARY_PATH=$PWD/../build-swift-macos/install/lib:$DYLD_LIBRARY_PATH
  36 +./add-punctuation-online