Fangjun Kuang
Committed by GitHub

Support adding punctuations to the speech recogntion result (#761)

  1 +#!/usr/bin/env bash
  2 +
  3 +set -ex
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Download model "
  18 +log "------------------------------------------------------------"
  19 +
  20 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  21 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  22 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  23 +repo=sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
  24 +ls -lh $repo
  25 +
  26 +$EXE \
  27 + --debug=1 \
  28 + --ct-transformer=$repo/model.onnx \
  29 + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你"
  30 +
  31 +$EXE \
  32 + --debug=1 \
  33 + --ct-transformer=$repo/model.onnx \
  34 + "我们都是木头人不会说话不会动"
  35 +
  36 +$EXE \
  37 + --debug=1 \
  38 + --ct-transformer=$repo/model.onnx \
  39 + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry"
  40 +
  41 +rm -rf $repo
@@ -16,6 +16,7 @@ on: @@ -16,6 +16,7 @@ on:
16 - '.github/scripts/test-online-ctc.sh' 16 - '.github/scripts/test-online-ctc.sh'
17 - '.github/scripts/test-offline-tts.sh' 17 - '.github/scripts/test-offline-tts.sh'
18 - '.github/scripts/test-audio-tagging.sh' 18 - '.github/scripts/test-audio-tagging.sh'
  19 + - '.github/scripts/test-offline-punctuation.sh'
19 - 'CMakeLists.txt' 20 - 'CMakeLists.txt'
20 - 'cmake/**' 21 - 'cmake/**'
21 - 'sherpa-onnx/csrc/*' 22 - 'sherpa-onnx/csrc/*'
@@ -34,6 +35,7 @@ on: @@ -34,6 +35,7 @@ on:
34 - '.github/scripts/test-online-ctc.sh' 35 - '.github/scripts/test-online-ctc.sh'
35 - '.github/scripts/test-offline-tts.sh' 36 - '.github/scripts/test-offline-tts.sh'
36 - '.github/scripts/test-audio-tagging.sh' 37 - '.github/scripts/test-audio-tagging.sh'
  38 + - '.github/scripts/test-offline-punctuation.sh'
37 - 'CMakeLists.txt' 39 - 'CMakeLists.txt'
38 - 'cmake/**' 40 - 'cmake/**'
39 - 'sherpa-onnx/csrc/*' 41 - 'sherpa-onnx/csrc/*'
@@ -126,6 +128,14 @@ jobs: @@ -126,6 +128,14 @@ jobs:
126 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 128 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
127 path: build/bin/* 129 path: build/bin/*
128 130
  131 + - name: Test offline punctuation
  132 + shell: bash
  133 + run: |
  134 + export PATH=$PWD/build/bin:$PATH
  135 + export EXE=sherpa-onnx-offline-punctuation
  136 +
  137 + .github/scripts/test-offline-punctuation.sh
  138 +
129 - name: Test C API 139 - name: Test C API
130 shell: bash 140 shell: bash
131 run: | 141 run: |
@@ -16,6 +16,7 @@ on: @@ -16,6 +16,7 @@ on:
16 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
17 - '.github/scripts/test-online-ctc.sh' 17 - '.github/scripts/test-online-ctc.sh'
18 - '.github/scripts/test-audio-tagging.sh' 18 - '.github/scripts/test-audio-tagging.sh'
  19 + - '.github/scripts/test-offline-punctuation.sh'
19 - 'CMakeLists.txt' 20 - 'CMakeLists.txt'
20 - 'cmake/**' 21 - 'cmake/**'
21 - 'sherpa-onnx/csrc/*' 22 - 'sherpa-onnx/csrc/*'
@@ -33,6 +34,7 @@ on: @@ -33,6 +34,7 @@ on:
33 - '.github/scripts/test-offline-tts.sh' 34 - '.github/scripts/test-offline-tts.sh'
34 - '.github/scripts/test-online-ctc.sh' 35 - '.github/scripts/test-online-ctc.sh'
35 - '.github/scripts/test-audio-tagging.sh' 36 - '.github/scripts/test-audio-tagging.sh'
  37 + - '.github/scripts/test-offline-punctuation.sh'
36 - 'CMakeLists.txt' 38 - 'CMakeLists.txt'
37 - 'cmake/**' 39 - 'cmake/**'
38 - 'sherpa-onnx/csrc/*' 40 - 'sherpa-onnx/csrc/*'
@@ -105,6 +107,14 @@ jobs: @@ -105,6 +107,14 @@ jobs:
105 otool -L build/bin/sherpa-onnx 107 otool -L build/bin/sherpa-onnx
106 otool -l build/bin/sherpa-onnx 108 otool -l build/bin/sherpa-onnx
107 109
  110 + - name: Test offline punctuation
  111 + shell: bash
  112 + run: |
  113 + export PATH=$PWD/build/bin:$PATH
  114 + export EXE=sherpa-onnx-offline-punctuation
  115 +
  116 + .github/scripts/test-offline-punctuation.sh
  117 +
108 - name: Test C API 118 - name: Test C API
109 shell: bash 119 shell: bash
110 run: | 120 run: |
@@ -15,6 +15,7 @@ on: @@ -15,6 +15,7 @@ on:
15 - '.github/scripts/test-online-ctc.sh' 15 - '.github/scripts/test-online-ctc.sh'
16 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
17 - '.github/scripts/test-audio-tagging.sh' 17 - '.github/scripts/test-audio-tagging.sh'
  18 + - '.github/scripts/test-offline-punctuation.sh'
18 - 'CMakeLists.txt' 19 - 'CMakeLists.txt'
19 - 'cmake/**' 20 - 'cmake/**'
20 - 'sherpa-onnx/csrc/*' 21 - 'sherpa-onnx/csrc/*'
@@ -30,6 +31,7 @@ on: @@ -30,6 +31,7 @@ on:
30 - '.github/scripts/test-online-ctc.sh' 31 - '.github/scripts/test-online-ctc.sh'
31 - '.github/scripts/test-offline-tts.sh' 32 - '.github/scripts/test-offline-tts.sh'
32 - '.github/scripts/test-audio-tagging.sh' 33 - '.github/scripts/test-audio-tagging.sh'
  34 + - '.github/scripts/test-offline-punctuation.sh'
33 - 'CMakeLists.txt' 35 - 'CMakeLists.txt'
34 - 'cmake/**' 36 - 'cmake/**'
35 - 'sherpa-onnx/csrc/*' 37 - 'sherpa-onnx/csrc/*'
@@ -72,6 +74,14 @@ jobs: @@ -72,6 +74,14 @@ jobs:
72 74
73 ls -lh ./bin/Release/sherpa-onnx.exe 75 ls -lh ./bin/Release/sherpa-onnx.exe
74 76
  77 + - name: Test offline punctuation
  78 + shell: bash
  79 + run: |
  80 + export PATH=$PWD/build/bin/Release:$PATH
  81 + export EXE=sherpa-onnx-offline-punctuation.exe
  82 +
  83 + .github/scripts/test-offline-punctuation.sh
  84 +
75 - name: Test C API 85 - name: Test C API
76 shell: bash 86 shell: bash
77 run: | 87 run: |
@@ -82,7 +92,6 @@ jobs: @@ -82,7 +92,6 @@ jobs:
82 92
83 .github/scripts/test-c-api.sh 93 .github/scripts/test-c-api.sh
84 94
85 -  
86 - name: Test Audio tagging 95 - name: Test Audio tagging
87 shell: bash 96 shell: bash
88 run: | 97 run: |
@@ -15,6 +15,7 @@ on: @@ -15,6 +15,7 @@ on:
15 - '.github/scripts/test-offline-tts.sh' 15 - '.github/scripts/test-offline-tts.sh'
16 - '.github/scripts/test-online-ctc.sh' 16 - '.github/scripts/test-online-ctc.sh'
17 - '.github/scripts/test-audio-tagging.sh' 17 - '.github/scripts/test-audio-tagging.sh'
  18 + - '.github/scripts/test-offline-punctuation.sh'
18 - 'CMakeLists.txt' 19 - 'CMakeLists.txt'
19 - 'cmake/**' 20 - 'cmake/**'
20 - 'sherpa-onnx/csrc/*' 21 - 'sherpa-onnx/csrc/*'
@@ -30,6 +31,7 @@ on: @@ -30,6 +31,7 @@ on:
30 - '.github/scripts/test-offline-tts.sh' 31 - '.github/scripts/test-offline-tts.sh'
31 - '.github/scripts/test-online-ctc.sh' 32 - '.github/scripts/test-online-ctc.sh'
32 - '.github/scripts/test-audio-tagging.sh' 33 - '.github/scripts/test-audio-tagging.sh'
  34 + - '.github/scripts/test-offline-punctuation.sh'
33 - 'CMakeLists.txt' 35 - 'CMakeLists.txt'
34 - 'cmake/**' 36 - 'cmake/**'
35 - 'sherpa-onnx/csrc/*' 37 - 'sherpa-onnx/csrc/*'
@@ -72,6 +74,14 @@ jobs: @@ -72,6 +74,14 @@ jobs:
72 74
73 ls -lh ./bin/Release/sherpa-onnx.exe 75 ls -lh ./bin/Release/sherpa-onnx.exe
74 76
  77 + - name: Test offline punctuation
  78 + shell: bash
  79 + run: |
  80 + export PATH=$PWD/build/bin/Release:$PATH
  81 + export EXE=sherpa-onnx-offline-punctuation.exe
  82 +
  83 + .github/scripts/test-offline-punctuation.sh
  84 +
75 - name: Test spoken language identification (C API) 85 - name: Test spoken language identification (C API)
76 shell: bash 86 shell: bash
77 run: | 87 run: |
@@ -46,14 +46,15 @@ def enable_alsa(): @@ -46,14 +46,15 @@ def enable_alsa():
46 def get_binaries(): 46 def get_binaries():
47 binaries = [ 47 binaries = [
48 "sherpa-onnx", 48 "sherpa-onnx",
49 - "sherpa-onnx-offline-audio-tagging",  
50 "sherpa-onnx-keyword-spotter", 49 "sherpa-onnx-keyword-spotter",
51 "sherpa-onnx-microphone", 50 "sherpa-onnx-microphone",
52 "sherpa-onnx-microphone-offline", 51 "sherpa-onnx-microphone-offline",
53 "sherpa-onnx-microphone-offline-audio-tagging", 52 "sherpa-onnx-microphone-offline-audio-tagging",
54 "sherpa-onnx-microphone-offline-speaker-identification", 53 "sherpa-onnx-microphone-offline-speaker-identification",
55 "sherpa-onnx-offline", 54 "sherpa-onnx-offline",
  55 + "sherpa-onnx-offline-audio-tagging",
56 "sherpa-onnx-offline-language-identification", 56 "sherpa-onnx-offline-language-identification",
  57 + "sherpa-onnx-offline-punctuation",
57 "sherpa-onnx-offline-tts", 58 "sherpa-onnx-offline-tts",
58 "sherpa-onnx-offline-tts-play", 59 "sherpa-onnx-offline-tts-play",
59 "sherpa-onnx-offline-websocket-server", 60 "sherpa-onnx-offline-websocket-server",
@@ -408,8 +408,11 @@ def main(): @@ -408,8 +408,11 @@ def main():
408 vad_config.silero_vad.min_silence_duration = 0.25 408 vad_config.silero_vad.min_silence_duration = 0.25
409 vad_config.silero_vad.min_speech_duration = 0.25 409 vad_config.silero_vad.min_speech_duration = 0.25
410 vad_config.sample_rate = g_sample_rate 410 vad_config.sample_rate = g_sample_rate
  411 + if not vad_config.validate():
  412 + raise ValueError("Errors in vad config")
411 413
412 window_size = vad_config.silero_vad.window_size 414 window_size = vad_config.silero_vad.window_size
  415 +
413 vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100) 416 vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100)
414 417
415 samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms 418 samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
@@ -121,6 +121,14 @@ list(APPEND sources @@ -121,6 +121,14 @@ list(APPEND sources
121 offline-zipformer-audio-tagging-model.cc 121 offline-zipformer-audio-tagging-model.cc
122 ) 122 )
123 123
  124 +# punctuation
  125 +list(APPEND sources
  126 + offline-ct-transformer-model.cc
  127 + offline-punctuation-impl.cc
  128 + offline-punctuation-model-config.cc
  129 + offline-punctuation.cc
  130 +)
  131 +
124 if(SHERPA_ONNX_ENABLE_TTS) 132 if(SHERPA_ONNX_ENABLE_TTS)
125 list(APPEND sources 133 list(APPEND sources
126 lexicon.cc 134 lexicon.cc
@@ -201,9 +209,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -201,9 +209,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
201 add_executable(sherpa-onnx sherpa-onnx.cc) 209 add_executable(sherpa-onnx sherpa-onnx.cc)
202 add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc) 210 add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
203 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) 211 add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
204 - add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)  
205 - add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)  
206 add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc) 212 add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc)
  213 + add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
  214 + add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
  215 + add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
207 216
208 if(SHERPA_ONNX_ENABLE_TTS) 217 if(SHERPA_ONNX_ENABLE_TTS)
209 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) 218 add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
@@ -213,9 +222,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) @@ -213,9 +222,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
213 sherpa-onnx 222 sherpa-onnx
214 sherpa-onnx-keyword-spotter 223 sherpa-onnx-keyword-spotter
215 sherpa-onnx-offline 224 sherpa-onnx-offline
216 - sherpa-onnx-offline-parallel  
217 - sherpa-onnx-offline-language-identification  
218 sherpa-onnx-offline-audio-tagging 225 sherpa-onnx-offline-audio-tagging
  226 + sherpa-onnx-offline-language-identification
  227 + sherpa-onnx-offline-parallel
  228 + sherpa-onnx-offline-punctuation
219 ) 229 )
220 if(SHERPA_ONNX_ENABLE_TTS) 230 if(SHERPA_ONNX_ENABLE_TTS)
221 list(APPEND main_exes 231 list(APPEND main_exes
@@ -260,11 +270,11 @@ endif() @@ -260,11 +270,11 @@ endif()
260 270
261 if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) 271 if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY)
262 add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) 272 add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
263 - add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)  
264 add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) 273 add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc)
  274 + add_executable(sherpa-onnx-alsa-offline-audio-tagging sherpa-onnx-alsa-offline-audio-tagging.cc alsa.cc)
265 add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) 275 add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc)
  276 + add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc)
266 add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc) 277 add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc)
267 - add_executable(sherpa-onnx-alsa-offline-audio-tagging sherpa-onnx-alsa-offline-audio-tagging.cc alsa.cc)  
268 278
269 279
270 if(SHERPA_ONNX_ENABLE_TTS) 280 if(SHERPA_ONNX_ENABLE_TTS)
@@ -74,11 +74,6 @@ static std::vector<std::string> ProcessHeteronyms( @@ -74,11 +74,6 @@ static std::vector<std::string> ProcessHeteronyms(
74 return ans; 74 return ans;
75 } 75 }
76 76
77 -static void ToLowerCase(std::string *in_out) {  
78 - std::transform(in_out->begin(), in_out->end(), in_out->begin(),  
79 - [](unsigned char c) { return std::tolower(c); });  
80 -}  
81 -  
82 // Note: We don't use SymbolTable here since tokens may contain a blank 77 // Note: We don't use SymbolTable here since tokens may contain a blank
83 // in the first column 78 // in the first column
84 static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) { 79 static std::unordered_map<std::string, int32_t> ReadTokens(std::istream &is) {
@@ -118,6 +118,24 @@ @@ -118,6 +118,24 @@
118 } \ 118 } \
119 } while (0) 119 } while (0)
120 120
  121 +// read a vector of strings separated by sep
  122 +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \
  123 + do { \
  124 + auto value = \
  125 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  126 + if (!value) { \
  127 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  128 + exit(-1); \
  129 + } \
  130 + SplitStringToVector(value.get(), sep, false, &dst); \
  131 + \
  132 + if (dst.empty()) { \
  133 + SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \
  134 + src_key); \
  135 + exit(-1); \
  136 + } \
  137 + } while (0)
  138 +
121 // Read a string 139 // Read a string
122 #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ 140 #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
123 do { \ 141 do { \
  1 +// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
  6 +
  7 +#include <string>
  8 +#include <unordered_map>
  9 +#include <vector>
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineCtTransformerModelMetaData {
  14 + std::unordered_map<std::string, int32_t> token2id;
  15 + std::unordered_map<std::string, int32_t> punct2id;
  16 + std::vector<std::string> id2punct;
  17 +
  18 + int32_t unk_id;
  19 + int32_t dot_id;
  20 + int32_t comma_id;
  21 + int32_t quest_id;
  22 + int32_t pause_id;
  23 + int32_t underline_id;
  24 + int32_t num_punctuations;
  25 +};
  26 +
  27 +} // namespace sherpa_onnx
  28 +
  29 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_
  1 +// sherpa-onnx/csrc/offline-ct-transformer-model.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/onnx-utils.h"
  11 +#include "sherpa-onnx/csrc/session.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineCtTransformerModel::Impl {
  17 + public:
  18 + explicit Impl(const OfflinePunctuationModelConfig &config)
  19 + : config_(config),
  20 + env_(ORT_LOGGING_LEVEL_ERROR),
  21 + sess_opts_(GetSessionOptions(config)),
  22 + allocator_{} {
  23 + auto buf = ReadFile(config_.ct_transformer);
  24 + Init(buf.data(), buf.size());
  25 + }
  26 +
  27 +#if __ANDROID_API__ >= 9
  28 + Impl(AAssetManager *mgr, const OfflinePunctuationModelConfig &config)
  29 + : config_(config),
  30 + env_(ORT_LOGGING_LEVEL_ERROR),
  31 + sess_opts_(GetSessionOptions(config)),
  32 + allocator_{} {
  33 + auto buf = ReadFile(mgr, config_.ct_transformer);
  34 + Init(buf.data(), buf.size());
  35 + }
  36 +#endif
  37 +
  38 + Ort::Value Forward(Ort::Value text, Ort::Value text_len) {
  39 + std::array<Ort::Value, 2> inputs = {std::move(text), std::move(text_len)};
  40 +
  41 + auto ans =
  42 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  43 + output_names_ptr_.data(), output_names_ptr_.size());
  44 + return std::move(ans[0]);
  45 + }
  46 +
  47 + OrtAllocator *Allocator() const { return allocator_; }
  48 +
  49 + const OfflineCtTransformerModelMetaData &GetModelMetadata() const {
  50 + return meta_data_;
  51 + }
  52 +
  53 + private:
  54 + void Init(void *model_data, size_t model_data_length) {
  55 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  56 + sess_opts_);
  57 +
  58 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  59 +
  60 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  61 +
  62 + // get meta data
  63 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  64 +
  65 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  66 +
  67 + std::vector<std::string> tokens;
  68 + SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(tokens, "tokens", "|");
  69 +
  70 + int32_t vocab_size;
  71 + SHERPA_ONNX_READ_META_DATA(vocab_size, "vocab_size");
  72 + if (tokens.size() != vocab_size) {
  73 + SHERPA_ONNX_LOGE("tokens.size() %d != vocab_size %d",
  74 + static_cast<int32_t>(tokens.size()), vocab_size);
  75 + exit(-1);
  76 + }
  77 +
  78 + SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(meta_data_.id2punct,
  79 + "punctuations", "|");
  80 +
  81 + std::string unk_symbol;
  82 + SHERPA_ONNX_READ_META_DATA_STR(unk_symbol, "unk_symbol");
  83 +
  84 + // output shape is (N, T, num_punctuations)
  85 + meta_data_.num_punctuations =
  86 + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2];
  87 +
  88 + int32_t i = 0;
  89 + for (const auto &t : tokens) {
  90 + meta_data_.token2id[t] = i;
  91 + i += 1;
  92 + }
  93 +
  94 + i = 0;
  95 + for (const auto &p : meta_data_.id2punct) {
  96 + meta_data_.punct2id[p] = i;
  97 + i += 1;
  98 + }
  99 +
  100 + meta_data_.unk_id = meta_data_.token2id.at(unk_symbol);
  101 +
  102 + meta_data_.dot_id = meta_data_.punct2id.at("。");
  103 + meta_data_.comma_id = meta_data_.punct2id.at(",");
  104 + meta_data_.quest_id = meta_data_.punct2id.at("?");
  105 + meta_data_.pause_id = meta_data_.punct2id.at("、");
  106 + meta_data_.underline_id = meta_data_.punct2id.at("_");
  107 +
  108 + if (config_.debug) {
  109 + std::ostringstream os;
  110 + os << "vocab_size: " << meta_data_.token2id.size() << "\n";
  111 + os << "num_punctuations: " << meta_data_.num_punctuations << "\n";
  112 + os << "punctuations: ";
  113 + for (const auto &s : meta_data_.id2punct) {
  114 + os << s << " ";
  115 + }
  116 + os << "\n";
  117 + SHERPA_ONNX_LOGE("\n%s\n", os.str().c_str());
  118 + }
  119 + }
  120 +
  121 + private:
  122 + OfflinePunctuationModelConfig config_;
  123 + Ort::Env env_;
  124 + Ort::SessionOptions sess_opts_;
  125 + Ort::AllocatorWithDefaultOptions allocator_;
  126 +
  127 + std::unique_ptr<Ort::Session> sess_;
  128 +
  129 + std::vector<std::string> input_names_;
  130 + std::vector<const char *> input_names_ptr_;
  131 +
  132 + std::vector<std::string> output_names_;
  133 + std::vector<const char *> output_names_ptr_;
  134 +
  135 + OfflineCtTransformerModelMetaData meta_data_;
  136 +};
  137 +
  138 +OfflineCtTransformerModel::OfflineCtTransformerModel(
  139 + const OfflinePunctuationModelConfig &config)
  140 + : impl_(std::make_unique<Impl>(config)) {}
  141 +
  142 +#if __ANDROID_API__ >= 9
  143 +OfflineCtTransformerModel::OfflineCtTransformerModel(
  144 + AAssetManager *mgr, const OfflinePunctuationModelConfig &config)
  145 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  146 +#endif
  147 +
  148 +OfflineCtTransformerModel::~OfflineCtTransformerModel() = default;
  149 +
  150 +Ort::Value OfflineCtTransformerModel::Forward(Ort::Value text,
  151 + Ort::Value text_len) const {
  152 + return impl_->Forward(std::move(text), std::move(text_len));
  153 +}
  154 +
  155 +OrtAllocator *OfflineCtTransformerModel::Allocator() const {
  156 + return impl_->Allocator();
  157 +}
  158 +
  159 +const OfflineCtTransformerModelMetaData &
  160 +OfflineCtTransformerModel::GetModelMetadata() const {
  161 + return impl_->GetModelMetadata();
  162 +}
  163 +
  164 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-ct-transformer-model.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_
  6 +#include <memory>
  7 +#include <utility>
  8 +
  9 +#if __ANDROID_API__ >= 9
  10 +#include "android/asset_manager.h"
  11 +#include "android/asset_manager_jni.h"
  12 +#endif
  13 +
  14 +#include "onnxruntime_cxx_api.h" // NOLINT
  15 +#include "sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h"
  16 +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
  17 +
  18 +namespace sherpa_onnx {
  19 +
  20 +/** This class implements
  21 + * https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx/punc_bin.py#L17
  22 + * from FunASR
  23 + */
  24 +class OfflineCtTransformerModel {
  25 + public:
  26 + explicit OfflineCtTransformerModel(
  27 + const OfflinePunctuationModelConfig &config);
  28 +
  29 +#if __ANDROID_API__ >= 9
  30 + OfflineCtTransformerModel(AAssetManager *mgr,
  31 + const OfflinePunctuationModelConfig &config);
  32 +#endif
  33 +
  34 + ~OfflineCtTransformerModel();
  35 +
  36 + /** Run the forward method of the model.
  37 + *
  38 + * @param text A tensor of shape (N, T) of dtype int32.
  39 + * @param text A tensor of shape (N) of dtype int32.
  40 + *
  41 + * @return Return a tensor
  42 + * - punctuation_ids: A 2-D tensor of shape (N, T).
  43 + */
  44 + Ort::Value Forward(Ort::Value text, Ort::Value text_len) const;
  45 +
  46 + /** Return an allocator for allocating memory
  47 + */
  48 + OrtAllocator *Allocator() const;
  49 +
  50 + const OfflineCtTransformerModelMetaData &GetModelMetadata() const;
  51 +
  52 + private:
  53 + class Impl;
  54 + std::unique_ptr<Impl> impl_;
  55 +};
  56 +
  57 +} // namespace sherpa_onnx
  58 +
  59 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_
  1 +// sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <utility>
  10 +#include <vector>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/math.h"
  14 +#include "sherpa-onnx/csrc/offline-ct-transformer-model.h"
  15 +#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
  16 +#include "sherpa-onnx/csrc/offline-punctuation.h"
  17 +#include "sherpa-onnx/csrc/text-utils.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
  22 + public:
  23 + explicit OfflinePunctuationCtTransformerImpl(
  24 + const OfflinePunctuationConfig &config)
  25 + : config_(config), model_(config.model) {}
  26 +
  27 + std::string AddPunctuation(const std::string &text) const override {
  28 + if (text.empty()) {
  29 + return {};
  30 + }
  31 +
  32 + std::vector<std::string> tokens = SplitUtf8(text);
  33 + std::vector<int32_t> token_ids;
  34 + token_ids.reserve(tokens.size());
  35 +
  36 + const auto &meta_data = model_.GetModelMetadata();
  37 +
  38 + for (const auto &t : tokens) {
  39 + std::string token = ToLowerCase(t);
  40 + if (meta_data.token2id.count(token)) {
  41 + token_ids.push_back(meta_data.token2id.at(token));
  42 + } else {
  43 + token_ids.push_back(meta_data.unk_id);
  44 + }
  45 + }
  46 +
  47 + auto memory_info =
  48 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  49 +
  50 + int32_t segment_size = 20;
  51 + int32_t max_len = 200;
  52 + int32_t num_segments = (token_ids.size() + segment_size - 1) / segment_size;
  53 +
  54 + std::vector<int32_t> punctuations;
  55 + int32_t last = -1;
  56 + for (int32_t i = 0; i != num_segments; ++i) {
  57 + int32_t this_start = i * segment_size; // inclusive
  58 + int32_t this_end = this_start + segment_size; // exclusive
  59 + if (this_end > token_ids.size()) {
  60 + this_end = token_ids.size();
  61 + }
  62 +
  63 + if (last != -1) {
  64 + this_start = last;
  65 + }
  66 + // token_ids[this_start:this_end] is sent to the model
  67 +
  68 + std::array<int64_t, 2> x_shape = {1, this_end - this_start};
  69 + Ort::Value x =
  70 + Ort::Value::CreateTensor(memory_info, token_ids.data() + this_start,
  71 + x_shape[1], x_shape.data(), x_shape.size());
  72 +
  73 + int64_t len_shape = 1;
  74 + int32_t len = x_shape[1];
  75 + Ort::Value x_len =
  76 + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1);
  77 +
  78 + Ort::Value out = model_.Forward(std::move(x), std::move(x_len));
  79 +
  80 + // [N, T, num_punctuations]
  81 + std::vector<int64_t> out_shape =
  82 + out.GetTensorTypeAndShapeInfo().GetShape();
  83 +
  84 + assert(out_shape[0] == 1);
  85 + assert(out_shape[1] == len);
  86 + assert(out_shape[2] == meta_data.num_punctuations);
  87 +
  88 + std::vector<int32_t> this_punctuations;
  89 + this_punctuations.reserve(len);
  90 +
  91 + const float *p = out.GetTensorData<float>();
  92 + for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) {
  93 + auto index = static_cast<int32_t>(std::distance(
  94 + p, std::max_element(p, p + meta_data.num_punctuations)));
  95 + this_punctuations.push_back(index);
  96 + } // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations)
  97 +
  98 + int32_t dot_index = -1;
  99 + int32_t comma_index = -1;
  100 +
  101 + for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) {
  102 + int32_t punct_id = this_punctuations[m];
  103 +
  104 + if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
  105 + dot_index = m;
  106 + break;
  107 + }
  108 +
  109 + if (comma_index == -1 && punct_id == meta_data.comma_id) {
  110 + comma_index = m;
  111 + }
  112 + } // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k)
  113 +
  114 + if (dot_index == -1 && len >= max_len && comma_index != -1) {
  115 + dot_index = comma_index;
  116 + this_punctuations[dot_index] = meta_data.dot_id;
  117 + }
  118 +
  119 + if (dot_index == -1) {
  120 + if (last == -1) {
  121 + last = this_start;
  122 + }
  123 +
  124 + if (i == num_segments - 1) {
  125 + dot_index = token_ids.size() - 1;
  126 + }
  127 + } else {
  128 + last = this_start + dot_index + 1;
  129 +
  130 + punctuations.insert(punctuations.end(), this_punctuations.begin(),
  131 + this_punctuations.begin() + (dot_index + 1));
  132 + }
  133 + } // for (int32_t i = 0; i != num_segments; ++i)
  134 +
  135 + if (punctuations.size() != token_ids.size() &&
  136 + punctuations.size() + 1 == token_ids.size()) {
  137 + punctuations.push_back(meta_data.dot_id);
  138 + }
  139 +
  140 + if (punctuations.size() != token_ids.size()) {
  141 + SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
  142 + text.c_str(), static_cast<int32_t>(punctuations.size()),
  143 + static_cast<int32_t>(token_ids.size()));
  144 + return text;
  145 + }
  146 +
  147 + std::string ans;
  148 +
  149 + for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
  150 + const std::string &w = tokens[i];
  151 + if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
  152 + ans.push_back(' ');
  153 + }
  154 + ans.append(w);
  155 + if (punctuations[i] != meta_data.underline_id) {
  156 + ans.append(meta_data.id2punct[punctuations[i]]);
  157 + }
  158 + }
  159 +
  160 + return ans;
  161 + }
  162 +
  163 + private:
  164 + OfflinePunctuationConfig config_;
  165 + OfflineCtTransformerModel model_;
  166 +};
  167 +
  168 +} // namespace sherpa_onnx
  169 +
  170 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-punctuation-impl.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +std::unique_ptr<OfflinePunctuationImpl> OfflinePunctuationImpl::Create(
  13 + const OfflinePunctuationConfig &config) {
  14 + if (!config.model.ct_transformer.empty()) {
  15 + return std::make_unique<OfflinePunctuationCtTransformerImpl>(config);
  16 + }
  17 +
  18 + SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer");
  19 + return nullptr;
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-punctuation-impl.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-punctuation.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflinePunctuationImpl {
  16 + public:
  17 + virtual ~OfflinePunctuationImpl() = default;
  18 +
  19 + static std::unique_ptr<OfflinePunctuationImpl> Create(
  20 + const OfflinePunctuationConfig &config);
  21 +
  22 + virtual std::string AddPunctuation(const std::string &text) const = 0;
  23 +};
  24 +
  25 +} // namespace sherpa_onnx
  26 +
  27 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_
  1 +// sherpa-onnx/csrc/offline-punctuation-model-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflinePunctuationModelConfig::Register(ParseOptions *po) {
  13 + po->Register("ct-transformer", &ct_transformer,
  14 + "Path to the controllable time-delay (CT) transformer model");
  15 +
  16 + po->Register("num-threads", &num_threads,
  17 + "Number of threads to run the neural network");
  18 +
  19 + po->Register("debug", &debug,
  20 + "true to print model information while loading it.");
  21 +
  22 + po->Register("provider", &provider,
  23 + "Specify a provider to use: cpu, cuda, coreml");
  24 +}
  25 +
  26 +bool OfflinePunctuationModelConfig::Validate() const {
  27 + if (ct_transformer.empty()) {
  28 + SHERPA_ONNX_LOGE("Please provide --ct-transformer");
  29 + return false;
  30 + }
  31 +
  32 + if (!FileExists(ct_transformer)) {
  33 + SHERPA_ONNX_LOGE("--ct-transformer %s does not exist",
  34 + ct_transformer.c_str());
  35 + return false;
  36 + }
  37 +
  38 + return true;
  39 +}
  40 +
  41 +std::string OfflinePunctuationModelConfig::ToString() const {
  42 + std::ostringstream os;
  43 +
  44 + os << "OfflinePunctuationModelConfig(";
  45 + os << "ct_transformer=\"" << ct_transformer << "\", ";
  46 + os << "num_threads=" << num_threads << ", ";
  47 + os << "debug=" << (debug ? "True" : "False") << ", ";
  48 + os << "provider=\"" << provider << "\")";
  49 +
  50 + return os.str();
  51 +}
  52 +
  53 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-punctuation-model-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflinePunctuationModelConfig {
  14 + std::string ct_transformer;
  15 +
  16 + int32_t num_threads = 1;
  17 + bool debug = false;
  18 + std::string provider = "cpu";
  19 +
  20 + OfflinePunctuationModelConfig() = default;
  21 +
  22 + OfflinePunctuationModelConfig(const std::string &ct_transformer,
  23 + int32_t num_threads, bool debug,
  24 + const std::string &provider)
  25 + : ct_transformer(ct_transformer),
  26 + num_threads(num_threads),
  27 + debug(debug),
  28 + provider(provider) {}
  29 +
  30 + void Register(ParseOptions *po);
  31 + bool Validate() const;
  32 +
  33 + std::string ToString() const;
  34 +};
  35 +
  36 +} // namespace sherpa_onnx
  37 +
  38 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-punctuation.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-punctuation.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/offline-punctuation-impl.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflinePunctuationConfig::Register(ParseOptions *po) {
  13 + model.Register(po);
  14 +}
  15 +
  16 +bool OfflinePunctuationConfig::Validate() const {
  17 + if (!model.Validate()) {
  18 + return false;
  19 + }
  20 +
  21 + return true;
  22 +}
  23 +
  24 +std::string OfflinePunctuationConfig::ToString() const {
  25 + std::ostringstream os;
  26 +
  27 + os << "OfflinePunctuationConfig(";
  28 + os << "model=" << model.ToString() << ")";
  29 +
  30 + return os.str();
  31 +}
  32 +
  33 +OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config)
  34 + : impl_(OfflinePunctuationImpl::Create(config)) {}
  35 +
  36 +OfflinePunctuation::~OfflinePunctuation() = default;
  37 +
  38 +std::string OfflinePunctuation::AddPunctuation(const std::string &text) const {
  39 + return impl_->AddPunctuation(text);
  40 +}
  41 +
  42 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-punctuation.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_
  6 +
  7 +#include <memory>
  8 +#include <string>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
  12 +#include "sherpa-onnx/csrc/parse-options.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +struct OfflinePunctuationConfig {
  17 + OfflinePunctuationModelConfig model;
  18 +
  19 + OfflinePunctuationConfig() = default;
  20 +
  21 + explicit OfflinePunctuationConfig(const OfflinePunctuationModelConfig &model)
  22 + : model(model) {}
  23 +
  24 + void Register(ParseOptions *po);
  25 + bool Validate() const;
  26 +
  27 + std::string ToString() const;
  28 +};
  29 +
  30 +class OfflinePunctuationImpl;
  31 +
  32 +class OfflinePunctuation {
  33 + public:
  34 + explicit OfflinePunctuation(const OfflinePunctuationConfig &config);
  35 +
  36 + ~OfflinePunctuation();
  37 +
  38 + // Add punctuation to the input text and return it.
  39 + std::string AddPunctuation(const std::string &text) const;
  40 +
  41 + private:
  42 + std::unique_ptr<OfflinePunctuationImpl> impl_;
  43 +};
  44 +
  45 +} // namespace sherpa_onnx
  46 +
  47 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_
@@ -29,7 +29,6 @@ void OnlineWebsocketDecoderConfig::Validate() const { @@ -29,7 +29,6 @@ void OnlineWebsocketDecoderConfig::Validate() const {
29 SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0); 29 SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0);
30 SHERPA_ONNX_CHECK_GT(max_batch_size, 0); 30 SHERPA_ONNX_CHECK_GT(max_batch_size, 0);
31 SHERPA_ONNX_CHECK_GT(end_tail_padding, 0); 31 SHERPA_ONNX_CHECK_GT(end_tail_padding, 0);
32 -  
33 } 32 }
34 33
35 void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) { 34 void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) {
@@ -87,7 +86,8 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) { @@ -87,7 +86,8 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr<Connection> c) {
87 c->samples.pop_front(); 86 c->samples.pop_front();
88 } 87 }
89 88
90 - std::vector<float> tail_padding(static_cast<int64_t>(config_.end_tail_padding * sample_rate)); 89 + std::vector<float> tail_padding(
  90 + static_cast<int64_t>(config_.end_tail_padding * sample_rate));
91 91
92 c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size()); 92 c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size());
93 93
@@ -160,4 +160,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { @@ -160,4 +160,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
160 return GetSessionOptionsImpl(config.num_threads, config.provider); 160 return GetSessionOptionsImpl(config.num_threads, config.provider);
161 } 161 }
162 162
  163 +Ort::SessionOptions GetSessionOptions(
  164 + const OfflinePunctuationModelConfig &config) {
  165 + return GetSessionOptionsImpl(config.num_threads, config.provider);
  166 +}
  167 +
163 } // namespace sherpa_onnx 168 } // namespace sherpa_onnx
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include "sherpa-onnx/csrc/audio-tagging-model-config.h" 9 #include "sherpa-onnx/csrc/audio-tagging-model-config.h"
10 #include "sherpa-onnx/csrc/offline-lm-config.h" 10 #include "sherpa-onnx/csrc/offline-lm-config.h"
11 #include "sherpa-onnx/csrc/offline-model-config.h" 11 #include "sherpa-onnx/csrc/offline-model-config.h"
  12 +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
12 #include "sherpa-onnx/csrc/online-lm-config.h" 13 #include "sherpa-onnx/csrc/online-lm-config.h"
13 #include "sherpa-onnx/csrc/online-model-config.h" 14 #include "sherpa-onnx/csrc/online-model-config.h"
14 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" 15 #include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
@@ -43,6 +44,9 @@ Ort::SessionOptions GetSessionOptions( @@ -43,6 +44,9 @@ Ort::SessionOptions GetSessionOptions(
43 44
44 Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); 45 Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
45 46
  47 +Ort::SessionOptions GetSessionOptions(
  48 + const OfflinePunctuationModelConfig &config);
  49 +
46 } // namespace sherpa_onnx 50 } // namespace sherpa_onnx
47 51
48 #endif // SHERPA_ONNX_CSRC_SESSION_H_ 52 #endif // SHERPA_ONNX_CSRC_SESSION_H_
  1 +// sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc
  2 +//
  3 +// Copyright (c) 2022-2024 Xiaomi Corporation
  4 +#include <stdio.h>
  5 +
  6 +#include <chrono> // NOLINT
  7 +
  8 +#include "sherpa-onnx/csrc/offline-punctuation.h"
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +int main(int32_t argc, char *argv[]) {
  12 + const char *kUsageMessage = R"usage(
  13 +Add punctuations to the input text.
  14 +
  15 +The input text can contain both Chinese and English words.
  16 +
  17 +Usage:
  18 +
  19 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  20 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  21 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
  22 +
  23 +./bin/sherpa-onnx-offline-punctuation \
  24 + --ct-transformer=./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx
  25 + "你好吗how are you Fantasitic 谢谢我很好你怎么样呢"
  26 +
  27 +The output text should look like below:
  28 +)usage";
  29 +
  30 + sherpa_onnx::ParseOptions po(kUsageMessage);
  31 + sherpa_onnx::OfflinePunctuationConfig config;
  32 + config.Register(&po);
  33 + po.Read(argc, argv);
  34 + if (po.NumArgs() != 1) {
  35 + fprintf(stderr,
  36 + "Error: Please provide only 1 position argument containing the "
  37 + "input text.\n\n");
  38 + po.PrintUsage();
  39 + exit(EXIT_FAILURE);
  40 + }
  41 +
  42 + fprintf(stderr, "%s\n", config.ToString().c_str());
  43 +
  44 + if (!config.Validate()) {
  45 + fprintf(stderr, "Errors in config!\n");
  46 + return -1;
  47 + }
  48 +
  49 + fprintf(stderr, "Creating OfflinePunctuation ...\n");
  50 + sherpa_onnx::OfflinePunctuation punct(config);
  51 + fprintf(stderr, "Started\n");
  52 + const auto begin = std::chrono::steady_clock::now();
  53 +
  54 + std::string text = po.GetArg(1);
  55 + std::string text_with_punct = punct.AddPunctuation(text);
  56 + fprintf(stderr, "Done\n");
  57 + const auto end = std::chrono::steady_clock::now();
  58 +
  59 + float elapsed_seconds =
  60 + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
  61 + .count() /
  62 + 1000.;
  63 +
  64 + fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
  65 + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
  66 + fprintf(stderr, "Input text: %s\n", text.c_str());
  67 + fprintf(stderr, "Output text: %s\n", text_with_punct.c_str());
  68 +}
@@ -111,8 +111,8 @@ for a list of pre-trained models to download. @@ -111,8 +111,8 @@ for a list of pre-trained models to download.
111 fprintf(stderr, "Creating recognizer ...\n"); 111 fprintf(stderr, "Creating recognizer ...\n");
112 sherpa_onnx::OfflineRecognizer recognizer(config); 112 sherpa_onnx::OfflineRecognizer recognizer(config);
113 113
114 - const auto begin = std::chrono::steady_clock::now();  
115 fprintf(stderr, "Started\n"); 114 fprintf(stderr, "Started\n");
  115 + const auto begin = std::chrono::steady_clock::now();
116 116
117 std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; 117 std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
118 std::vector<sherpa_onnx::OfflineStream *> ss_pointers; 118 std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
@@ -385,4 +385,16 @@ std::vector<std::string> SplitUtf8(const std::string &text) { @@ -385,4 +385,16 @@ std::vector<std::string> SplitUtf8(const std::string &text) {
385 return MergeCharactersIntoWords(ans); 385 return MergeCharactersIntoWords(ans);
386 } 386 }
387 387
  388 +std::string ToLowerCase(const std::string &s) {
  389 + std::string ans(s.size(), 0);
  390 + std::transform(s.begin(), s.end(), ans.begin(),
  391 + [](unsigned char c) { return std::tolower(c); });
  392 + return ans;
  393 +}
  394 +
  395 +void ToLowerCase(std::string *in_out) {
  396 + std::transform(in_out->begin(), in_out->end(), in_out->begin(),
  397 + [](unsigned char c) { return std::tolower(c); });
  398 +}
  399 +
388 } // namespace sherpa_onnx 400 } // namespace sherpa_onnx
@@ -121,6 +121,9 @@ bool ConvertStringToReal(const std::string &str, T *out); @@ -121,6 +121,9 @@ bool ConvertStringToReal(const std::string &str, T *out);
121 121
122 std::vector<std::string> SplitUtf8(const std::string &text); 122 std::vector<std::string> SplitUtf8(const std::string &text);
123 123
  124 +std::string ToLowerCase(const std::string &s);
  125 +void ToLowerCase(std::string *in_out);
  126 +
124 } // namespace sherpa_onnx 127 } // namespace sherpa_onnx
125 128
126 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ 129 #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_