Committed by
GitHub
Support streaming conformer CTC models from wenet (#427)
正在显示
31 个修改的文件
包含
1212 行增加
和
7 行删除
.github/scripts/test-online-ctc.sh
0 → 100755
| 1 | +#!/usr/bin/env bash | ||
| 2 | + | ||
| 3 | +set -e | ||
| 4 | + | ||
| 5 | +log() { | ||
| 6 | + # This function is from espnet | ||
| 7 | + local fname=${BASH_SOURCE[1]##*/} | ||
| 8 | + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
| 9 | +} | ||
| 10 | + | ||
| 11 | +echo "EXE is $EXE" | ||
| 12 | +echo "PATH: $PATH" | ||
| 13 | + | ||
| 14 | +which $EXE | ||
| 15 | + | ||
| 16 | +log "------------------------------------------------------------" | ||
| 17 | +log "Run streaming Conformer CTC from WeNet" | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | +wenet_models=( | ||
| 20 | +sherpa-onnx-zh-wenet-aishell | ||
| 21 | +sherpa-onnx-zh-wenet-aishell2 | ||
| 22 | +sherpa-onnx-zh-wenet-wenetspeech | ||
| 23 | +sherpa-onnx-zh-wenet-multi-cn | ||
| 24 | +sherpa-onnx-en-wenet-librispeech | ||
| 25 | +sherpa-onnx-en-wenet-gigaspeech | ||
| 26 | +) | ||
| 27 | +for name in ${wenet_models[@]}; do | ||
| 28 | + repo_url=https://huggingface.co/csukuangfj/$name | ||
| 29 | + log "Start testing ${repo_url}" | ||
| 30 | + repo=$(basename $repo_url) | ||
| 31 | + log "Download pretrained model and test-data from $repo_url" | ||
| 32 | + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 33 | + pushd $repo | ||
| 34 | + git lfs pull --include "*.onnx" | ||
| 35 | + ls -lh *.onnx | ||
| 36 | + popd | ||
| 37 | + | ||
| 38 | + log "test float32 models" | ||
| 39 | + time $EXE \ | ||
| 40 | + --tokens=$repo/tokens.txt \ | ||
| 41 | + --wenet-ctc-model=$repo/model-streaming.onnx \ | ||
| 42 | + $repo/test_wavs/0.wav \ | ||
| 43 | + $repo/test_wavs/1.wav \ | ||
| 44 | + $repo/test_wavs/8k.wav | ||
| 45 | + | ||
| 46 | + log "test int8 models" | ||
| 47 | + time $EXE \ | ||
| 48 | + --tokens=$repo/tokens.txt \ | ||
| 49 | + --wenet-ctc-model=$repo/model-streaming.int8.onnx \ | ||
| 50 | + $repo/test_wavs/0.wav \ | ||
| 51 | + $repo/test_wavs/1.wav \ | ||
| 52 | + $repo/test_wavs/8k.wav | ||
| 53 | + | ||
| 54 | + rm -rf $repo | ||
| 55 | +done |
| @@ -12,6 +12,7 @@ on: | @@ -12,6 +12,7 @@ on: | ||
| 12 | - '.github/scripts/test-online-paraformer.sh' | 12 | - '.github/scripts/test-online-paraformer.sh' |
| 13 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | + - '.github/scripts/test-online-ctc.sh' | ||
| 15 | - '.github/scripts/test-offline-tts.sh' | 16 | - '.github/scripts/test-offline-tts.sh' |
| 16 | - 'CMakeLists.txt' | 17 | - 'CMakeLists.txt' |
| 17 | - 'cmake/**' | 18 | - 'cmake/**' |
| @@ -27,6 +28,8 @@ on: | @@ -27,6 +28,8 @@ on: | ||
| 27 | - '.github/scripts/test-online-paraformer.sh' | 28 | - '.github/scripts/test-online-paraformer.sh' |
| 28 | - '.github/scripts/test-offline-transducer.sh' | 29 | - '.github/scripts/test-offline-transducer.sh' |
| 29 | - '.github/scripts/test-offline-ctc.sh' | 30 | - '.github/scripts/test-offline-ctc.sh' |
| 31 | + - '.github/scripts/test-online-ctc.sh' | ||
| 32 | + - '.github/scripts/test-online-ctc.sh' | ||
| 30 | - '.github/scripts/test-offline-tts.sh' | 33 | - '.github/scripts/test-offline-tts.sh' |
| 31 | - 'CMakeLists.txt' | 34 | - 'CMakeLists.txt' |
| 32 | - 'cmake/**' | 35 | - 'cmake/**' |
| @@ -88,6 +91,14 @@ jobs: | @@ -88,6 +91,14 @@ jobs: | ||
| 88 | file build/bin/sherpa-onnx | 91 | file build/bin/sherpa-onnx |
| 89 | readelf -d build/bin/sherpa-onnx | 92 | readelf -d build/bin/sherpa-onnx |
| 90 | 93 | ||
| 94 | + - name: Test online CTC | ||
| 95 | + shell: bash | ||
| 96 | + run: | | ||
| 97 | + export PATH=$PWD/build/bin:$PATH | ||
| 98 | + export EXE=sherpa-onnx | ||
| 99 | + | ||
| 100 | + .github/scripts/test-online-ctc.sh | ||
| 101 | + | ||
| 91 | - name: Test offline TTS | 102 | - name: Test offline TTS |
| 92 | shell: bash | 103 | shell: bash |
| 93 | run: | | 104 | run: | |
| @@ -12,6 +12,7 @@ on: | @@ -12,6 +12,7 @@ on: | ||
| 12 | - '.github/scripts/test-online-paraformer.sh' | 12 | - '.github/scripts/test-online-paraformer.sh' |
| 13 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | + - '.github/scripts/test-online-ctc.sh' | ||
| 15 | - '.github/scripts/test-offline-tts.sh' | 16 | - '.github/scripts/test-offline-tts.sh' |
| 16 | - 'CMakeLists.txt' | 17 | - 'CMakeLists.txt' |
| 17 | - 'cmake/**' | 18 | - 'cmake/**' |
| @@ -27,6 +28,7 @@ on: | @@ -27,6 +28,7 @@ on: | ||
| 27 | - '.github/scripts/test-online-paraformer.sh' | 28 | - '.github/scripts/test-online-paraformer.sh' |
| 28 | - '.github/scripts/test-offline-transducer.sh' | 29 | - '.github/scripts/test-offline-transducer.sh' |
| 29 | - '.github/scripts/test-offline-ctc.sh' | 30 | - '.github/scripts/test-offline-ctc.sh' |
| 31 | + - '.github/scripts/test-online-ctc.sh' | ||
| 30 | - '.github/scripts/test-offline-tts.sh' | 32 | - '.github/scripts/test-offline-tts.sh' |
| 31 | - 'CMakeLists.txt' | 33 | - 'CMakeLists.txt' |
| 32 | - 'cmake/**' | 34 | - 'cmake/**' |
| @@ -89,6 +91,14 @@ jobs: | @@ -89,6 +91,14 @@ jobs: | ||
| 89 | file build/bin/sherpa-onnx | 91 | file build/bin/sherpa-onnx |
| 90 | readelf -d build/bin/sherpa-onnx | 92 | readelf -d build/bin/sherpa-onnx |
| 91 | 93 | ||
| 94 | + - name: Test online CTC | ||
| 95 | + shell: bash | ||
| 96 | + run: | | ||
| 97 | + export PATH=$PWD/build/bin:$PATH | ||
| 98 | + export EXE=sherpa-onnx | ||
| 99 | + | ||
| 100 | + .github/scripts/test-online-ctc.sh | ||
| 101 | + | ||
| 92 | - name: Test offline CTC | 102 | - name: Test offline CTC |
| 93 | shell: bash | 103 | shell: bash |
| 94 | run: | | 104 | run: | |
| @@ -13,6 +13,7 @@ on: | @@ -13,6 +13,7 @@ on: | ||
| 13 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | - '.github/scripts/test-offline-tts.sh' | 15 | - '.github/scripts/test-offline-tts.sh' |
| 16 | + - '.github/scripts/test-online-ctc.sh' | ||
| 16 | - 'CMakeLists.txt' | 17 | - 'CMakeLists.txt' |
| 17 | - 'cmake/**' | 18 | - 'cmake/**' |
| 18 | - 'sherpa-onnx/csrc/*' | 19 | - 'sherpa-onnx/csrc/*' |
| @@ -26,6 +27,7 @@ on: | @@ -26,6 +27,7 @@ on: | ||
| 26 | - '.github/scripts/test-offline-transducer.sh' | 27 | - '.github/scripts/test-offline-transducer.sh' |
| 27 | - '.github/scripts/test-offline-ctc.sh' | 28 | - '.github/scripts/test-offline-ctc.sh' |
| 28 | - '.github/scripts/test-offline-tts.sh' | 29 | - '.github/scripts/test-offline-tts.sh' |
| 30 | + - '.github/scripts/test-online-ctc.sh' | ||
| 29 | - 'CMakeLists.txt' | 31 | - 'CMakeLists.txt' |
| 30 | - 'cmake/**' | 32 | - 'cmake/**' |
| 31 | - 'sherpa-onnx/csrc/*' | 33 | - 'sherpa-onnx/csrc/*' |
| @@ -96,6 +98,15 @@ jobs: | @@ -96,6 +98,15 @@ jobs: | ||
| 96 | otool -L build/bin/sherpa-onnx | 98 | otool -L build/bin/sherpa-onnx |
| 97 | otool -l build/bin/sherpa-onnx | 99 | otool -l build/bin/sherpa-onnx |
| 98 | 100 | ||
| 101 | + - name: Test online CTC | ||
| 102 | + shell: bash | ||
| 103 | + run: | | ||
| 104 | + export PATH=$PWD/build/bin:$PATH | ||
| 105 | + export EXE=sherpa-onnx | ||
| 106 | + | ||
| 107 | + .github/scripts/test-online-ctc.sh | ||
| 108 | + | ||
| 109 | + | ||
| 99 | - name: Test offline TTS | 110 | - name: Test offline TTS |
| 100 | shell: bash | 111 | shell: bash |
| 101 | run: | | 112 | run: | |
| @@ -12,6 +12,7 @@ on: | @@ -12,6 +12,7 @@ on: | ||
| 12 | - '.github/scripts/test-online-paraformer.sh' | 12 | - '.github/scripts/test-online-paraformer.sh' |
| 13 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | + - '.github/scripts/test-online-ctc.sh' | ||
| 15 | - '.github/scripts/test-offline-tts.sh' | 16 | - '.github/scripts/test-offline-tts.sh' |
| 16 | - 'CMakeLists.txt' | 17 | - 'CMakeLists.txt' |
| 17 | - 'cmake/**' | 18 | - 'cmake/**' |
| @@ -25,6 +26,7 @@ on: | @@ -25,6 +26,7 @@ on: | ||
| 25 | - '.github/scripts/test-online-paraformer.sh' | 26 | - '.github/scripts/test-online-paraformer.sh' |
| 26 | - '.github/scripts/test-offline-transducer.sh' | 27 | - '.github/scripts/test-offline-transducer.sh' |
| 27 | - '.github/scripts/test-offline-ctc.sh' | 28 | - '.github/scripts/test-offline-ctc.sh' |
| 29 | + - '.github/scripts/test-online-ctc.sh' | ||
| 28 | - '.github/scripts/test-offline-tts.sh' | 30 | - '.github/scripts/test-offline-tts.sh' |
| 29 | - 'CMakeLists.txt' | 31 | - 'CMakeLists.txt' |
| 30 | - 'cmake/**' | 32 | - 'cmake/**' |
| @@ -66,6 +68,14 @@ jobs: | @@ -66,6 +68,14 @@ jobs: | ||
| 66 | 68 | ||
| 67 | ls -lh ./bin/Release/sherpa-onnx.exe | 69 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 68 | 70 | ||
| 71 | + - name: Test online CTC | ||
| 72 | + shell: bash | ||
| 73 | + run: | | ||
| 74 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 75 | + export EXE=sherpa-onnx.exe | ||
| 76 | + | ||
| 77 | + .github/scripts/test-online-ctc.sh | ||
| 78 | + | ||
| 69 | - name: Test offline TTS | 79 | - name: Test offline TTS |
| 70 | shell: bash | 80 | shell: bash |
| 71 | run: | | 81 | run: | |
| @@ -12,6 +12,7 @@ on: | @@ -12,6 +12,7 @@ on: | ||
| 12 | - '.github/scripts/test-online-paraformer.sh' | 12 | - '.github/scripts/test-online-paraformer.sh' |
| 13 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | + - '.github/scripts/test-online-ctc.sh' | ||
| 15 | - '.github/scripts/test-offline-tts.sh' | 16 | - '.github/scripts/test-offline-tts.sh' |
| 16 | - 'CMakeLists.txt' | 17 | - 'CMakeLists.txt' |
| 17 | - 'cmake/**' | 18 | - 'cmake/**' |
| @@ -25,6 +26,7 @@ on: | @@ -25,6 +26,7 @@ on: | ||
| 25 | - '.github/scripts/test-online-paraformer.sh' | 26 | - '.github/scripts/test-online-paraformer.sh' |
| 26 | - '.github/scripts/test-offline-transducer.sh' | 27 | - '.github/scripts/test-offline-transducer.sh' |
| 27 | - '.github/scripts/test-offline-ctc.sh' | 28 | - '.github/scripts/test-offline-ctc.sh' |
| 29 | + - '.github/scripts/test-online-ctc.sh' | ||
| 28 | - '.github/scripts/test-offline-tts.sh' | 30 | - '.github/scripts/test-offline-tts.sh' |
| 29 | - 'CMakeLists.txt' | 31 | - 'CMakeLists.txt' |
| 30 | - 'cmake/**' | 32 | - 'cmake/**' |
| @@ -67,6 +69,14 @@ jobs: | @@ -67,6 +69,14 @@ jobs: | ||
| 67 | 69 | ||
| 68 | ls -lh ./bin/Release/sherpa-onnx.exe | 70 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 69 | 71 | ||
| 72 | + - name: Test online CTC | ||
| 73 | + shell: bash | ||
| 74 | + run: | | ||
| 75 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 76 | + export EXE=sherpa-onnx.exe | ||
| 77 | + | ||
| 78 | + .github/scripts/test-online-ctc.sh | ||
| 79 | + | ||
| 70 | - name: Test offline TTS | 80 | - name: Test offline TTS |
| 71 | shell: bash | 81 | shell: bash |
| 72 | run: | | 82 | run: | |
| @@ -13,6 +13,7 @@ on: | @@ -13,6 +13,7 @@ on: | ||
| 13 | - '.github/scripts/test-offline-transducer.sh' | 13 | - '.github/scripts/test-offline-transducer.sh' |
| 14 | - '.github/scripts/test-offline-ctc.sh' | 14 | - '.github/scripts/test-offline-ctc.sh' |
| 15 | - '.github/scripts/test-offline-tts.sh' | 15 | - '.github/scripts/test-offline-tts.sh' |
| 16 | + - '.github/scripts/test-online-ctc.sh' | ||
| 16 | - 'CMakeLists.txt' | 17 | - 'CMakeLists.txt' |
| 17 | - 'cmake/**' | 18 | - 'cmake/**' |
| 18 | - 'sherpa-onnx/csrc/*' | 19 | - 'sherpa-onnx/csrc/*' |
| @@ -26,6 +27,7 @@ on: | @@ -26,6 +27,7 @@ on: | ||
| 26 | - '.github/scripts/test-offline-transducer.sh' | 27 | - '.github/scripts/test-offline-transducer.sh' |
| 27 | - '.github/scripts/test-offline-ctc.sh' | 28 | - '.github/scripts/test-offline-ctc.sh' |
| 28 | - '.github/scripts/test-offline-tts.sh' | 29 | - '.github/scripts/test-offline-tts.sh' |
| 30 | + - '.github/scripts/test-online-ctc.sh' | ||
| 29 | - 'CMakeLists.txt' | 31 | - 'CMakeLists.txt' |
| 30 | - 'cmake/**' | 32 | - 'cmake/**' |
| 31 | - 'sherpa-onnx/csrc/*' | 33 | - 'sherpa-onnx/csrc/*' |
| @@ -67,6 +69,14 @@ jobs: | @@ -67,6 +69,14 @@ jobs: | ||
| 67 | 69 | ||
| 68 | ls -lh ./bin/Release/sherpa-onnx.exe | 70 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 69 | 71 | ||
| 72 | + - name: Test online CTC | ||
| 73 | + shell: bash | ||
| 74 | + run: | | ||
| 75 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 76 | + export EXE=sherpa-onnx.exe | ||
| 77 | + | ||
| 78 | + .github/scripts/test-online-ctc.sh | ||
| 79 | + | ||
| 70 | - name: Test offline TTS | 80 | - name: Test offline TTS |
| 71 | shell: bash | 81 | shell: bash |
| 72 | run: | | 82 | run: | |
| @@ -164,6 +164,7 @@ def main(): | @@ -164,6 +164,7 @@ def main(): | ||
| 164 | dynamic_axes={ | 164 | dynamic_axes={ |
| 165 | "x": {0: "N", 1: "T"}, | 165 | "x": {0: "N", 1: "T"}, |
| 166 | "attn_cache": {2: "T"}, | 166 | "attn_cache": {2: "T"}, |
| 167 | + "attn_mask": {2: "T"}, | ||
| 167 | "log_probs": {0: "N"}, | 168 | "log_probs": {0: "N"}, |
| 168 | "new_attn_cache": {2: "T"}, | 169 | "new_attn_cache": {2: "T"}, |
| 169 | }, | 170 | }, |
| @@ -49,6 +49,8 @@ set(sources | @@ -49,6 +49,8 @@ set(sources | ||
| 49 | offline-zipformer-ctc-model-config.cc | 49 | offline-zipformer-ctc-model-config.cc |
| 50 | offline-zipformer-ctc-model.cc | 50 | offline-zipformer-ctc-model.cc |
| 51 | online-conformer-transducer-model.cc | 51 | online-conformer-transducer-model.cc |
| 52 | + online-ctc-greedy-search-decoder.cc | ||
| 53 | + online-ctc-model.cc | ||
| 52 | online-lm-config.cc | 54 | online-lm-config.cc |
| 53 | online-lm.cc | 55 | online-lm.cc |
| 54 | online-lstm-transducer-model.cc | 56 | online-lstm-transducer-model.cc |
| @@ -64,6 +66,8 @@ set(sources | @@ -64,6 +66,8 @@ set(sources | ||
| 64 | online-transducer-model-config.cc | 66 | online-transducer-model-config.cc |
| 65 | online-transducer-model.cc | 67 | online-transducer-model.cc |
| 66 | online-transducer-modified-beam-search-decoder.cc | 68 | online-transducer-modified-beam-search-decoder.cc |
| 69 | + online-wenet-ctc-model-config.cc | ||
| 70 | + online-wenet-ctc-model.cc | ||
| 67 | online-zipformer-transducer-model.cc | 71 | online-zipformer-transducer-model.cc |
| 68 | online-zipformer2-transducer-model.cc | 72 | online-zipformer2-transducer-model.cc |
| 69 | onnx-utils.cc | 73 | onnx-utils.cc |
sherpa-onnx/csrc/online-ctc-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-ctc-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OnlineCtcDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int64_t> tokens; | ||
| 17 | + | ||
| 18 | + /// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
| 19 | + /// Note: The index is after subsampling | ||
| 20 | + std::vector<int32_t> timestamps; | ||
| 21 | + | ||
| 22 | + int32_t num_trailing_blanks = 0; | ||
| 23 | +}; | ||
| 24 | + | ||
| 25 | +class OnlineCtcDecoder { | ||
| 26 | + public: | ||
| 27 | + virtual ~OnlineCtcDecoder() = default; | ||
| 28 | + | ||
| 29 | + /** Run streaming CTC decoding given the output from the encoder model. | ||
| 30 | + * | ||
| 31 | + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing | ||
| 32 | + * lob_probs. | ||
| 33 | + * | ||
| 34 | + * @param results Input & Output parameters.. | ||
| 35 | + */ | ||
| 36 | + virtual void Decode(Ort::Value log_probs, | ||
| 37 | + std::vector<OnlineCtcDecoderResult> *results) = 0; | ||
| 38 | +}; | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx | ||
| 41 | + | ||
| 42 | +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +void OnlineCtcGreedySearchDecoder::Decode( | ||
| 16 | + Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) { | ||
| 17 | + std::vector<int64_t> log_probs_shape = | ||
| 18 | + log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 19 | + | ||
| 20 | + if (log_probs_shape[0] != results->size()) { | ||
| 21 | + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", | ||
| 22 | + static_cast<int32_t>(log_probs_shape[0]), | ||
| 23 | + static_cast<int32_t>(results->size())); | ||
| 24 | + exit(-1); | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]); | ||
| 28 | + int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]); | ||
| 29 | + int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]); | ||
| 30 | + | ||
| 31 | + const float *p = log_probs.GetTensorData<float>(); | ||
| 32 | + | ||
| 33 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 34 | + auto &r = (*results)[b]; | ||
| 35 | + | ||
| 36 | + int32_t prev_id = -1; | ||
| 37 | + | ||
| 38 | + for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) { | ||
| 39 | + int32_t y = static_cast<int32_t>(std::distance( | ||
| 40 | + static_cast<const float *>(p), | ||
| 41 | + std::max_element(static_cast<const float *>(p), | ||
| 42 | + static_cast<const float *>(p) + vocab_size))); | ||
| 43 | + | ||
| 44 | + if (y == blank_id_) { | ||
| 45 | + r.num_trailing_blanks += 1; | ||
| 46 | + } else { | ||
| 47 | + r.num_trailing_blanks = 0; | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + if (y != blank_id_ && y != prev_id) { | ||
| 51 | + r.tokens.push_back(y); | ||
| 52 | + r.timestamps.push_back(t); | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + prev_id = y; | ||
| 56 | + } // for (int32_t t = 0; t != num_frames; ++t) { | ||
| 57 | + } // for (int32_t b = 0; b != batch_size; ++b) | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-ctc-decoder.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { | ||
| 15 | + public: | ||
| 16 | + explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) | ||
| 17 | + : blank_id_(blank_id) {} | ||
| 18 | + | ||
| 19 | + void Decode(Ort::Value log_probs, | ||
| 20 | + std::vector<OnlineCtcDecoderResult> *results) override; | ||
| 21 | + | ||
| 22 | + private: | ||
| 23 | + int32_t blank_id_; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ |
sherpa-onnx/csrc/online-ctc-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <sstream> | ||
| 10 | +#include <string> | ||
| 11 | + | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/online-wenet-ctc-model.h" | ||
| 14 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 15 | + | ||
| 16 | +namespace { | ||
| 17 | + | ||
| 18 | +enum class ModelType { | ||
| 19 | + kZipformerCtc, | ||
| 20 | + kWenetCtc, | ||
| 21 | + kUnkown, | ||
| 22 | +}; | ||
| 23 | + | ||
| 24 | +} // namespace | ||
| 25 | + | ||
| 26 | +namespace sherpa_onnx { | ||
| 27 | + | ||
| 28 | +static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 29 | + bool debug) { | ||
| 30 | + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); | ||
| 31 | + Ort::SessionOptions sess_opts; | ||
| 32 | + | ||
| 33 | + auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length, | ||
| 34 | + sess_opts); | ||
| 35 | + | ||
| 36 | + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); | ||
| 37 | + if (debug) { | ||
| 38 | + std::ostringstream os; | ||
| 39 | + PrintModelMetadata(os, meta_data); | ||
| 40 | + SHERPA_ONNX_LOGE("%s", os.str().c_str()); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 44 | + auto model_type = | ||
| 45 | + meta_data.LookupCustomMetadataMapAllocated("model_type", allocator); | ||
| 46 | + if (!model_type) { | ||
| 47 | + SHERPA_ONNX_LOGE( | ||
| 48 | + "No model_type in the metadata!\n" | ||
| 49 | + "If you are using models from WeNet, please refer to\n" | ||
| 50 | + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wenet/" | ||
| 51 | + "run.sh\n" | ||
| 52 | + "\n" | ||
| 53 | + "for how to add metadta to model.onnx\n"); | ||
| 54 | + return ModelType::kUnkown; | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + if (model_type.get() == std::string("zipformer2")) { | ||
| 58 | + return ModelType::kZipformerCtc; | ||
| 59 | + } else if (model_type.get() == std::string("wenet_ctc")) { | ||
| 60 | + return ModelType::kWenetCtc; | ||
| 61 | + } else { | ||
| 62 | + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | ||
| 63 | + return ModelType::kUnkown; | ||
| 64 | + } | ||
| 65 | +} | ||
| 66 | + | ||
| 67 | +std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | ||
| 68 | + const OnlineModelConfig &config) { | ||
| 69 | + ModelType model_type = ModelType::kUnkown; | ||
| 70 | + | ||
| 71 | + std::string filename; | ||
| 72 | + if (!config.wenet_ctc.model.empty()) { | ||
| 73 | + filename = config.wenet_ctc.model; | ||
| 74 | + } else { | ||
| 75 | + SHERPA_ONNX_LOGE("Please specify a CTC model"); | ||
| 76 | + exit(-1); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + { | ||
| 80 | + auto buffer = ReadFile(filename); | ||
| 81 | + | ||
| 82 | + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + switch (model_type) { | ||
| 86 | + case ModelType::kZipformerCtc: | ||
| 87 | + return nullptr; | ||
| 88 | + // return std::make_unique<OnlineZipformerCtcModel>(config); | ||
| 89 | + break; | ||
| 90 | + case ModelType::kWenetCtc: | ||
| 91 | + return std::make_unique<OnlineWenetCtcModel>(config); | ||
| 92 | + break; | ||
| 93 | + case ModelType::kUnkown: | ||
| 94 | + SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); | ||
| 95 | + return nullptr; | ||
| 96 | + } | ||
| 97 | + | ||
| 98 | + return nullptr; | ||
| 99 | +} | ||
| 100 | + | ||
| 101 | +#if __ANDROID_API__ >= 9 | ||
| 102 | + | ||
| 103 | +std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create( | ||
| 104 | + AAssetManager *mgr, const OnlineModelConfig &config) { | ||
| 105 | + ModelType model_type = ModelType::kUnkown; | ||
| 106 | + | ||
| 107 | + std::string filename; | ||
| 108 | + if (!config.wenet_ctc.model.empty()) { | ||
| 109 | + filename = config.wenet_ctc.model; | ||
| 110 | + } else { | ||
| 111 | + SHERPA_ONNX_LOGE("Please specify a CTC model"); | ||
| 112 | + exit(-1); | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + { | ||
| 116 | + auto buffer = ReadFile(mgr, filename); | ||
| 117 | + | ||
| 118 | + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + switch (model_type) { | ||
| 122 | + case ModelType::kZipformerCtc: | ||
| 123 | + return nullptr; | ||
| 124 | + // return std::make_unique<OnlineZipformerCtcModel>(mgr, config); | ||
| 125 | + break; | ||
| 126 | + case ModelType::kWenetCtc: | ||
| 127 | + return std::make_unique<OnlineWenetCtcModel>(mgr, config); | ||
| 128 | + break; | ||
| 129 | + case ModelType::kUnkown: | ||
| 130 | + SHERPA_ONNX_LOGE("Unknown model type in online CTC!"); | ||
| 131 | + return nullptr; | ||
| 132 | + } | ||
| 133 | + | ||
| 134 | + return nullptr; | ||
| 135 | +} | ||
| 136 | +#endif | ||
| 137 | + | ||
| 138 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-ctc-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +class OnlineCtcModel { | ||
| 22 | + public: | ||
| 23 | + virtual ~OnlineCtcModel() = default; | ||
| 24 | + | ||
| 25 | + static std::unique_ptr<OnlineCtcModel> Create( | ||
| 26 | + const OnlineModelConfig &config); | ||
| 27 | + | ||
| 28 | +#if __ANDROID_API__ >= 9 | ||
| 29 | + static std::unique_ptr<OnlineCtcModel> Create( | ||
| 30 | + AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 31 | +#endif | ||
| 32 | + | ||
| 33 | + // Return a list of tensors containing the initial states | ||
| 34 | + virtual std::vector<Ort::Value> GetInitStates() const = 0; | ||
| 35 | + | ||
| 36 | + /** | ||
| 37 | + * | ||
| 38 | + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. | ||
| 39 | + * @param states It is from GetInitStates() or returned from this method. | ||
| 40 | + * | ||
| 41 | + * @return Return a list of tensors | ||
| 42 | + * - ans[0] contains log_probs, of shape (N, T, C) | ||
| 43 | + * - ans[1:] contains next_states | ||
| 44 | + */ | ||
| 45 | + virtual std::vector<Ort::Value> Forward( | ||
| 46 | + Ort::Value x, std::vector<Ort::Value> states) const = 0; | ||
| 47 | + | ||
| 48 | + /** Return the vocabulary size of the model | ||
| 49 | + */ | ||
| 50 | + virtual int32_t VocabSize() const = 0; | ||
| 51 | + | ||
| 52 | + /** Return an allocator for allocating memory | ||
| 53 | + */ | ||
| 54 | + virtual OrtAllocator *Allocator() const = 0; | ||
| 55 | + | ||
| 56 | + // The model accepts this number of frames before subsampling as input | ||
| 57 | + virtual int32_t ChunkLength() const = 0; | ||
| 58 | + | ||
| 59 | + // Similar to frame_shift in feature extractor, after processing | ||
| 60 | + // ChunkLength() frames, we advance by ChunkShift() frames | ||
| 61 | + // before we process the next chunk. | ||
| 62 | + virtual int32_t ChunkShift() const = 0; | ||
| 63 | +}; | ||
| 64 | + | ||
| 65 | +} // namespace sherpa_onnx | ||
| 66 | + | ||
| 67 | +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ |
| @@ -13,6 +13,7 @@ namespace sherpa_onnx { | @@ -13,6 +13,7 @@ namespace sherpa_onnx { | ||
| 13 | void OnlineModelConfig::Register(ParseOptions *po) { | 13 | void OnlineModelConfig::Register(ParseOptions *po) { |
| 14 | transducer.Register(po); | 14 | transducer.Register(po); |
| 15 | paraformer.Register(po); | 15 | paraformer.Register(po); |
| 16 | + wenet_ctc.Register(po); | ||
| 16 | 17 | ||
| 17 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 18 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 18 | 19 | ||
| @@ -46,6 +47,10 @@ bool OnlineModelConfig::Validate() const { | @@ -46,6 +47,10 @@ bool OnlineModelConfig::Validate() const { | ||
| 46 | return paraformer.Validate(); | 47 | return paraformer.Validate(); |
| 47 | } | 48 | } |
| 48 | 49 | ||
| 50 | + if (!wenet_ctc.model.empty()) { | ||
| 51 | + return wenet_ctc.Validate(); | ||
| 52 | + } | ||
| 53 | + | ||
| 49 | return transducer.Validate(); | 54 | return transducer.Validate(); |
| 50 | } | 55 | } |
| 51 | 56 | ||
| @@ -55,6 +60,7 @@ std::string OnlineModelConfig::ToString() const { | @@ -55,6 +60,7 @@ std::string OnlineModelConfig::ToString() const { | ||
| 55 | os << "OnlineModelConfig("; | 60 | os << "OnlineModelConfig("; |
| 56 | os << "transducer=" << transducer.ToString() << ", "; | 61 | os << "transducer=" << transducer.ToString() << ", "; |
| 57 | os << "paraformer=" << paraformer.ToString() << ", "; | 62 | os << "paraformer=" << paraformer.ToString() << ", "; |
| 63 | + os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; | ||
| 58 | os << "tokens=\"" << tokens << "\", "; | 64 | os << "tokens=\"" << tokens << "\", "; |
| 59 | os << "num_threads=" << num_threads << ", "; | 65 | os << "num_threads=" << num_threads << ", "; |
| 60 | os << "debug=" << (debug ? "True" : "False") << ", "; | 66 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -8,12 +8,14 @@ | @@ -8,12 +8,14 @@ | ||
| 8 | 8 | ||
| 9 | #include "sherpa-onnx/csrc/online-paraformer-model-config.h" | 9 | #include "sherpa-onnx/csrc/online-paraformer-model-config.h" |
| 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 10 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 11 | +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" | ||
| 11 | 12 | ||
| 12 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 13 | 14 | ||
| 14 | struct OnlineModelConfig { | 15 | struct OnlineModelConfig { |
| 15 | OnlineTransducerModelConfig transducer; | 16 | OnlineTransducerModelConfig transducer; |
| 16 | OnlineParaformerModelConfig paraformer; | 17 | OnlineParaformerModelConfig paraformer; |
| 18 | + OnlineWenetCtcModelConfig wenet_ctc; | ||
| 17 | std::string tokens; | 19 | std::string tokens; |
| 18 | int32_t num_threads = 1; | 20 | int32_t num_threads = 1; |
| 19 | bool debug = false; | 21 | bool debug = false; |
| @@ -31,10 +33,12 @@ struct OnlineModelConfig { | @@ -31,10 +33,12 @@ struct OnlineModelConfig { | ||
| 31 | OnlineModelConfig() = default; | 33 | OnlineModelConfig() = default; |
| 32 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, | 34 | OnlineModelConfig(const OnlineTransducerModelConfig &transducer, |
| 33 | const OnlineParaformerModelConfig ¶former, | 35 | const OnlineParaformerModelConfig ¶former, |
| 36 | + const OnlineWenetCtcModelConfig &wenet_ctc, | ||
| 34 | const std::string &tokens, int32_t num_threads, bool debug, | 37 | const std::string &tokens, int32_t num_threads, bool debug, |
| 35 | const std::string &provider, const std::string &model_type) | 38 | const std::string &provider, const std::string &model_type) |
| 36 | : transducer(transducer), | 39 | : transducer(transducer), |
| 37 | paraformer(paraformer), | 40 | paraformer(paraformer), |
| 41 | + wenet_ctc(wenet_ctc), | ||
| 38 | tokens(tokens), | 42 | tokens(tokens), |
| 39 | num_threads(num_threads), | 43 | num_threads(num_threads), |
| 40 | debug(debug), | 44 | debug(debug), |
| 1 | +// sherpa-onnx/csrc/online-recognizer-ctc-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 15 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 16 | +#include "sherpa-onnx/csrc/online-ctc-decoder.h" | ||
| 17 | +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-ctc-model.h" | ||
| 19 | +#include "sherpa-onnx/csrc/online-recognizer-impl.h" | ||
| 20 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 21 | + | ||
| 22 | +namespace sherpa_onnx { | ||
| 23 | + | ||
| 24 | +static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, | ||
| 25 | + const SymbolTable &sym_table, | ||
| 26 | + float frame_shift_ms, | ||
| 27 | + int32_t subsampling_factor, | ||
| 28 | + int32_t segment, | ||
| 29 | + int32_t frames_since_start) { | ||
| 30 | + OnlineRecognizerResult r; | ||
| 31 | + r.tokens.reserve(src.tokens.size()); | ||
| 32 | + r.timestamps.reserve(src.tokens.size()); | ||
| 33 | + | ||
| 34 | + for (auto i : src.tokens) { | ||
| 35 | + auto sym = sym_table[i]; | ||
| 36 | + | ||
| 37 | + r.text.append(sym); | ||
| 38 | + r.tokens.push_back(std::move(sym)); | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 42 | + for (auto t : src.timestamps) { | ||
| 43 | + float time = frame_shift_s * t; | ||
| 44 | + r.timestamps.push_back(time); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + r.segment = segment; | ||
| 48 | + r.start_time = frames_since_start * frame_shift_ms / 1000.; | ||
| 49 | + | ||
| 50 | + return r; | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 54 | + public: | ||
| 55 | + explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) | ||
| 56 | + : config_(config), | ||
| 57 | + model_(OnlineCtcModel::Create(config.model_config)), | ||
| 58 | + sym_(config.model_config.tokens), | ||
| 59 | + endpoint_(config_.endpoint_config) { | ||
| 60 | + if (!config.model_config.wenet_ctc.model.empty()) { | ||
| 61 | + // WeNet CTC models assume input samples are in the range | ||
| 62 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 63 | + config_.feat_config.normalize_samples = false; | ||
| 64 | + } | ||
| 65 | + | ||
| 66 | + InitDecoder(); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | +#if __ANDROID_API__ >= 9 | ||
| 70 | + explicit OnlineRecognizerCtcImpl(AAssetManager *mgr, | ||
| 71 | + const OnlineRecognizerConfig &config) | ||
| 72 | + : config_(config), | ||
| 73 | + model_(OnlineCtcModel::Create(mgr, config.model_config)), | ||
| 74 | + sym_(mgr, config.model_config.tokens), | ||
| 75 | + endpoint_(config_.endpoint_config) { | ||
| 76 | + if (!config.model_config.wenet_ctc.model.empty()) { | ||
| 77 | + // WeNet CTC models assume input samples are in the range | ||
| 78 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 79 | + config_.feat_config.normalize_samples = false; | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + InitDecoder(); | ||
| 83 | + } | ||
| 84 | +#endif | ||
| 85 | + | ||
| 86 | + std::unique_ptr<OnlineStream> CreateStream() const override { | ||
| 87 | + auto stream = std::make_unique<OnlineStream>(config_.feat_config); | ||
| 88 | + stream->SetStates(model_->GetInitStates()); | ||
| 89 | + | ||
| 90 | + return stream; | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + bool IsReady(OnlineStream *s) const override { | ||
| 94 | + return s->GetNumProcessedFrames() + model_->ChunkLength() < | ||
| 95 | + s->NumFramesReady(); | ||
| 96 | + } | ||
| 97 | + | ||
| 98 | + void DecodeStreams(OnlineStream **ss, int32_t n) const override { | ||
| 99 | + for (int32_t i = 0; i != n; ++i) { | ||
| 100 | + DecodeStream(ss[i]); | ||
| 101 | + } | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + OnlineRecognizerResult GetResult(OnlineStream *s) const override { | ||
| 105 | + OnlineCtcDecoderResult decoder_result = s->GetCtcResult(); | ||
| 106 | + | ||
| 107 | + // TODO(fangjun): Remember to change these constants if needed | ||
| 108 | + int32_t frame_shift_ms = 10; | ||
| 109 | + int32_t subsampling_factor = 4; | ||
| 110 | + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, | ||
| 111 | + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); | ||
| 112 | + } | ||
| 113 | + | ||
| 114 | + bool IsEndpoint(OnlineStream *s) const override { | ||
| 115 | + if (!config_.enable_endpoint) { | ||
| 116 | + return false; | ||
| 117 | + } | ||
| 118 | + | ||
| 119 | + int32_t num_processed_frames = s->GetNumProcessedFrames(); | ||
| 120 | + | ||
| 121 | + // frame shift is 10 milliseconds | ||
| 122 | + float frame_shift_in_seconds = 0.01; | ||
| 123 | + | ||
| 124 | + // subsampling factor is 4 | ||
| 125 | + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4; | ||
| 126 | + | ||
| 127 | + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, | ||
| 128 | + frame_shift_in_seconds); | ||
| 129 | + } | ||
| 130 | + | ||
| 131 | + void Reset(OnlineStream *s) const override { | ||
| 132 | + // segment is incremented only when the last | ||
| 133 | + // result is not empty | ||
| 134 | + const auto &r = s->GetCtcResult(); | ||
| 135 | + if (!r.tokens.empty()) { | ||
| 136 | + s->GetCurrentSegment() += 1; | ||
| 137 | + } | ||
| 138 | + | ||
| 139 | + // clear result | ||
| 140 | + s->SetCtcResult({}); | ||
| 141 | + | ||
| 142 | + // clear states | ||
| 143 | + s->SetStates(model_->GetInitStates()); | ||
| 144 | + | ||
| 145 | + // Note: We only update counters. The underlying audio samples | ||
| 146 | + // are not discarded. | ||
| 147 | + s->Reset(); | ||
| 148 | + } | ||
| 149 | + | ||
| 150 | + private: | ||
| 151 | + void InitDecoder() { | ||
| 152 | + if (config_.decoding_method == "greedy_search") { | ||
| 153 | + if (!sym_.contains("<blk>") && !sym_.contains("<eps>") && | ||
| 154 | + !sym_.contains("<blank>")) { | ||
| 155 | + SHERPA_ONNX_LOGE( | ||
| 156 | + "We expect that tokens.txt contains " | ||
| 157 | + "the symbol <blk> or <eps> or <blank> and its ID."); | ||
| 158 | + exit(-1); | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + int32_t blank_id = 0; | ||
| 162 | + if (sym_.contains("<blk>")) { | ||
| 163 | + blank_id = sym_["<blk>"]; | ||
| 164 | + } else if (sym_.contains("<eps>")) { | ||
| 165 | + // for tdnn models of the yesno recipe from icefall | ||
| 166 | + blank_id = sym_["<eps>"]; | ||
| 167 | + } else if (sym_.contains("<blank>")) { | ||
| 168 | + // for WeNet CTC models | ||
| 169 | + blank_id = sym_["<blank>"]; | ||
| 170 | + } | ||
| 171 | + | ||
| 172 | + decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id); | ||
| 173 | + } else { | ||
| 174 | + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 175 | + config_.decoding_method.c_str()); | ||
| 176 | + exit(-1); | ||
| 177 | + } | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + void DecodeStream(OnlineStream *s) const { | ||
| 181 | + int32_t chunk_length = model_->ChunkLength(); | ||
| 182 | + int32_t chunk_shift = model_->ChunkShift(); | ||
| 183 | + | ||
| 184 | + int32_t feat_dim = s->FeatureDim(); | ||
| 185 | + | ||
| 186 | + const auto num_processed_frames = s->GetNumProcessedFrames(); | ||
| 187 | + std::vector<float> frames = | ||
| 188 | + s->GetFrames(num_processed_frames, chunk_length); | ||
| 189 | + s->GetNumProcessedFrames() += chunk_shift; | ||
| 190 | + | ||
| 191 | + auto memory_info = | ||
| 192 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 193 | + | ||
| 194 | + std::array<int64_t, 3> x_shape{1, chunk_length, feat_dim}; | ||
| 195 | + Ort::Value x = | ||
| 196 | + Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(), | ||
| 197 | + x_shape.data(), x_shape.size()); | ||
| 198 | + auto out = model_->Forward(std::move(x), std::move(s->GetStates())); | ||
| 199 | + int32_t num_states = static_cast<int32_t>(out.size()) - 1; | ||
| 200 | + | ||
| 201 | + std::vector<Ort::Value> states; | ||
| 202 | + states.reserve(num_states); | ||
| 203 | + | ||
| 204 | + for (int32_t i = 0; i != num_states; ++i) { | ||
| 205 | + states.push_back(std::move(out[i + 1])); | ||
| 206 | + } | ||
| 207 | + s->SetStates(std::move(states)); | ||
| 208 | + | ||
| 209 | + std::vector<OnlineCtcDecoderResult> results(1); | ||
| 210 | + results[0] = std::move(s->GetCtcResult()); | ||
| 211 | + | ||
| 212 | + decoder_->Decode(std::move(out[0]), &results); | ||
| 213 | + s->SetCtcResult(results[0]); | ||
| 214 | + } | ||
| 215 | + | ||
| 216 | + private: | ||
| 217 | + OnlineRecognizerConfig config_; | ||
| 218 | + std::unique_ptr<OnlineCtcModel> model_; | ||
| 219 | + std::unique_ptr<OnlineCtcDecoder> decoder_; | ||
| 220 | + SymbolTable sym_; | ||
| 221 | + Endpoint endpoint_; | ||
| 222 | +}; | ||
| 223 | + | ||
| 224 | +} // namespace sherpa_onnx | ||
| 225 | + | ||
| 226 | +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ |
| @@ -4,6 +4,7 @@ | @@ -4,6 +4,7 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 5 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| 6 | 6 | ||
| 7 | +#include "sherpa-onnx/csrc/online-recognizer-ctc-impl.h" | ||
| 7 | #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" | 8 | #include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h" |
| 8 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" | 9 | #include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h" |
| 9 | 10 | ||
| @@ -19,6 +20,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -19,6 +20,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 19 | return std::make_unique<OnlineRecognizerParaformerImpl>(config); | 20 | return std::make_unique<OnlineRecognizerParaformerImpl>(config); |
| 20 | } | 21 | } |
| 21 | 22 | ||
| 23 | + if (!config.model_config.wenet_ctc.model.empty()) { | ||
| 24 | + return std::make_unique<OnlineRecognizerCtcImpl>(config); | ||
| 25 | + } | ||
| 26 | + | ||
| 22 | SHERPA_ONNX_LOGE("Please specify a model"); | 27 | SHERPA_ONNX_LOGE("Please specify a model"); |
| 23 | exit(-1); | 28 | exit(-1); |
| 24 | } | 29 | } |
| @@ -34,6 +39,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | @@ -34,6 +39,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create( | ||
| 34 | return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); | 39 | return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config); |
| 35 | } | 40 | } |
| 36 | 41 | ||
| 42 | + if (!config.model_config.wenet_ctc.model.empty()) { | ||
| 43 | + return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config); | ||
| 44 | + } | ||
| 45 | + | ||
| 37 | SHERPA_ONNX_LOGE("Please specify a model"); | 46 | SHERPA_ONNX_LOGE("Please specify a model"); |
| 38 | exit(-1); | 47 | exit(-1); |
| 39 | } | 48 | } |
| @@ -120,11 +120,7 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | @@ -120,11 +120,7 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { | ||
| 120 | model_(mgr, config.model_config), | 120 | model_(mgr, config.model_config), |
| 121 | sym_(mgr, config.model_config.tokens), | 121 | sym_(mgr, config.model_config.tokens), |
| 122 | endpoint_(config_.endpoint_config) { | 122 | endpoint_(config_.endpoint_config) { |
| 123 | - if (config.decoding_method == "greedy_search") { | ||
| 124 | - // add greedy search decoder | ||
| 125 | - // SHERPA_ONNX_LOGE("to be implemented"); | ||
| 126 | - // exit(-1); | ||
| 127 | - } else { | 123 | + if (config.decoding_method != "greedy_search") { |
| 128 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 124 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 129 | config.decoding_method.c_str()); | 125 | config.decoding_method.c_str()); |
| 130 | exit(-1); | 126 | exit(-1); |
| @@ -51,6 +51,10 @@ class OnlineStream::Impl { | @@ -51,6 +51,10 @@ class OnlineStream::Impl { | ||
| 51 | 51 | ||
| 52 | OnlineTransducerDecoderResult &GetResult() { return result_; } | 52 | OnlineTransducerDecoderResult &GetResult() { return result_; } |
| 53 | 53 | ||
| 54 | + OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } | ||
| 55 | + | ||
| 56 | + void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; } | ||
| 57 | + | ||
| 54 | void SetParaformerResult(const OnlineParaformerDecoderResult &r) { | 58 | void SetParaformerResult(const OnlineParaformerDecoderResult &r) { |
| 55 | paraformer_result_ = r; | 59 | paraformer_result_ = r; |
| 56 | } | 60 | } |
| @@ -89,7 +93,8 @@ class OnlineStream::Impl { | @@ -89,7 +93,8 @@ class OnlineStream::Impl { | ||
| 89 | int32_t start_frame_index_ = 0; // never reset | 93 | int32_t start_frame_index_ = 0; // never reset |
| 90 | int32_t segment_ = 0; | 94 | int32_t segment_ = 0; |
| 91 | OnlineTransducerDecoderResult result_; | 95 | OnlineTransducerDecoderResult result_; |
| 92 | - std::vector<Ort::Value> states_; | 96 | + OnlineCtcDecoderResult ctc_result_; |
| 97 | + std::vector<Ort::Value> states_; // states for transducer or ctc models | ||
| 93 | std::vector<float> paraformer_feat_cache_; | 98 | std::vector<float> paraformer_feat_cache_; |
| 94 | std::vector<float> paraformer_encoder_out_cache_; | 99 | std::vector<float> paraformer_encoder_out_cache_; |
| 95 | std::vector<float> paraformer_alpha_cache_; | 100 | std::vector<float> paraformer_alpha_cache_; |
| @@ -144,6 +149,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { | @@ -144,6 +149,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() { | ||
| 144 | return impl_->GetResult(); | 149 | return impl_->GetResult(); |
| 145 | } | 150 | } |
| 146 | 151 | ||
| 152 | +OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { | ||
| 153 | + return impl_->GetCtcResult(); | ||
| 154 | +} | ||
| 155 | + | ||
| 156 | +void OnlineStream::SetCtcResult(const OnlineCtcDecoderResult &r) { | ||
| 157 | + impl_->SetCtcResult(r); | ||
| 158 | +} | ||
| 159 | + | ||
| 147 | void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { | 160 | void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { |
| 148 | impl_->SetParaformerResult(r); | 161 | impl_->SetParaformerResult(r); |
| 149 | } | 162 | } |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/context-graph.h" | 12 | #include "sherpa-onnx/csrc/context-graph.h" |
| 13 | #include "sherpa-onnx/csrc/features.h" | 13 | #include "sherpa-onnx/csrc/features.h" |
| 14 | +#include "sherpa-onnx/csrc/online-ctc-decoder.h" | ||
| 14 | #include "sherpa-onnx/csrc/online-paraformer-decoder.h" | 15 | #include "sherpa-onnx/csrc/online-paraformer-decoder.h" |
| 15 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" | 16 | #include "sherpa-onnx/csrc/online-transducer-decoder.h" |
| 16 | 17 | ||
| @@ -75,6 +76,9 @@ class OnlineStream { | @@ -75,6 +76,9 @@ class OnlineStream { | ||
| 75 | void SetResult(const OnlineTransducerDecoderResult &r); | 76 | void SetResult(const OnlineTransducerDecoderResult &r); |
| 76 | OnlineTransducerDecoderResult &GetResult(); | 77 | OnlineTransducerDecoderResult &GetResult(); |
| 77 | 78 | ||
| 79 | + void SetCtcResult(const OnlineCtcDecoderResult &r); | ||
| 80 | + OnlineCtcDecoderResult &GetCtcResult(); | ||
| 81 | + | ||
| 78 | void SetParaformerResult(const OnlineParaformerDecoderResult &r); | 82 | void SetParaformerResult(const OnlineParaformerDecoderResult &r); |
| 79 | OnlineParaformerDecoderResult &GetParaformerResult(); | 83 | OnlineParaformerDecoderResult &GetParaformerResult(); |
| 80 | 84 |
| 1 | +// sherpa-onnx/csrc/online-wenet-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OnlineWenetCtcModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("wenet-ctc-model", &model, | ||
| 14 | + "Path to CTC model.onnx from WeNet. Please see " | ||
| 15 | + "https://github.com/k2-fsa/sherpa-onnx/pull/425"); | ||
| 16 | + po->Register("wenet-ctc-chunk-size", &chunk_size, | ||
| 17 | + "Chunk size after subsampling used for decoding."); | ||
| 18 | + po->Register("wenet-ctc-num-left-chunks", &num_left_chunks, | ||
| 19 | + "Number of left chunks after subsampling used for decoding."); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +bool OnlineWenetCtcModelConfig::Validate() const { | ||
| 23 | + if (!FileExists(model)) { | ||
| 24 | + SHERPA_ONNX_LOGE("WeNet CTC model %s does not exist", model.c_str()); | ||
| 25 | + return false; | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + if (chunk_size <= 0) { | ||
| 29 | + SHERPA_ONNX_LOGE( | ||
| 30 | + "Please specify a positive value for --wenet-ctc-chunk-size. Currently " | ||
| 31 | + "given: %d", | ||
| 32 | + chunk_size); | ||
| 33 | + return false; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + if (num_left_chunks <= 0) { | ||
| 37 | + SHERPA_ONNX_LOGE( | ||
| 38 | + "Please specify a positive value for --wenet-ctc-num-left-chunks. " | ||
| 39 | + "Currently given: %d. Note that if you want to use -1, please consider " | ||
| 40 | + "using a non-streaming model.", | ||
| 41 | + num_left_chunks); | ||
| 42 | + return false; | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + return true; | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +std::string OnlineWenetCtcModelConfig::ToString() const { | ||
| 49 | + std::ostringstream os; | ||
| 50 | + | ||
| 51 | + os << "OnlineWenetCtcModelConfig("; | ||
| 52 | + os << "model=\"" << model << "\", "; | ||
| 53 | + os << "chunk_size=" << chunk_size << ", "; | ||
| 54 | + os << "num_left_chunks=" << num_left_chunks << ")"; | ||
| 55 | + | ||
| 56 | + return os.str(); | ||
| 57 | +} | ||
| 58 | + | ||
| 59 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-wenet-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OnlineWenetCtcModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + // --chunk_size from wenet | ||
| 17 | + int32_t chunk_size = 16; | ||
| 18 | + | ||
| 19 | + // --num_left_chunks from wenet | ||
| 20 | + int32_t num_left_chunks = 4; | ||
| 21 | + | ||
| 22 | + OnlineWenetCtcModelConfig() = default; | ||
| 23 | + | ||
| 24 | + OnlineWenetCtcModelConfig(const std::string &model, int32_t chunk_size, | ||
| 25 | + int32_t num_left_chunks) | ||
| 26 | + : model(model), | ||
| 27 | + chunk_size(chunk_size), | ||
| 28 | + num_left_chunks(num_left_chunks) {} | ||
| 29 | + | ||
| 30 | + void Register(ParseOptions *po); | ||
| 31 | + bool Validate() const; | ||
| 32 | + | ||
| 33 | + std::string ToString() const; | ||
| 34 | +}; | ||
| 35 | + | ||
| 36 | +} // namespace sherpa_onnx | ||
| 37 | + | ||
| 38 | +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/online-wenet-ctc-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-paraformer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-wenet-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <cmath> | ||
| 9 | +#include <string> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 17 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 18 | +#include "sherpa-onnx/csrc/session.h" | ||
| 19 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 20 | + | ||
| 21 | +namespace sherpa_onnx { | ||
| 22 | + | ||
| 23 | +class OnlineWenetCtcModel::Impl { | ||
| 24 | + public: | ||
| 25 | + explicit Impl(const OnlineModelConfig &config) | ||
| 26 | + : config_(config), | ||
| 27 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 28 | + sess_opts_(GetSessionOptions(config)), | ||
| 29 | + allocator_{} { | ||
| 30 | + { | ||
| 31 | + auto buf = ReadFile(config.wenet_ctc.model); | ||
| 32 | + Init(buf.data(), buf.size()); | ||
| 33 | + } | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | +#if __ANDROID_API__ >= 9 | ||
| 37 | + Impl(AAssetManager *mgr, const OnlineModelConfig &config) | ||
| 38 | + : config_(config), | ||
| 39 | + env_(ORT_LOGGING_LEVEL_WARNING), | ||
| 40 | + sess_opts_(GetSessionOptions(config)), | ||
| 41 | + allocator_{} { | ||
| 42 | + { | ||
| 43 | + auto buf = ReadFile(mgr, config.wenet_ctc.model); | ||
| 44 | + Init(buf.data(), buf.size()); | ||
| 45 | + } | ||
| 46 | + } | ||
| 47 | +#endif | ||
| 48 | + | ||
| 49 | + std::vector<Ort::Value> Forward(Ort::Value x, | ||
| 50 | + std::vector<Ort::Value> states) { | ||
| 51 | + Ort::Value &attn_cache = states[0]; | ||
| 52 | + Ort::Value &conv_cache = states[1]; | ||
| 53 | + Ort::Value &offset = states[2]; | ||
| 54 | + | ||
| 55 | + int32_t chunk_size = config_.wenet_ctc.chunk_size; | ||
| 56 | + int32_t left_chunks = config_.wenet_ctc.num_left_chunks; | ||
| 57 | + // build attn_mask | ||
| 58 | + std::array<int64_t, 3> attn_mask_shape{1, 1, | ||
| 59 | + required_cache_size_ + chunk_size}; | ||
| 60 | + Ort::Value attn_mask = Ort::Value::CreateTensor<bool>( | ||
| 61 | + allocator_, attn_mask_shape.data(), attn_mask_shape.size()); | ||
| 62 | + bool *p = attn_mask.GetTensorMutableData<bool>(); | ||
| 63 | + int32_t chunk_idx = | ||
| 64 | + offset.GetTensorData<int64_t>()[0] / chunk_size - left_chunks; | ||
| 65 | + if (chunk_idx < left_chunks) { | ||
| 66 | + std::fill(p, p + required_cache_size_ - chunk_idx * chunk_size, 0); | ||
| 67 | + std::fill(p + required_cache_size_ - chunk_idx * chunk_size, | ||
| 68 | + p + attn_mask_shape[2], 1); | ||
| 69 | + } else { | ||
| 70 | + std::fill(p, p + attn_mask_shape[2], 1); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + std::array<Ort::Value, 6> inputs = {std::move(x), | ||
| 74 | + View(&offset), | ||
| 75 | + View(&required_cache_size_tensor_), | ||
| 76 | + std::move(attn_cache), | ||
| 77 | + std::move(conv_cache), | ||
| 78 | + std::move(attn_mask)}; | ||
| 79 | + | ||
| 80 | + auto out = | ||
| 81 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 82 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 83 | + | ||
| 84 | + offset.GetTensorMutableData<int64_t>()[0] += | ||
| 85 | + out[0].GetTensorTypeAndShapeInfo().GetShape()[1]; | ||
| 86 | + out.push_back(std::move(offset)); | ||
| 87 | + | ||
| 88 | + return out; | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 92 | + | ||
| 93 | + int32_t ChunkLength() const { | ||
| 94 | + // When chunk_size is 16, subsampling_factor_ is 4, right_context_ is 6, | ||
| 95 | + // the returned value is (16 - 1)*4 + 6 + 1 = 67 | ||
| 96 | + return (config_.wenet_ctc.chunk_size - 1) * subsampling_factor_ + | ||
| 97 | + right_context_ + 1; | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + int32_t ChunkShift() const { return required_cache_size_; } | ||
| 101 | + | ||
| 102 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 103 | + | ||
| 104 | + // Return a vector containing 3 tensors | ||
| 105 | + // - attn_cache | ||
| 106 | + // - conv_cache | ||
| 107 | + // - offset | ||
| 108 | + std::vector<Ort::Value> GetInitStates() const { | ||
| 109 | + std::vector<Ort::Value> ans; | ||
| 110 | + ans.reserve(3); | ||
| 111 | + ans.push_back(Clone(Allocator(), &attn_cache_)); | ||
| 112 | + ans.push_back(Clone(Allocator(), &conv_cache_)); | ||
| 113 | + | ||
| 114 | + int64_t offset_shape = 1; | ||
| 115 | + | ||
| 116 | + Ort::Value offset = | ||
| 117 | + Ort::Value::CreateTensor<int64_t>(allocator_, &offset_shape, 1); | ||
| 118 | + | ||
| 119 | + offset.GetTensorMutableData<int64_t>()[0] = required_cache_size_; | ||
| 120 | + | ||
| 121 | + ans.push_back(std::move(offset)); | ||
| 122 | + | ||
| 123 | + return ans; | ||
| 124 | + } | ||
| 125 | + | ||
| 126 | + private: | ||
| 127 | + void Init(void *model_data, size_t model_data_length) { | ||
| 128 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 129 | + sess_opts_); | ||
| 130 | + | ||
| 131 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 132 | + | ||
| 133 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 134 | + | ||
| 135 | + // get meta data | ||
| 136 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 137 | + if (config_.debug) { | ||
| 138 | + std::ostringstream os; | ||
| 139 | + PrintModelMetadata(os, meta_data); | ||
| 140 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 144 | + SHERPA_ONNX_READ_META_DATA(head_, "head"); | ||
| 145 | + SHERPA_ONNX_READ_META_DATA(num_blocks_, "num_blocks"); | ||
| 146 | + SHERPA_ONNX_READ_META_DATA(output_size_, "output_size"); | ||
| 147 | + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel"); | ||
| 148 | + SHERPA_ONNX_READ_META_DATA(right_context_, "right_context"); | ||
| 149 | + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); | ||
| 150 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 151 | + | ||
| 152 | + required_cache_size_ = | ||
| 153 | + config_.wenet_ctc.chunk_size * config_.wenet_ctc.num_left_chunks; | ||
| 154 | + | ||
| 155 | + InitStates(); | ||
| 156 | + } | ||
| 157 | + | ||
| 158 | + void InitStates() { | ||
| 159 | + std::array<int64_t, 4> attn_cache_shape{ | ||
| 160 | + num_blocks_, head_, required_cache_size_, output_size_ / head_ * 2}; | ||
| 161 | + attn_cache_ = Ort::Value::CreateTensor<float>( | ||
| 162 | + allocator_, attn_cache_shape.data(), attn_cache_shape.size()); | ||
| 163 | + | ||
| 164 | + Fill<float>(&attn_cache_, 0); | ||
| 165 | + | ||
| 166 | + std::array<int64_t, 4> conv_cache_shape{num_blocks_, 1, output_size_, | ||
| 167 | + cnn_module_kernel_ - 1}; | ||
| 168 | + conv_cache_ = Ort::Value::CreateTensor<float>( | ||
| 169 | + allocator_, conv_cache_shape.data(), conv_cache_shape.size()); | ||
| 170 | + | ||
| 171 | + Fill<float>(&conv_cache_, 0); | ||
| 172 | + | ||
| 173 | + int64_t shape = 1; | ||
| 174 | + required_cache_size_tensor_ = | ||
| 175 | + Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1); | ||
| 176 | + | ||
| 177 | + required_cache_size_tensor_.GetTensorMutableData<int64_t>()[0] = | ||
| 178 | + required_cache_size_; | ||
| 179 | + } | ||
| 180 | + | ||
| 181 | + private: | ||
| 182 | + OnlineModelConfig config_; | ||
| 183 | + Ort::Env env_; | ||
| 184 | + Ort::SessionOptions sess_opts_; | ||
| 185 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 186 | + | ||
| 187 | + std::unique_ptr<Ort::Session> sess_; | ||
| 188 | + | ||
| 189 | + std::vector<std::string> input_names_; | ||
| 190 | + std::vector<const char *> input_names_ptr_; | ||
| 191 | + | ||
| 192 | + std::vector<std::string> output_names_; | ||
| 193 | + std::vector<const char *> output_names_ptr_; | ||
| 194 | + | ||
| 195 | + int32_t head_; | ||
| 196 | + int32_t num_blocks_; | ||
| 197 | + int32_t output_size_; | ||
| 198 | + int32_t cnn_module_kernel_; | ||
| 199 | + int32_t right_context_; | ||
| 200 | + int32_t subsampling_factor_; | ||
| 201 | + int32_t vocab_size_; | ||
| 202 | + | ||
| 203 | + int32_t required_cache_size_; | ||
| 204 | + | ||
| 205 | + Ort::Value attn_cache_{nullptr}; | ||
| 206 | + Ort::Value conv_cache_{nullptr}; | ||
| 207 | + Ort::Value required_cache_size_tensor_{nullptr}; | ||
| 208 | +}; | ||
| 209 | + | ||
| 210 | +OnlineWenetCtcModel::OnlineWenetCtcModel(const OnlineModelConfig &config) | ||
| 211 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 212 | + | ||
| 213 | +#if __ANDROID_API__ >= 9 | ||
| 214 | +OnlineWenetCtcModel::OnlineWenetCtcModel(AAssetManager *mgr, | ||
| 215 | + const OnlineModelConfig &config) | ||
| 216 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 217 | +#endif | ||
| 218 | + | ||
| 219 | +OnlineWenetCtcModel::~OnlineWenetCtcModel() = default; | ||
| 220 | + | ||
| 221 | +std::vector<Ort::Value> OnlineWenetCtcModel::Forward( | ||
| 222 | + Ort::Value x, std::vector<Ort::Value> states) const { | ||
| 223 | + return impl_->Forward(std::move(x), std::move(states)); | ||
| 224 | +} | ||
| 225 | + | ||
| 226 | +int32_t OnlineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 227 | + | ||
| 228 | +int32_t OnlineWenetCtcModel::ChunkLength() const { | ||
| 229 | + return impl_->ChunkLength(); | ||
| 230 | +} | ||
| 231 | + | ||
| 232 | +int32_t OnlineWenetCtcModel::ChunkShift() const { return impl_->ChunkShift(); } | ||
| 233 | + | ||
| 234 | +OrtAllocator *OnlineWenetCtcModel::Allocator() const { | ||
| 235 | + return impl_->Allocator(); | ||
| 236 | +} | ||
| 237 | + | ||
| 238 | +std::vector<Ort::Value> OnlineWenetCtcModel::GetInitStates() const { | ||
| 239 | + return impl_->GetInitStates(); | ||
| 240 | +} | ||
| 241 | + | ||
| 242 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-wenet-ctc-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-wenet-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/online-ctc-model.h" | ||
| 18 | +#include "sherpa-onnx/csrc/online-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +class OnlineWenetCtcModel : public OnlineCtcModel { | ||
| 23 | + public: | ||
| 24 | + explicit OnlineWenetCtcModel(const OnlineModelConfig &config); | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + OnlineWenetCtcModel(AAssetManager *mgr, const OnlineModelConfig &config); | ||
| 28 | +#endif | ||
| 29 | + | ||
| 30 | + ~OnlineWenetCtcModel() override; | ||
| 31 | + | ||
| 32 | + // A list of 3 tensors: | ||
| 33 | + // - attn_cache | ||
| 34 | + // - conv_cache | ||
| 35 | + // - offset | ||
| 36 | + std::vector<Ort::Value> GetInitStates() const override; | ||
| 37 | + | ||
| 38 | + /** | ||
| 39 | + * | ||
| 40 | + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. | ||
| 41 | + * @param states It is from GetInitStates() or returned from this method. | ||
| 42 | + * | ||
| 43 | + * @return Return a list of tensors | ||
| 44 | + * - ans[0] contains log_probs, of shape (N, T, C) | ||
| 45 | + * - ans[1:] contains next_states | ||
| 46 | + */ | ||
| 47 | + std::vector<Ort::Value> Forward( | ||
| 48 | + Ort::Value x, std::vector<Ort::Value> states) const override; | ||
| 49 | + | ||
| 50 | + /** Return the vocabulary size of the model | ||
| 51 | + */ | ||
| 52 | + int32_t VocabSize() const override; | ||
| 53 | + | ||
| 54 | + /** Return an allocator for allocating memory | ||
| 55 | + */ | ||
| 56 | + OrtAllocator *Allocator() const override; | ||
| 57 | + | ||
| 58 | + // The model accepts this number of frames before subsampling as input | ||
| 59 | + int32_t ChunkLength() const override; | ||
| 60 | + | ||
| 61 | + // Similar to frame_shift in feature extractor, after processing | ||
| 62 | + // ChunkLength() frames, we advance by ChunkShift() frames | ||
| 63 | + // before we process the next chunk. | ||
| 64 | + int32_t ChunkShift() const override; | ||
| 65 | + | ||
| 66 | + private: | ||
| 67 | + class Impl; | ||
| 68 | + std::unique_ptr<Impl> impl_; | ||
| 69 | +}; | ||
| 70 | + | ||
| 71 | +} // namespace sherpa_onnx | ||
| 72 | + | ||
| 73 | +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ |
| @@ -125,6 +125,34 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { | @@ -125,6 +125,34 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { | ||
| 125 | } | 125 | } |
| 126 | } | 126 | } |
| 127 | 127 | ||
| 128 | +Ort::Value View(Ort::Value *v) { | ||
| 129 | + auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | ||
| 130 | + std::vector<int64_t> shape = type_and_shape.GetShape(); | ||
| 131 | + | ||
| 132 | + auto memory_info = | ||
| 133 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 134 | + switch (type_and_shape.GetElementType()) { | ||
| 135 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: | ||
| 136 | + return Ort::Value::CreateTensor( | ||
| 137 | + memory_info, v->GetTensorMutableData<int32_t>(), | ||
| 138 | + type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 139 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: | ||
| 140 | + return Ort::Value::CreateTensor( | ||
| 141 | + memory_info, v->GetTensorMutableData<int64_t>(), | ||
| 142 | + type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 143 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: | ||
| 144 | + return Ort::Value::CreateTensor( | ||
| 145 | + memory_info, v->GetTensorMutableData<float>(), | ||
| 146 | + type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 147 | + default: | ||
| 148 | + fprintf(stderr, "Unsupported type: %d\n", | ||
| 149 | + static_cast<int32_t>(type_and_shape.GetElementType())); | ||
| 150 | + exit(-1); | ||
| 151 | + // unreachable code | ||
| 152 | + return Ort::Value{nullptr}; | ||
| 153 | + } | ||
| 154 | +} | ||
| 155 | + | ||
| 128 | void Print1D(Ort::Value *v) { | 156 | void Print1D(Ort::Value *v) { |
| 129 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 157 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 130 | const float *d = v->GetTensorData<float>(); | 158 | const float *d = v->GetTensorData<float>(); |
| @@ -65,6 +65,9 @@ void PrintModelMetadata(std::ostream &os, | @@ -65,6 +65,9 @@ void PrintModelMetadata(std::ostream &os, | ||
| 65 | // Return a deep copy of v | 65 | // Return a deep copy of v |
| 66 | Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | 66 | Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); |
| 67 | 67 | ||
| 68 | +// Return a shallow copy | ||
| 69 | +Ort::Value View(Ort::Value *v); | ||
| 70 | + | ||
| 68 | // Print a 1-D tensor to stderr | 71 | // Print a 1-D tensor to stderr |
| 69 | void Print1D(Ort::Value *v); | 72 | void Print1D(Ort::Value *v); |
| 70 | 73 |
| @@ -26,6 +26,7 @@ pybind11_add_module(_sherpa_onnx | @@ -26,6 +26,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 26 | online-recognizer.cc | 26 | online-recognizer.cc |
| 27 | online-stream.cc | 27 | online-stream.cc |
| 28 | online-transducer-model-config.cc | 28 | online-transducer-model-config.cc |
| 29 | + online-wenet-ctc-model-config.cc | ||
| 29 | sherpa-onnx.cc | 30 | sherpa-onnx.cc |
| 30 | silero-vad-model-config.cc | 31 | silero-vad-model-config.cc |
| 31 | vad-model-config.cc | 32 | vad-model-config.cc |
| @@ -11,24 +11,29 @@ | @@ -11,24 +11,29 @@ | ||
| 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 11 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" | 12 | #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" |
| 14 | +#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" | ||
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| 17 | void PybindOnlineModelConfig(py::module *m) { | 18 | void PybindOnlineModelConfig(py::module *m) { |
| 18 | PybindOnlineTransducerModelConfig(m); | 19 | PybindOnlineTransducerModelConfig(m); |
| 19 | PybindOnlineParaformerModelConfig(m); | 20 | PybindOnlineParaformerModelConfig(m); |
| 21 | + PybindOnlineWenetCtcModelConfig(m); | ||
| 20 | 22 | ||
| 21 | using PyClass = OnlineModelConfig; | 23 | using PyClass = OnlineModelConfig; |
| 22 | py::class_<PyClass>(*m, "OnlineModelConfig") | 24 | py::class_<PyClass>(*m, "OnlineModelConfig") |
| 23 | .def(py::init<const OnlineTransducerModelConfig &, | 25 | .def(py::init<const OnlineTransducerModelConfig &, |
| 24 | - const OnlineParaformerModelConfig &, const std::string &, | 26 | + const OnlineParaformerModelConfig &, |
| 27 | + const OnlineWenetCtcModelConfig &, const std::string &, | ||
| 25 | int32_t, bool, const std::string &, const std::string &>(), | 28 | int32_t, bool, const std::string &, const std::string &>(), |
| 26 | py::arg("transducer") = OnlineTransducerModelConfig(), | 29 | py::arg("transducer") = OnlineTransducerModelConfig(), |
| 27 | py::arg("paraformer") = OnlineParaformerModelConfig(), | 30 | py::arg("paraformer") = OnlineParaformerModelConfig(), |
| 31 | + py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), | ||
| 28 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | 32 | py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, |
| 29 | py::arg("provider") = "cpu", py::arg("model_type") = "") | 33 | py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 30 | .def_readwrite("transducer", &PyClass::transducer) | 34 | .def_readwrite("transducer", &PyClass::transducer) |
| 31 | .def_readwrite("paraformer", &PyClass::paraformer) | 35 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 36 | + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) | ||
| 32 | .def_readwrite("tokens", &PyClass::tokens) | 37 | .def_readwrite("tokens", &PyClass::tokens) |
| 33 | .def_readwrite("num_threads", &PyClass::num_threads) | 38 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 34 | .def_readwrite("debug", &PyClass::debug) | 39 | .def_readwrite("debug", &PyClass::debug) |
| 1 | +// sherpa-onnx/python/csrc/online-wenet-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void PybindOnlineWenetCtcModelConfig(py::module *m) { | ||
| 15 | + using PyClass = OnlineWenetCtcModelConfig; | ||
| 16 | + py::class_<PyClass>(*m, "OnlineWenetCtcModelConfig") | ||
| 17 | + .def(py::init<const std::string &, int32_t, int32_t>(), py::arg("model"), | ||
| 18 | + py::arg("chunk_size") = 16, py::arg("num_left_chunks") = 4) | ||
| 19 | + .def_readwrite("model", &PyClass::model) | ||
| 20 | + .def_readwrite("chunk_size", &PyClass::chunk_size) | ||
| 21 | + .def_readwrite("num_left_chunks", &PyClass::num_left_chunks) | ||
| 22 | + .def("__str__", &PyClass::ToString); | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlineWenetCtcModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ |
-
请 注册 或 登录 后发表评论