Fangjun Kuang
Committed by GitHub

Support streaming conformer CTC models from wenet (#427)

  1 +#!/usr/bin/env bash
  2 +
  3 +set -e
  4 +
  5 +log() {
  6 + # This function is from espnet
  7 + local fname=${BASH_SOURCE[1]##*/}
  8 + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
  9 +}
  10 +
  11 +echo "EXE is $EXE"
  12 +echo "PATH: $PATH"
  13 +
  14 +which $EXE
  15 +
  16 +log "------------------------------------------------------------"
  17 +log "Run streaming Conformer CTC from WeNet"
  18 +log "------------------------------------------------------------"
  19 +wenet_models=(
  20 +sherpa-onnx-zh-wenet-aishell
  21 +sherpa-onnx-zh-wenet-aishell2
  22 +sherpa-onnx-zh-wenet-wenetspeech
  23 +sherpa-onnx-zh-wenet-multi-cn
  24 +sherpa-onnx-en-wenet-librispeech
  25 +sherpa-onnx-en-wenet-gigaspeech
  26 +)
  27 +for name in ${wenet_models[@]}; do
  28 + repo_url=https://huggingface.co/csukuangfj/$name
  29 + log "Start testing ${repo_url}"
  30 + repo=$(basename $repo_url)
  31 + log "Download pretrained model and test-data from $repo_url"
  32 + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  33 + pushd $repo
  34 + git lfs pull --include "*.onnx"
  35 + ls -lh *.onnx
  36 + popd
  37 +
  38 + log "test float32 models"
  39 + time $EXE \
  40 + --tokens=$repo/tokens.txt \
  41 + --wenet-ctc-model=$repo/model-streaming.onnx \
  42 + $repo/test_wavs/0.wav \
  43 + $repo/test_wavs/1.wav \
  44 + $repo/test_wavs/8k.wav
  45 +
  46 + log "test int8 models"
  47 + time $EXE \
  48 + --tokens=$repo/tokens.txt \
  49 + --wenet-ctc-model=$repo/model-streaming.int8.onnx \
  50 + $repo/test_wavs/0.wav \
  51 + $repo/test_wavs/1.wav \
  52 + $repo/test_wavs/8k.wav
  53 +
  54 + rm -rf $repo
  55 +done
@@ -12,6 +12,7 @@ on: @@ -12,6 +12,7 @@ on:
12 - '.github/scripts/test-online-paraformer.sh' 12 - '.github/scripts/test-online-paraformer.sh'
13 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
  15 + - '.github/scripts/test-online-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
16 - 'CMakeLists.txt' 17 - 'CMakeLists.txt'
17 - 'cmake/**' 18 - 'cmake/**'
@@ -27,6 +28,8 @@ on: @@ -27,6 +28,8 @@ on:
27 - '.github/scripts/test-online-paraformer.sh' 28 - '.github/scripts/test-online-paraformer.sh'
28 - '.github/scripts/test-offline-transducer.sh' 29 - '.github/scripts/test-offline-transducer.sh'
29 - '.github/scripts/test-offline-ctc.sh' 30 - '.github/scripts/test-offline-ctc.sh'
  31 + - '.github/scripts/test-online-ctc.sh'
  32 + - '.github/scripts/test-online-ctc.sh'
30 - '.github/scripts/test-offline-tts.sh' 33 - '.github/scripts/test-offline-tts.sh'
31 - 'CMakeLists.txt' 34 - 'CMakeLists.txt'
32 - 'cmake/**' 35 - 'cmake/**'
@@ -88,6 +91,14 @@ jobs: @@ -88,6 +91,14 @@ jobs:
88 file build/bin/sherpa-onnx 91 file build/bin/sherpa-onnx
89 readelf -d build/bin/sherpa-onnx 92 readelf -d build/bin/sherpa-onnx
90 93
  94 + - name: Test online CTC
  95 + shell: bash
  96 + run: |
  97 + export PATH=$PWD/build/bin:$PATH
  98 + export EXE=sherpa-onnx
  99 +
  100 + .github/scripts/test-online-ctc.sh
  101 +
91 - name: Test offline TTS 102 - name: Test offline TTS
92 shell: bash 103 shell: bash
93 run: | 104 run: |
@@ -12,6 +12,7 @@ on: @@ -12,6 +12,7 @@ on:
12 - '.github/scripts/test-online-paraformer.sh' 12 - '.github/scripts/test-online-paraformer.sh'
13 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
  15 + - '.github/scripts/test-online-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
16 - 'CMakeLists.txt' 17 - 'CMakeLists.txt'
17 - 'cmake/**' 18 - 'cmake/**'
@@ -27,6 +28,7 @@ on: @@ -27,6 +28,7 @@ on:
27 - '.github/scripts/test-online-paraformer.sh' 28 - '.github/scripts/test-online-paraformer.sh'
28 - '.github/scripts/test-offline-transducer.sh' 29 - '.github/scripts/test-offline-transducer.sh'
29 - '.github/scripts/test-offline-ctc.sh' 30 - '.github/scripts/test-offline-ctc.sh'
  31 + - '.github/scripts/test-online-ctc.sh'
30 - '.github/scripts/test-offline-tts.sh' 32 - '.github/scripts/test-offline-tts.sh'
31 - 'CMakeLists.txt' 33 - 'CMakeLists.txt'
32 - 'cmake/**' 34 - 'cmake/**'
@@ -89,6 +91,14 @@ jobs: @@ -89,6 +91,14 @@ jobs:
89 file build/bin/sherpa-onnx 91 file build/bin/sherpa-onnx
90 readelf -d build/bin/sherpa-onnx 92 readelf -d build/bin/sherpa-onnx
91 93
  94 + - name: Test online CTC
  95 + shell: bash
  96 + run: |
  97 + export PATH=$PWD/build/bin:$PATH
  98 + export EXE=sherpa-onnx
  99 +
  100 + .github/scripts/test-online-ctc.sh
  101 +
92 - name: Test offline CTC 102 - name: Test offline CTC
93 shell: bash 103 shell: bash
94 run: | 104 run: |
@@ -13,6 +13,7 @@ on: @@ -13,6 +13,7 @@ on:
13 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 15 - '.github/scripts/test-offline-tts.sh'
  16 + - '.github/scripts/test-online-ctc.sh'
16 - 'CMakeLists.txt' 17 - 'CMakeLists.txt'
17 - 'cmake/**' 18 - 'cmake/**'
18 - 'sherpa-onnx/csrc/*' 19 - 'sherpa-onnx/csrc/*'
@@ -26,6 +27,7 @@ on: @@ -26,6 +27,7 @@ on:
26 - '.github/scripts/test-offline-transducer.sh' 27 - '.github/scripts/test-offline-transducer.sh'
27 - '.github/scripts/test-offline-ctc.sh' 28 - '.github/scripts/test-offline-ctc.sh'
28 - '.github/scripts/test-offline-tts.sh' 29 - '.github/scripts/test-offline-tts.sh'
  30 + - '.github/scripts/test-online-ctc.sh'
29 - 'CMakeLists.txt' 31 - 'CMakeLists.txt'
30 - 'cmake/**' 32 - 'cmake/**'
31 - 'sherpa-onnx/csrc/*' 33 - 'sherpa-onnx/csrc/*'
@@ -96,6 +98,15 @@ jobs: @@ -96,6 +98,15 @@ jobs:
96 otool -L build/bin/sherpa-onnx 98 otool -L build/bin/sherpa-onnx
97 otool -l build/bin/sherpa-onnx 99 otool -l build/bin/sherpa-onnx
98 100
  101 + - name: Test online CTC
  102 + shell: bash
  103 + run: |
  104 + export PATH=$PWD/build/bin:$PATH
  105 + export EXE=sherpa-onnx
  106 +
  107 + .github/scripts/test-online-ctc.sh
  108 +
  109 +
99 - name: Test offline TTS 110 - name: Test offline TTS
100 shell: bash 111 shell: bash
101 run: | 112 run: |
@@ -12,6 +12,7 @@ on: @@ -12,6 +12,7 @@ on:
12 - '.github/scripts/test-online-paraformer.sh' 12 - '.github/scripts/test-online-paraformer.sh'
13 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
  15 + - '.github/scripts/test-online-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
16 - 'CMakeLists.txt' 17 - 'CMakeLists.txt'
17 - 'cmake/**' 18 - 'cmake/**'
@@ -25,6 +26,7 @@ on: @@ -25,6 +26,7 @@ on:
25 - '.github/scripts/test-online-paraformer.sh' 26 - '.github/scripts/test-online-paraformer.sh'
26 - '.github/scripts/test-offline-transducer.sh' 27 - '.github/scripts/test-offline-transducer.sh'
27 - '.github/scripts/test-offline-ctc.sh' 28 - '.github/scripts/test-offline-ctc.sh'
  29 + - '.github/scripts/test-online-ctc.sh'
28 - '.github/scripts/test-offline-tts.sh' 30 - '.github/scripts/test-offline-tts.sh'
29 - 'CMakeLists.txt' 31 - 'CMakeLists.txt'
30 - 'cmake/**' 32 - 'cmake/**'
@@ -66,6 +68,14 @@ jobs: @@ -66,6 +68,14 @@ jobs:
66 68
67 ls -lh ./bin/Release/sherpa-onnx.exe 69 ls -lh ./bin/Release/sherpa-onnx.exe
68 70
  71 + - name: Test online CTC
  72 + shell: bash
  73 + run: |
  74 + export PATH=$PWD/build/bin/Release:$PATH
  75 + export EXE=sherpa-onnx.exe
  76 +
  77 + .github/scripts/test-online-ctc.sh
  78 +
69 - name: Test offline TTS 79 - name: Test offline TTS
70 shell: bash 80 shell: bash
71 run: | 81 run: |
@@ -12,6 +12,7 @@ on: @@ -12,6 +12,7 @@ on:
12 - '.github/scripts/test-online-paraformer.sh' 12 - '.github/scripts/test-online-paraformer.sh'
13 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
  15 + - '.github/scripts/test-online-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 16 - '.github/scripts/test-offline-tts.sh'
16 - 'CMakeLists.txt' 17 - 'CMakeLists.txt'
17 - 'cmake/**' 18 - 'cmake/**'
@@ -25,6 +26,7 @@ on: @@ -25,6 +26,7 @@ on:
25 - '.github/scripts/test-online-paraformer.sh' 26 - '.github/scripts/test-online-paraformer.sh'
26 - '.github/scripts/test-offline-transducer.sh' 27 - '.github/scripts/test-offline-transducer.sh'
27 - '.github/scripts/test-offline-ctc.sh' 28 - '.github/scripts/test-offline-ctc.sh'
  29 + - '.github/scripts/test-online-ctc.sh'
28 - '.github/scripts/test-offline-tts.sh' 30 - '.github/scripts/test-offline-tts.sh'
29 - 'CMakeLists.txt' 31 - 'CMakeLists.txt'
30 - 'cmake/**' 32 - 'cmake/**'
@@ -67,6 +69,14 @@ jobs: @@ -67,6 +69,14 @@ jobs:
67 69
68 ls -lh ./bin/Release/sherpa-onnx.exe 70 ls -lh ./bin/Release/sherpa-onnx.exe
69 71
  72 + - name: Test online CTC
  73 + shell: bash
  74 + run: |
  75 + export PATH=$PWD/build/bin/Release:$PATH
  76 + export EXE=sherpa-onnx.exe
  77 +
  78 + .github/scripts/test-online-ctc.sh
  79 +
70 - name: Test offline TTS 80 - name: Test offline TTS
71 shell: bash 81 shell: bash
72 run: | 82 run: |
@@ -13,6 +13,7 @@ on: @@ -13,6 +13,7 @@ on:
13 - '.github/scripts/test-offline-transducer.sh' 13 - '.github/scripts/test-offline-transducer.sh'
14 - '.github/scripts/test-offline-ctc.sh' 14 - '.github/scripts/test-offline-ctc.sh'
15 - '.github/scripts/test-offline-tts.sh' 15 - '.github/scripts/test-offline-tts.sh'
  16 + - '.github/scripts/test-online-ctc.sh'
16 - 'CMakeLists.txt' 17 - 'CMakeLists.txt'
17 - 'cmake/**' 18 - 'cmake/**'
18 - 'sherpa-onnx/csrc/*' 19 - 'sherpa-onnx/csrc/*'
@@ -26,6 +27,7 @@ on: @@ -26,6 +27,7 @@ on:
26 - '.github/scripts/test-offline-transducer.sh' 27 - '.github/scripts/test-offline-transducer.sh'
27 - '.github/scripts/test-offline-ctc.sh' 28 - '.github/scripts/test-offline-ctc.sh'
28 - '.github/scripts/test-offline-tts.sh' 29 - '.github/scripts/test-offline-tts.sh'
  30 + - '.github/scripts/test-online-ctc.sh'
29 - 'CMakeLists.txt' 31 - 'CMakeLists.txt'
30 - 'cmake/**' 32 - 'cmake/**'
31 - 'sherpa-onnx/csrc/*' 33 - 'sherpa-onnx/csrc/*'
@@ -67,6 +69,14 @@ jobs: @@ -67,6 +69,14 @@ jobs:
67 69
68 ls -lh ./bin/Release/sherpa-onnx.exe 70 ls -lh ./bin/Release/sherpa-onnx.exe
69 71
  72 + - name: Test online CTC
  73 + shell: bash
  74 + run: |
  75 + export PATH=$PWD/build/bin/Release:$PATH
  76 + export EXE=sherpa-onnx.exe
  77 +
  78 + .github/scripts/test-online-ctc.sh
  79 +
70 - name: Test offline TTS 80 - name: Test offline TTS
71 shell: bash 81 shell: bash
72 run: | 82 run: |
@@ -164,6 +164,7 @@ def main(): @@ -164,6 +164,7 @@ def main():
164 dynamic_axes={ 164 dynamic_axes={
165 "x": {0: "N", 1: "T"}, 165 "x": {0: "N", 1: "T"},
166 "attn_cache": {2: "T"}, 166 "attn_cache": {2: "T"},
  167 + "attn_mask": {2: "T"},
167 "log_probs": {0: "N"}, 168 "log_probs": {0: "N"},
168 "new_attn_cache": {2: "T"}, 169 "new_attn_cache": {2: "T"},
169 }, 170 },
@@ -49,6 +49,8 @@ set(sources @@ -49,6 +49,8 @@ set(sources
49 offline-zipformer-ctc-model-config.cc 49 offline-zipformer-ctc-model-config.cc
50 offline-zipformer-ctc-model.cc 50 offline-zipformer-ctc-model.cc
51 online-conformer-transducer-model.cc 51 online-conformer-transducer-model.cc
  52 + online-ctc-greedy-search-decoder.cc
  53 + online-ctc-model.cc
52 online-lm-config.cc 54 online-lm-config.cc
53 online-lm.cc 55 online-lm.cc
54 online-lstm-transducer-model.cc 56 online-lstm-transducer-model.cc
@@ -64,6 +66,8 @@ set(sources @@ -64,6 +66,8 @@ set(sources
64 online-transducer-model-config.cc 66 online-transducer-model-config.cc
65 online-transducer-model.cc 67 online-transducer-model.cc
66 online-transducer-modified-beam-search-decoder.cc 68 online-transducer-modified-beam-search-decoder.cc
  69 + online-wenet-ctc-model-config.cc
  70 + online-wenet-ctc-model.cc
67 online-zipformer-transducer-model.cc 71 online-zipformer-transducer-model.cc
68 online-zipformer2-transducer-model.cc 72 online-zipformer2-transducer-model.cc
69 onnx-utils.cc 73 onnx-utils.cc
  1 +// sherpa-onnx/csrc/online-ctc-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OnlineCtcDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int64_t> tokens;
  17 +
  18 + /// timestamps[i] contains the output frame index where tokens[i] is decoded.
  19 + /// Note: The index is after subsampling
  20 + std::vector<int32_t> timestamps;
  21 +
  22 + int32_t num_trailing_blanks = 0;
  23 +};
  24 +
  25 +class OnlineCtcDecoder {
  26 + public:
  27 + virtual ~OnlineCtcDecoder() = default;
  28 +
  29 + /** Run streaming CTC decoding given the output from the encoder model.
  30 + *
  31 + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
  32 + * lob_probs.
  33 + *
  34 + * @param results Input & Output parameters..
  35 + */
  36 + virtual void Decode(Ort::Value log_probs,
  37 + std::vector<OnlineCtcDecoderResult> *results) = 0;
  38 +};
  39 +
  40 +} // namespace sherpa_onnx
  41 +
  42 +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
  1 +// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +void OnlineCtcGreedySearchDecoder::Decode(
  16 + Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) {
  17 + std::vector<int64_t> log_probs_shape =
  18 + log_probs.GetTensorTypeAndShapeInfo().GetShape();
  19 +
  20 + if (log_probs_shape[0] != results->size()) {
  21 + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
  22 + static_cast<int32_t>(log_probs_shape[0]),
  23 + static_cast<int32_t>(results->size()));
  24 + exit(-1);
  25 + }
  26 +
  27 + int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
  28 + int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
  29 + int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
  30 +
  31 + const float *p = log_probs.GetTensorData<float>();
  32 +
  33 + for (int32_t b = 0; b != batch_size; ++b) {
  34 + auto &r = (*results)[b];
  35 +
  36 + int32_t prev_id = -1;
  37 +
  38 + for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) {
  39 + int32_t y = static_cast<int32_t>(std::distance(
  40 + static_cast<const float *>(p),
  41 + std::max_element(static_cast<const float *>(p),
  42 + static_cast<const float *>(p) + vocab_size)));
  43 +
  44 + if (y == blank_id_) {
  45 + r.num_trailing_blanks += 1;
  46 + } else {
  47 + r.num_trailing_blanks = 0;
  48 + }
  49 +
  50 + if (y != blank_id_ && y != prev_id) {
  51 + r.tokens.push_back(y);
  52 + r.timestamps.push_back(t);
  53 + }
  54 +
  55 + prev_id = y;
  56 + } // for (int32_t t = 0; t != num_frames; ++t) {
  57 + } // for (int32_t b = 0; b != batch_size; ++b)
  58 +}
  59 +
  60 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-ctc-decoder.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
  15 + public:
  16 + explicit OnlineCtcGreedySearchDecoder(int32_t blank_id)
  17 + : blank_id_(blank_id) {}
  18 +
  19 + void Decode(Ort::Value log_probs,
  20 + std::vector<OnlineCtcDecoderResult> *results) override;
  21 +
  22 + private:
  23 + int32_t blank_id_;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/online-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-ctc-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <memory>
  9 +#include <sstream>
  10 +#include <string>
  11 +
  12 +#include "sherpa-onnx/csrc/macros.h"
  13 +#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
  14 +#include "sherpa-onnx/csrc/onnx-utils.h"
  15 +
  16 +namespace {
  17 +
  18 +enum class ModelType {
  19 + kZipformerCtc,
  20 + kWenetCtc,
  21 + kUnkown,
  22 +};
  23 +
  24 +} // namespace
  25 +
  26 +namespace sherpa_onnx {
  27 +
  28 +static ModelType GetModelType(char *model_data, size_t model_data_length,
  29 + bool debug) {
  30 + Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
  31 + Ort::SessionOptions sess_opts;
  32 +
  33 + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
  34 + sess_opts);
  35 +
  36 + Ort::ModelMetadata meta_data = sess->GetModelMetadata();
  37 + if (debug) {
  38 + std::ostringstream os;
  39 + PrintModelMetadata(os, meta_data);
  40 + SHERPA_ONNX_LOGE("%s", os.str().c_str());
  41 + }
  42 +
  43 + Ort::AllocatorWithDefaultOptions allocator;
  44 + auto model_type =
  45 + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
  46 + if (!model_type) {
  47 + SHERPA_ONNX_LOGE(
  48 + "No model_type in the metadata!\n"
  49 + "If you are using models from WeNet, please refer to\n"
  50 + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/"
  51 + "run.sh\n"
  52 + "\n"
  53 + "for how to add metadta to model.onnx\n");
  54 + return ModelType::kUnkown;
  55 + }
  56 +
  57 + if (model_type.get() == std::string("zipformer2")) {
  58 + return ModelType::kZipformerCtc;
  59 + } else if (model_type.get() == std::string("wenet_ctc")) {
  60 + return ModelType::kWenetCtc;
  61 + } else {
  62 + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
  63 + return ModelType::kUnkown;
  64 + }
  65 +}
  66 +
  67 +std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
  68 + const OnlineModelConfig &config) {
  69 + ModelType model_type = ModelType::kUnkown;
  70 +
  71 + std::string filename;
  72 + if (!config.wenet_ctc.model.empty()) {
  73 + filename = config.wenet_ctc.model;
  74 + } else {
  75 + SHERPA_ONNX_LOGE("Please specify a CTC model");
  76 + exit(-1);
  77 + }
  78 +
  79 + {
  80 + auto buffer = ReadFile(filename);
  81 +
  82 + model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  83 + }
  84 +
  85 + switch (model_type) {
  86 + case ModelType::kZipformerCtc:
  87 + return nullptr;
  88 + // return std::make_unique<OnlineZipformerCtcModel>(config);
  89 + break;
  90 + case ModelType::kWenetCtc:
  91 + return std::make_unique<OnlineWenetCtcModel>(config);
  92 + break;
  93 + case ModelType::kUnkown:
  94 + SHERPA_ONNX_LOGE("Unknown model type in online CTC!");
  95 + return nullptr;
  96 + }
  97 +
  98 + return nullptr;
  99 +}
  100 +
  101 +#if __ANDROID_API__ >= 9
  102 +
  103 +std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
  104 + AAssetManager *mgr, const OnlineModelConfig &config) {
  105 + ModelType model_type = ModelType::kUnkown;
  106 +
  107 + std::string filename;
  108 + if (!config.wenet_ctc.model.empty()) {
  109 + filename = config.wenet_ctc.model;
  110 + } else {
  111 + SHERPA_ONNX_LOGE("Please specify a CTC model");
  112 + exit(-1);
  113 + }
  114 +
  115 + {
  116 + auto buffer = ReadFile(mgr, filename);
  117 +
  118 + model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
  119 + }
  120 +
  121 + switch (model_type) {
  122 + case ModelType::kZipformerCtc:
  123 + return nullptr;
  124 + // return std::make_unique<OnlineZipformerCtcModel>(mgr, config);
  125 + break;
  126 + case ModelType::kWenetCtc:
  127 + return std::make_unique<OnlineWenetCtcModel>(mgr, config);
  128 + break;
  129 + case ModelType::kUnkown:
  130 + SHERPA_ONNX_LOGE("Unknown model type in online CTC!");
  131 + return nullptr;
  132 + }
  133 +
  134 + return nullptr;
  135 +}
  136 +#endif
  137 +
  138 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-ctc-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/online-model-config.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +class OnlineCtcModel {
  22 + public:
  23 + virtual ~OnlineCtcModel() = default;
  24 +
  25 + static std::unique_ptr<OnlineCtcModel> Create(
  26 + const OnlineModelConfig &config);
  27 +
  28 +#if __ANDROID_API__ >= 9
  29 + static std::unique_ptr<OnlineCtcModel> Create(
  30 + AAssetManager *mgr, const OnlineModelConfig &config);
  31 +#endif
  32 +
  33 + // Return a list of tensors containing the initial states
  34 + virtual std::vector<Ort::Value> GetInitStates() const = 0;
  35 +
  36 + /**
  37 + *
  38 + * @param x A 3-D tensor of shape (N, T, C). N has to be 1.
  39 + * @param states It is from GetInitStates() or returned from this method.
  40 + *
  41 + * @return Return a list of tensors
  42 + * - ans[0] contains log_probs, of shape (N, T, C)
  43 + * - ans[1:] contains next_states
  44 + */
  45 + virtual std::vector<Ort::Value> Forward(
  46 + Ort::Value x, std::vector<Ort::Value> states) const = 0;
  47 +
  48 + /** Return the vocabulary size of the model
  49 + */
  50 + virtual int32_t VocabSize() const = 0;
  51 +
  52 + /** Return an allocator for allocating memory
  53 + */
  54 + virtual OrtAllocator *Allocator() const = 0;
  55 +
  56 + // The model accepts this number of frames before subsampling as input
  57 + virtual int32_t ChunkLength() const = 0;
  58 +
  59 + // Similar to frame_shift in feature extractor, after processing
  60 + // ChunkLength() frames, we advance by ChunkShift() frames
  61 + // before we process the next chunk.
  62 + virtual int32_t ChunkShift() const = 0;
  63 +};
  64 +
  65 +} // namespace sherpa_onnx
  66 +
  67 +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_
@@ -13,6 +13,7 @@ namespace sherpa_onnx { @@ -13,6 +13,7 @@ namespace sherpa_onnx {
13 void OnlineModelConfig::Register(ParseOptions *po) { 13 void OnlineModelConfig::Register(ParseOptions *po) {
14 transducer.Register(po); 14 transducer.Register(po);
15 paraformer.Register(po); 15 paraformer.Register(po);
  16 + wenet_ctc.Register(po);
16 17
17 po->Register("tokens", &tokens, "Path to tokens.txt"); 18 po->Register("tokens", &tokens, "Path to tokens.txt");
18 19
@@ -46,6 +47,10 @@ bool OnlineModelConfig::Validate() const { @@ -46,6 +47,10 @@ bool OnlineModelConfig::Validate() const {
46 return paraformer.Validate(); 47 return paraformer.Validate();
47 } 48 }
48 49
  50 + if (!wenet_ctc.model.empty()) {
  51 + return wenet_ctc.Validate();
  52 + }
  53 +
49 return transducer.Validate(); 54 return transducer.Validate();
50 } 55 }
51 56
@@ -55,6 +60,7 @@ std::string OnlineModelConfig::ToString() const { @@ -55,6 +60,7 @@ std::string OnlineModelConfig::ToString() const {
55 os << "OnlineModelConfig("; 60 os << "OnlineModelConfig(";
56 os << "transducer=" << transducer.ToString() << ", "; 61 os << "transducer=" << transducer.ToString() << ", ";
57 os << "paraformer=" << paraformer.ToString() << ", "; 62 os << "paraformer=" << paraformer.ToString() << ", ";
  63 + os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
58 os << "tokens=\"" << tokens << "\", "; 64 os << "tokens=\"" << tokens << "\", ";
59 os << "num_threads=" << num_threads << ", "; 65 os << "num_threads=" << num_threads << ", ";
60 os << "debug=" << (debug ? "True" : "False") << ", "; 66 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -8,12 +8,14 @@ @@ -8,12 +8,14 @@
8 8
9 #include "sherpa-onnx/csrc/online-paraformer-model-config.h" 9 #include "sherpa-onnx/csrc/online-paraformer-model-config.h"
10 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 10 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
  11 +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
11 12
12 namespace sherpa_onnx { 13 namespace sherpa_onnx {
13 14
14 struct OnlineModelConfig { 15 struct OnlineModelConfig {
15 OnlineTransducerModelConfig transducer; 16 OnlineTransducerModelConfig transducer;
16 OnlineParaformerModelConfig paraformer; 17 OnlineParaformerModelConfig paraformer;
  18 + OnlineWenetCtcModelConfig wenet_ctc;
17 std::string tokens; 19 std::string tokens;
18 int32_t num_threads = 1; 20 int32_t num_threads = 1;
19 bool debug = false; 21 bool debug = false;
@@ -31,10 +33,12 @@ struct OnlineModelConfig { @@ -31,10 +33,12 @@ struct OnlineModelConfig {
31 OnlineModelConfig() = default; 33 OnlineModelConfig() = default;
32 OnlineModelConfig(const OnlineTransducerModelConfig &transducer, 34 OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
33 const OnlineParaformerModelConfig &paraformer, 35 const OnlineParaformerModelConfig &paraformer,
  36 + const OnlineWenetCtcModelConfig &wenet_ctc,
34 const std::string &tokens, int32_t num_threads, bool debug, 37 const std::string &tokens, int32_t num_threads, bool debug,
35 const std::string &provider, const std::string &model_type) 38 const std::string &provider, const std::string &model_type)
36 : transducer(transducer), 39 : transducer(transducer),
37 paraformer(paraformer), 40 paraformer(paraformer),
  41 + wenet_ctc(wenet_ctc),
38 tokens(tokens), 42 tokens(tokens),
39 num_threads(num_threads), 43 num_threads(num_threads),
40 debug(debug), 44 debug(debug),
  1 +// sherpa-onnx/csrc/online-recognizer-ctc-impl.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "sherpa-onnx/csrc/file-utils.h"
  15 +#include "sherpa-onnx/csrc/macros.h"
  16 +#include "sherpa-onnx/csrc/online-ctc-decoder.h"
  17 +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
  18 +#include "sherpa-onnx/csrc/online-ctc-model.h"
  19 +#include "sherpa-onnx/csrc/online-recognizer-impl.h"
  20 +#include "sherpa-onnx/csrc/symbol-table.h"
  21 +
  22 +namespace sherpa_onnx {
  23 +
  24 +static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src,
  25 + const SymbolTable &sym_table,
  26 + float frame_shift_ms,
  27 + int32_t subsampling_factor,
  28 + int32_t segment,
  29 + int32_t frames_since_start) {
  30 + OnlineRecognizerResult r;
  31 + r.tokens.reserve(src.tokens.size());
  32 + r.timestamps.reserve(src.tokens.size());
  33 +
  34 + for (auto i : src.tokens) {
  35 + auto sym = sym_table[i];
  36 +
  37 + r.text.append(sym);
  38 + r.tokens.push_back(std::move(sym));
  39 + }
  40 +
  41 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  42 + for (auto t : src.timestamps) {
  43 + float time = frame_shift_s * t;
  44 + r.timestamps.push_back(time);
  45 + }
  46 +
  47 + r.segment = segment;
  48 + r.start_time = frames_since_start * frame_shift_ms / 1000.;
  49 +
  50 + return r;
  51 +}
  52 +
  53 +class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
  54 + public:
  55 + explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config)
  56 + : config_(config),
  57 + model_(OnlineCtcModel::Create(config.model_config)),
  58 + sym_(config.model_config.tokens),
  59 + endpoint_(config_.endpoint_config) {
  60 + if (!config.model_config.wenet_ctc.model.empty()) {
  61 + // WeNet CTC models assume input samples are in the range
  62 + // [-32768, 32767], so we set normalize_samples to false
  63 + config_.feat_config.normalize_samples = false;
  64 + }
  65 +
  66 + InitDecoder();
  67 + }
  68 +
  69 +#if __ANDROID_API__ >= 9
  70 + explicit OnlineRecognizerCtcImpl(AAssetManager *mgr,
  71 + const OnlineRecognizerConfig &config)
  72 + : config_(config),
  73 + model_(OnlineCtcModel::Create(mgr, config.model_config)),
  74 + sym_(mgr, config.model_config.tokens),
  75 + endpoint_(config_.endpoint_config) {
  76 + if (!config.model_config.wenet_ctc.model.empty()) {
  77 + // WeNet CTC models assume input samples are in the range
  78 + // [-32768, 32767], so we set normalize_samples to false
  79 + config_.feat_config.normalize_samples = false;
  80 + }
  81 +
  82 + InitDecoder();
  83 + }
  84 +#endif
  85 +
  86 + std::unique_ptr<OnlineStream> CreateStream() const override {
  87 + auto stream = std::make_unique<OnlineStream>(config_.feat_config);
  88 + stream->SetStates(model_->GetInitStates());
  89 +
  90 + return stream;
  91 + }
  92 +
  93 + bool IsReady(OnlineStream *s) const override {
  94 + return s->GetNumProcessedFrames() + model_->ChunkLength() <
  95 + s->NumFramesReady();
  96 + }
  97 +
  98 + void DecodeStreams(OnlineStream **ss, int32_t n) const override {
  99 + for (int32_t i = 0; i != n; ++i) {
  100 + DecodeStream(ss[i]);
  101 + }
  102 + }
  103 +
  104 + OnlineRecognizerResult GetResult(OnlineStream *s) const override {
  105 + OnlineCtcDecoderResult decoder_result = s->GetCtcResult();
  106 +
  107 + // TODO(fangjun): Remember to change these constants if needed
  108 + int32_t frame_shift_ms = 10;
  109 + int32_t subsampling_factor = 4;
  110 + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
  111 + s->GetCurrentSegment(), s->GetNumFramesSinceStart());
  112 + }
  113 +
  114 + bool IsEndpoint(OnlineStream *s) const override {
  115 + if (!config_.enable_endpoint) {
  116 + return false;
  117 + }
  118 +
  119 + int32_t num_processed_frames = s->GetNumProcessedFrames();
  120 +
  121 + // frame shift is 10 milliseconds
  122 + float frame_shift_in_seconds = 0.01;
  123 +
  124 + // subsampling factor is 4
  125 + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4;
  126 +
  127 + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
  128 + frame_shift_in_seconds);
  129 + }
  130 +
  131 + void Reset(OnlineStream *s) const override {
  132 + // segment is incremented only when the last
  133 + // result is not empty
  134 + const auto &r = s->GetCtcResult();
  135 + if (!r.tokens.empty()) {
  136 + s->GetCurrentSegment() += 1;
  137 + }
  138 +
  139 + // clear result
  140 + s->SetCtcResult({});
  141 +
  142 + // clear states
  143 + s->SetStates(model_->GetInitStates());
  144 +
  145 + // Note: We only update counters. The underlying audio samples
  146 + // are not discarded.
  147 + s->Reset();
  148 + }
  149 +
  150 + private:
  151 + void InitDecoder() {
  152 + if (config_.decoding_method == "greedy_search") {
  153 + if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
  154 + !sym_.contains("<blank>")) {
  155 + SHERPA_ONNX_LOGE(
  156 + "We expect that tokens.txt contains "
  157 + "the symbol <blk> or <eps> or <blank> and its ID.");
  158 + exit(-1);
  159 + }
  160 +
  161 + int32_t blank_id = 0;
  162 + if (sym_.contains("<blk>")) {
  163 + blank_id = sym_["<blk>"];
  164 + } else if (sym_.contains("<eps>")) {
  165 + // for tdnn models of the yesno recipe from icefall
  166 + blank_id = sym_["<eps>"];
  167 + } else if (sym_.contains("<blank>")) {
  168 + // for WeNet CTC models
  169 + blank_id = sym_["<blank>"];
  170 + }
  171 +
  172 + decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
  173 + } else {
  174 + SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
  175 + config_.decoding_method.c_str());
  176 + exit(-1);
  177 + }
  178 + }
  179 +
  180 + void DecodeStream(OnlineStream *s) const {
  181 + int32_t chunk_length = model_->ChunkLength();
  182 + int32_t chunk_shift = model_->ChunkShift();
  183 +
  184 + int32_t feat_dim = s->FeatureDim();
  185 +
  186 + const auto num_processed_frames = s->GetNumProcessedFrames();
  187 + std::vector<float> frames =
  188 + s->GetFrames(num_processed_frames, chunk_length);
  189 + s->GetNumProcessedFrames() += chunk_shift;
  190 +
  191 + auto memory_info =
  192 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  193 +
  194 + std::array<int64_t, 3> x_shape{1, chunk_length, feat_dim};
  195 + Ort::Value x =
  196 + Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(),
  197 + x_shape.data(), x_shape.size());
  198 + auto out = model_->Forward(std::move(x), std::move(s->GetStates()));
  199 + int32_t num_states = static_cast<int32_t>(out.size()) - 1;
  200 +
  201 + std::vector<Ort::Value> states;
  202 + states.reserve(num_states);
  203 +
  204 + for (int32_t i = 0; i != num_states; ++i) {
  205 + states.push_back(std::move(out[i + 1]));
  206 + }
  207 + s->SetStates(std::move(states));
  208 +
  209 + std::vector<OnlineCtcDecoderResult> results(1);
  210 + results[0] = std::move(s->GetCtcResult());
  211 +
  212 + decoder_->Decode(std::move(out[0]), &results);
  213 + s->SetCtcResult(results[0]);
  214 + }
  215 +
  216 + private:
  217 + OnlineRecognizerConfig config_;
  218 + std::unique_ptr<OnlineCtcModel> model_;
  219 + std::unique_ptr<OnlineCtcDecoder> decoder_;
  220 + SymbolTable sym_;
  221 + Endpoint endpoint_;
  222 +};
  223 +
  224 +} // namespace sherpa_onnx
  225 +
  226 +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_
@@ -4,6 +4,7 @@ @@ -4,6 +4,7 @@
4 4
5 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 5 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
6 6
  7 +#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h"
7 #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" 8 #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
8 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" 9 #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
9 10
@@ -19,6 +20,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -19,6 +20,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
19 return std::make_unique<OnlineRecognizerParaformerImpl>(config); 20 return std::make_unique<OnlineRecognizerParaformerImpl>(config);
20 } 21 }
21 22
  23 + if (!config.model_config.wenet_ctc.model.empty()) {
  24 + return std::make_unique<OnlineRecognizerCtcImpl>(config);
  25 + }
  26 +
22 SHERPA_ONNX_LOGE("Please specify a model"); 27 SHERPA_ONNX_LOGE("Please specify a model");
23 exit(-1); 28 exit(-1);
24 } 29 }
@@ -34,6 +39,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( @@ -34,6 +39,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
34 return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); 39 return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
35 } 40 }
36 41
  42 + if (!config.model_config.wenet_ctc.model.empty()) {
  43 + return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
  44 + }
  45 +
37 SHERPA_ONNX_LOGE("Please specify a model"); 46 SHERPA_ONNX_LOGE("Please specify a model");
38 exit(-1); 47 exit(-1);
39 } 48 }
@@ -120,11 +120,7 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { @@ -120,11 +120,7 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
120 model_(mgr, config.model_config), 120 model_(mgr, config.model_config),
121 sym_(mgr, config.model_config.tokens), 121 sym_(mgr, config.model_config.tokens),
122 endpoint_(config_.endpoint_config) { 122 endpoint_(config_.endpoint_config) {
123 - if (config.decoding_method == "greedy_search") {  
124 - // add greedy search decoder  
125 - // SHERPA_ONNX_LOGE("to be implemented");  
126 - // exit(-1);  
127 - } else { 123 + if (config.decoding_method != "greedy_search") {
128 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 124 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
129 config.decoding_method.c_str()); 125 config.decoding_method.c_str());
130 exit(-1); 126 exit(-1);
@@ -51,6 +51,10 @@ class OnlineStream::Impl { @@ -51,6 +51,10 @@ class OnlineStream::Impl {
51 51
52 OnlineTransducerDecoderResult &GetResult() { return result_; } 52 OnlineTransducerDecoderResult &GetResult() { return result_; }
53 53
  54 + OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
  55 +
  56 + void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
  57 +
54 void SetParaformerResult(const OnlineParaformerDecoderResult &r) { 58 void SetParaformerResult(const OnlineParaformerDecoderResult &r) {
55 paraformer_result_ = r; 59 paraformer_result_ = r;
56 } 60 }
@@ -89,7 +93,8 @@ class OnlineStream::Impl { @@ -89,7 +93,8 @@ class OnlineStream::Impl {
89 int32_t start_frame_index_ = 0; // never reset 93 int32_t start_frame_index_ = 0; // never reset
90 int32_t segment_ = 0; 94 int32_t segment_ = 0;
91 OnlineTransducerDecoderResult result_; 95 OnlineTransducerDecoderResult result_;
92 - std::vector<Ort::Value> states_; 96 + OnlineCtcDecoderResult ctc_result_;
  97 + std::vector<Ort::Value> states_; // states for transducer or ctc models
93 std::vector<float> paraformer_feat_cache_; 98 std::vector<float> paraformer_feat_cache_;
94 std::vector<float> paraformer_encoder_out_cache_; 99 std::vector<float> paraformer_encoder_out_cache_;
95 std::vector<float> paraformer_alpha_cache_; 100 std::vector<float> paraformer_alpha_cache_;
@@ -144,6 +149,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { @@ -144,6 +149,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
144 return impl_->GetResult(); 149 return impl_->GetResult();
145 } 150 }
146 151
  152 +OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
  153 + return impl_->GetCtcResult();
  154 +}
  155 +
  156 +void OnlineStream::SetCtcResult(const OnlineCtcDecoderResult &r) {
  157 + impl_->SetCtcResult(r);
  158 +}
  159 +
147 void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { 160 void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) {
148 impl_->SetParaformerResult(r); 161 impl_->SetParaformerResult(r);
149 } 162 }
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/context-graph.h" 12 #include "sherpa-onnx/csrc/context-graph.h"
13 #include "sherpa-onnx/csrc/features.h" 13 #include "sherpa-onnx/csrc/features.h"
  14 +#include "sherpa-onnx/csrc/online-ctc-decoder.h"
14 #include "sherpa-onnx/csrc/online-paraformer-decoder.h" 15 #include "sherpa-onnx/csrc/online-paraformer-decoder.h"
15 #include "sherpa-onnx/csrc/online-transducer-decoder.h" 16 #include "sherpa-onnx/csrc/online-transducer-decoder.h"
16 17
@@ -75,6 +76,9 @@ class OnlineStream { @@ -75,6 +76,9 @@ class OnlineStream {
75 void SetResult(const OnlineTransducerDecoderResult &r); 76 void SetResult(const OnlineTransducerDecoderResult &r);
76 OnlineTransducerDecoderResult &GetResult(); 77 OnlineTransducerDecoderResult &GetResult();
77 78
  79 + void SetCtcResult(const OnlineCtcDecoderResult &r);
  80 + OnlineCtcDecoderResult &GetCtcResult();
  81 +
78 void SetParaformerResult(const OnlineParaformerDecoderResult &r); 82 void SetParaformerResult(const OnlineParaformerDecoderResult &r);
79 OnlineParaformerDecoderResult &GetParaformerResult(); 83 OnlineParaformerDecoderResult &GetParaformerResult();
80 84
  1 +// sherpa-onnx/csrc/online-wenet-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OnlineWenetCtcModelConfig::Register(ParseOptions *po) {
  13 + po->Register("wenet-ctc-model", &model,
  14 + "Path to CTC model.onnx from WeNet. Please see "
  15 + "https://github.com/k2-fsa/sherpa-onnx/pull/425");
  16 + po->Register("wenet-ctc-chunk-size", &chunk_size,
  17 + "Chunk size after subsampling used for decoding.");
  18 + po->Register("wenet-ctc-num-left-chunks", &num_left_chunks,
  19 + "Number of left chunks after subsampling used for decoding.");
  20 +}
  21 +
  22 +bool OnlineWenetCtcModelConfig::Validate() const {
  23 + if (!FileExists(model)) {
  24 + SHERPA_ONNX_LOGE("WeNet CTC model %s does not exist", model.c_str());
  25 + return false;
  26 + }
  27 +
  28 + if (chunk_size <= 0) {
  29 + SHERPA_ONNX_LOGE(
  30 + "Please specify a positive value for --wenet-ctc-chunk-size. Currently "
  31 + "given: %d",
  32 + chunk_size);
  33 + return false;
  34 + }
  35 +
  36 + if (num_left_chunks <= 0) {
  37 + SHERPA_ONNX_LOGE(
  38 + "Please specify a positive value for --wenet-ctc-num-left-chunks. "
  39 + "Currently given: %d. Note that if you want to use -1, please consider "
  40 + "using a non-streaming model.",
  41 + num_left_chunks);
  42 + return false;
  43 + }
  44 +
  45 + return true;
  46 +}
  47 +
  48 +std::string OnlineWenetCtcModelConfig::ToString() const {
  49 + std::ostringstream os;
  50 +
  51 + os << "OnlineWenetCtcModelConfig(";
  52 + os << "model=\"" << model << "\", ";
  53 + os << "chunk_size=" << chunk_size << ", ";
  54 + os << "num_left_chunks=" << num_left_chunks << ")";
  55 +
  56 + return os.str();
  57 +}
  58 +
  59 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-wenet-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OnlineWenetCtcModelConfig {
  14 + std::string model;
  15 +
  16 + // --chunk_size from wenet
  17 + int32_t chunk_size = 16;
  18 +
  19 + // --num_left_chunks from wenet
  20 + int32_t num_left_chunks = 4;
  21 +
  22 + OnlineWenetCtcModelConfig() = default;
  23 +
  24 + OnlineWenetCtcModelConfig(const std::string &model, int32_t chunk_size,
  25 + int32_t num_left_chunks)
  26 + : model(model),
  27 + chunk_size(chunk_size),
  28 + num_left_chunks(num_left_chunks) {}
  29 +
  30 + void Register(ParseOptions *po);
  31 + bool Validate() const;
  32 +
  33 + std::string ToString() const;
  34 +};
  35 +
  36 +} // namespace sherpa_onnx
  37 +
  38 +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-paraformer-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <cmath>
  9 +#include <string>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "sherpa-onnx/csrc/macros.h"
  17 +#include "sherpa-onnx/csrc/onnx-utils.h"
  18 +#include "sherpa-onnx/csrc/session.h"
  19 +#include "sherpa-onnx/csrc/text-utils.h"
  20 +
  21 +namespace sherpa_onnx {
  22 +
  23 +class OnlineWenetCtcModel::Impl {
  24 + public:
  25 + explicit Impl(const OnlineModelConfig &config)
  26 + : config_(config),
  27 + env_(ORT_LOGGING_LEVEL_ERROR),
  28 + sess_opts_(GetSessionOptions(config)),
  29 + allocator_{} {
  30 + {
  31 + auto buf = ReadFile(config.wenet_ctc.model);
  32 + Init(buf.data(), buf.size());
  33 + }
  34 + }
  35 +
  36 +#if __ANDROID_API__ >= 9
  37 + Impl(AAssetManager *mgr, const OnlineModelConfig &config)
  38 + : config_(config),
  39 + env_(ORT_LOGGING_LEVEL_WARNING),
  40 + sess_opts_(GetSessionOptions(config)),
  41 + allocator_{} {
  42 + {
  43 + auto buf = ReadFile(mgr, config.wenet_ctc.model);
  44 + Init(buf.data(), buf.size());
  45 + }
  46 + }
  47 +#endif
  48 +
  49 + std::vector<Ort::Value> Forward(Ort::Value x,
  50 + std::vector<Ort::Value> states) {
  51 + Ort::Value &attn_cache = states[0];
  52 + Ort::Value &conv_cache = states[1];
  53 + Ort::Value &offset = states[2];
  54 +
  55 + int32_t chunk_size = config_.wenet_ctc.chunk_size;
  56 + int32_t left_chunks = config_.wenet_ctc.num_left_chunks;
  57 + // build attn_mask
  58 + std::array<int64_t, 3> attn_mask_shape{1, 1,
  59 + required_cache_size_ + chunk_size};
  60 + Ort::Value attn_mask = Ort::Value::CreateTensor<bool>(
  61 + allocator_, attn_mask_shape.data(), attn_mask_shape.size());
  62 + bool *p = attn_mask.GetTensorMutableData<bool>();
  63 + int32_t chunk_idx =
  64 + offset.GetTensorData<int64_t>()[0] / chunk_size - left_chunks;
  65 + if (chunk_idx < left_chunks) {
  66 + std::fill(p, p + required_cache_size_ - chunk_idx * chunk_size, 0);
  67 + std::fill(p + required_cache_size_ - chunk_idx * chunk_size,
  68 + p + attn_mask_shape[2], 1);
  69 + } else {
  70 + std::fill(p, p + attn_mask_shape[2], 1);
  71 + }
  72 +
  73 + std::array<Ort::Value, 6> inputs = {std::move(x),
  74 + View(&offset),
  75 + View(&required_cache_size_tensor_),
  76 + std::move(attn_cache),
  77 + std::move(conv_cache),
  78 + std::move(attn_mask)};
  79 +
  80 + auto out =
  81 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  82 + output_names_ptr_.data(), output_names_ptr_.size());
  83 +
  84 + offset.GetTensorMutableData<int64_t>()[0] +=
  85 + out[0].GetTensorTypeAndShapeInfo().GetShape()[1];
  86 + out.push_back(std::move(offset));
  87 +
  88 + return out;
  89 + }
  90 +
  91 + int32_t VocabSize() const { return vocab_size_; }
  92 +
  93 + int32_t ChunkLength() const {
  94 + // When chunk_size is 16, subsampling_factor_ is 4, right_context_ is 6,
  95 + // the returned value is (16 - 1)*4 + 6 + 1 = 67
  96 + return (config_.wenet_ctc.chunk_size - 1) * subsampling_factor_ +
  97 + right_context_ + 1;
  98 + }
  99 +
  100 + int32_t ChunkShift() const { return required_cache_size_; }
  101 +
  102 + OrtAllocator *Allocator() const { return allocator_; }
  103 +
  104 + // Return a vector containing 3 tensors
  105 + // - attn_cache
  106 + // - conv_cache
  107 + // - offset
  108 + std::vector<Ort::Value> GetInitStates() const {
  109 + std::vector<Ort::Value> ans;
  110 + ans.reserve(3);
  111 + ans.push_back(Clone(Allocator(), &attn_cache_));
  112 + ans.push_back(Clone(Allocator(), &conv_cache_));
  113 +
  114 + int64_t offset_shape = 1;
  115 +
  116 + Ort::Value offset =
  117 + Ort::Value::CreateTensor<int64_t>(allocator_, &offset_shape, 1);
  118 +
  119 + offset.GetTensorMutableData<int64_t>()[0] = required_cache_size_;
  120 +
  121 + ans.push_back(std::move(offset));
  122 +
  123 + return ans;
  124 + }
  125 +
  126 + private:
  127 + void Init(void *model_data, size_t model_data_length) {
  128 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  129 + sess_opts_);
  130 +
  131 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  132 +
  133 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  134 +
  135 + // get meta data
  136 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  137 + if (config_.debug) {
  138 + std::ostringstream os;
  139 + PrintModelMetadata(os, meta_data);
  140 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  141 + }
  142 +
  143 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  144 + SHERPA_ONNX_READ_META_DATA(head_, "head");
  145 + SHERPA_ONNX_READ_META_DATA(num_blocks_, "num_blocks");
  146 + SHERPA_ONNX_READ_META_DATA(output_size_, "output_size");
  147 + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel");
  148 + SHERPA_ONNX_READ_META_DATA(right_context_, "right_context");
  149 + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
  150 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  151 +
  152 + required_cache_size_ =
  153 + config_.wenet_ctc.chunk_size * config_.wenet_ctc.num_left_chunks;
  154 +
  155 + InitStates();
  156 + }
  157 +
  158 + void InitStates() {
  159 + std::array<int64_t, 4> attn_cache_shape{
  160 + num_blocks_, head_, required_cache_size_, output_size_ / head_ * 2};
  161 + attn_cache_ = Ort::Value::CreateTensor<float>(
  162 + allocator_, attn_cache_shape.data(), attn_cache_shape.size());
  163 +
  164 + Fill<float>(&attn_cache_, 0);
  165 +
  166 + std::array<int64_t, 4> conv_cache_shape{num_blocks_, 1, output_size_,
  167 + cnn_module_kernel_ - 1};
  168 + conv_cache_ = Ort::Value::CreateTensor<float>(
  169 + allocator_, conv_cache_shape.data(), conv_cache_shape.size());
  170 +
  171 + Fill<float>(&conv_cache_, 0);
  172 +
  173 + int64_t shape = 1;
  174 + required_cache_size_tensor_ =
  175 + Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1);
  176 +
  177 + required_cache_size_tensor_.GetTensorMutableData<int64_t>()[0] =
  178 + required_cache_size_;
  179 + }
  180 +
  181 + private:
  182 + OnlineModelConfig config_;
  183 + Ort::Env env_;
  184 + Ort::SessionOptions sess_opts_;
  185 + Ort::AllocatorWithDefaultOptions allocator_;
  186 +
  187 + std::unique_ptr<Ort::Session> sess_;
  188 +
  189 + std::vector<std::string> input_names_;
  190 + std::vector<const char *> input_names_ptr_;
  191 +
  192 + std::vector<std::string> output_names_;
  193 + std::vector<const char *> output_names_ptr_;
  194 +
  195 + int32_t head_;
  196 + int32_t num_blocks_;
  197 + int32_t output_size_;
  198 + int32_t cnn_module_kernel_;
  199 + int32_t right_context_;
  200 + int32_t subsampling_factor_;
  201 + int32_t vocab_size_;
  202 +
  203 + int32_t required_cache_size_;
  204 +
  205 + Ort::Value attn_cache_{nullptr};
  206 + Ort::Value conv_cache_{nullptr};
  207 + Ort::Value required_cache_size_tensor_{nullptr};
  208 +};
  209 +
  210 +OnlineWenetCtcModel::OnlineWenetCtcModel(const OnlineModelConfig &config)
  211 + : impl_(std::make_unique<Impl>(config)) {}
  212 +
  213 +#if __ANDROID_API__ >= 9
  214 +OnlineWenetCtcModel::OnlineWenetCtcModel(AAssetManager *mgr,
  215 + const OnlineModelConfig &config)
  216 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  217 +#endif
  218 +
  219 +OnlineWenetCtcModel::~OnlineWenetCtcModel() = default;
  220 +
  221 +std::vector<Ort::Value> OnlineWenetCtcModel::Forward(
  222 + Ort::Value x, std::vector<Ort::Value> states) const {
  223 + return impl_->Forward(std::move(x), std::move(states));
  224 +}
  225 +
  226 +int32_t OnlineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); }
  227 +
  228 +int32_t OnlineWenetCtcModel::ChunkLength() const {
  229 + return impl_->ChunkLength();
  230 +}
  231 +
  232 +int32_t OnlineWenetCtcModel::ChunkShift() const { return impl_->ChunkShift(); }
  233 +
  234 +OrtAllocator *OnlineWenetCtcModel::Allocator() const {
  235 + return impl_->Allocator();
  236 +}
  237 +
  238 +std::vector<Ort::Value> OnlineWenetCtcModel::GetInitStates() const {
  239 + return impl_->GetInitStates();
  240 +}
  241 +
  242 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-wenet-ctc-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/online-ctc-model.h"
  18 +#include "sherpa-onnx/csrc/online-model-config.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +class OnlineWenetCtcModel : public OnlineCtcModel {
  23 + public:
  24 + explicit OnlineWenetCtcModel(const OnlineModelConfig &config);
  25 +
  26 +#if __ANDROID_API__ >= 9
  27 + OnlineWenetCtcModel(AAssetManager *mgr, const OnlineModelConfig &config);
  28 +#endif
  29 +
  30 + ~OnlineWenetCtcModel() override;
  31 +
  32 + // A list of 3 tensors:
  33 + // - attn_cache
  34 + // - conv_cache
  35 + // - offset
  36 + std::vector<Ort::Value> GetInitStates() const override;
  37 +
  38 + /**
  39 + *
  40 + * @param x A 3-D tensor of shape (N, T, C). N has to be 1.
  41 + * @param states It is from GetInitStates() or returned from this method.
  42 + *
  43 + * @return Return a list of tensors
  44 + * - ans[0] contains log_probs, of shape (N, T, C)
  45 + * - ans[1:] contains next_states
  46 + */
  47 + std::vector<Ort::Value> Forward(
  48 + Ort::Value x, std::vector<Ort::Value> states) const override;
  49 +
  50 + /** Return the vocabulary size of the model
  51 + */
  52 + int32_t VocabSize() const override;
  53 +
  54 + /** Return an allocator for allocating memory
  55 + */
  56 + OrtAllocator *Allocator() const override;
  57 +
  58 + // The model accepts this number of frames before subsampling as input
  59 + int32_t ChunkLength() const override;
  60 +
  61 + // Similar to frame_shift in feature extractor, after processing
  62 + // ChunkLength() frames, we advance by ChunkShift() frames
  63 + // before we process the next chunk.
  64 + int32_t ChunkShift() const override;
  65 +
  66 + private:
  67 + class Impl;
  68 + std::unique_ptr<Impl> impl_;
  69 +};
  70 +
  71 +} // namespace sherpa_onnx
  72 +
  73 +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_
@@ -125,6 +125,34 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { @@ -125,6 +125,34 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
125 } 125 }
126 } 126 }
127 127
  128 +Ort::Value View(Ort::Value *v) {
  129 + auto type_and_shape = v->GetTensorTypeAndShapeInfo();
  130 + std::vector<int64_t> shape = type_and_shape.GetShape();
  131 +
  132 + auto memory_info =
  133 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  134 + switch (type_and_shape.GetElementType()) {
  135 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
  136 + return Ort::Value::CreateTensor(
  137 + memory_info, v->GetTensorMutableData<int32_t>(),
  138 + type_and_shape.GetElementCount(), shape.data(), shape.size());
  139 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
  140 + return Ort::Value::CreateTensor(
  141 + memory_info, v->GetTensorMutableData<int64_t>(),
  142 + type_and_shape.GetElementCount(), shape.data(), shape.size());
  143 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
  144 + return Ort::Value::CreateTensor(
  145 + memory_info, v->GetTensorMutableData<float>(),
  146 + type_and_shape.GetElementCount(), shape.data(), shape.size());
  147 + default:
  148 + fprintf(stderr, "Unsupported type: %d\n",
  149 + static_cast<int32_t>(type_and_shape.GetElementType()));
  150 + exit(-1);
  151 + // unreachable code
  152 + return Ort::Value{nullptr};
  153 + }
  154 +}
  155 +
128 void Print1D(Ort::Value *v) { 156 void Print1D(Ort::Value *v) {
129 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 157 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
130 const float *d = v->GetTensorData<float>(); 158 const float *d = v->GetTensorData<float>();
@@ -65,6 +65,9 @@ void PrintModelMetadata(std::ostream &os, @@ -65,6 +65,9 @@ void PrintModelMetadata(std::ostream &os,
65 // Return a deep copy of v 65 // Return a deep copy of v
66 Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); 66 Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
67 67
  68 +// Return a shallow copy
  69 +Ort::Value View(Ort::Value *v);
  70 +
68 // Print a 1-D tensor to stderr 71 // Print a 1-D tensor to stderr
69 void Print1D(Ort::Value *v); 72 void Print1D(Ort::Value *v);
70 73
@@ -26,6 +26,7 @@ pybind11_add_module(_sherpa_onnx @@ -26,6 +26,7 @@ pybind11_add_module(_sherpa_onnx
26 online-recognizer.cc 26 online-recognizer.cc
27 online-stream.cc 27 online-stream.cc
28 online-transducer-model-config.cc 28 online-transducer-model-config.cc
  29 + online-wenet-ctc-model-config.cc
29 sherpa-onnx.cc 30 sherpa-onnx.cc
30 silero-vad-model-config.cc 31 silero-vad-model-config.cc
31 vad-model-config.cc 32 vad-model-config.cc
@@ -11,24 +11,29 @@ @@ -11,24 +11,29 @@
11 #include "sherpa-onnx/csrc/online-transducer-model-config.h" 11 #include "sherpa-onnx/csrc/online-transducer-model-config.h"
12 #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" 12 #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
13 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" 13 #include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
  14 +#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
17 void PybindOnlineModelConfig(py::module *m) { 18 void PybindOnlineModelConfig(py::module *m) {
18 PybindOnlineTransducerModelConfig(m); 19 PybindOnlineTransducerModelConfig(m);
19 PybindOnlineParaformerModelConfig(m); 20 PybindOnlineParaformerModelConfig(m);
  21 + PybindOnlineWenetCtcModelConfig(m);
20 22
21 using PyClass = OnlineModelConfig; 23 using PyClass = OnlineModelConfig;
22 py::class_<PyClass>(*m, "OnlineModelConfig") 24 py::class_<PyClass>(*m, "OnlineModelConfig")
23 .def(py::init<const OnlineTransducerModelConfig &, 25 .def(py::init<const OnlineTransducerModelConfig &,
24 - const OnlineParaformerModelConfig &, const std::string &, 26 + const OnlineParaformerModelConfig &,
  27 + const OnlineWenetCtcModelConfig &, const std::string &,
25 int32_t, bool, const std::string &, const std::string &>(), 28 int32_t, bool, const std::string &, const std::string &>(),
26 py::arg("transducer") = OnlineTransducerModelConfig(), 29 py::arg("transducer") = OnlineTransducerModelConfig(),
27 py::arg("paraformer") = OnlineParaformerModelConfig(), 30 py::arg("paraformer") = OnlineParaformerModelConfig(),
  31 + py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
28 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, 32 py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
29 py::arg("provider") = "cpu", py::arg("model_type") = "") 33 py::arg("provider") = "cpu", py::arg("model_type") = "")
30 .def_readwrite("transducer", &PyClass::transducer) 34 .def_readwrite("transducer", &PyClass::transducer)
31 .def_readwrite("paraformer", &PyClass::paraformer) 35 .def_readwrite("paraformer", &PyClass::paraformer)
  36 + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
32 .def_readwrite("tokens", &PyClass::tokens) 37 .def_readwrite("tokens", &PyClass::tokens)
33 .def_readwrite("num_threads", &PyClass::num_threads) 38 .def_readwrite("num_threads", &PyClass::num_threads)
34 .def_readwrite("debug", &PyClass::debug) 39 .def_readwrite("debug", &PyClass::debug)
  1 +// sherpa-onnx/python/csrc/online-wenet-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
  6 +
  7 +#include <string>
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +void PybindOnlineWenetCtcModelConfig(py::module *m) {
  15 + using PyClass = OnlineWenetCtcModelConfig;
  16 + py::class_<PyClass>(*m, "OnlineWenetCtcModelConfig")
  17 + .def(py::init<const std::string &, int32_t, int32_t>(), py::arg("model"),
  18 + py::arg("chunk_size") = 16, py::arg("num_left_chunks") = 4)
  19 + .def_readwrite("model", &PyClass::model)
  20 + .def_readwrite("chunk_size", &PyClass::chunk_size)
  21 + .def_readwrite("num_left_chunks", &PyClass::num_left_chunks)
  22 + .def("__str__", &PyClass::ToString);
  23 +}
  24 +
  25 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineWenetCtcModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_