Committed by
GitHub
Add HLG decoding for streaming CTC models (#731)
正在显示
28 个修改的文件
包含
668 行增加
和
82 行删除
| 1 | #!/usr/bin/env bash | 1 | #!/usr/bin/env bash |
| 2 | 2 | ||
| 3 | -set -e | 3 | +set -ex |
| 4 | 4 | ||
| 5 | log() { | 5 | log() { |
| 6 | # This function is from espnet | 6 | # This function is from espnet |
| @@ -14,6 +14,26 @@ echo "PATH: $PATH" | @@ -14,6 +14,26 @@ echo "PATH: $PATH" | ||
| 14 | which $EXE | 14 | which $EXE |
| 15 | 15 | ||
| 16 | log "------------------------------------------------------------" | 16 | log "------------------------------------------------------------" |
| 17 | +log "Run streaming Zipformer2 CTC HLG decoding " | ||
| 18 | +log "------------------------------------------------------------" | ||
| 19 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 20 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 21 | +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 22 | +repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 | ||
| 23 | +ls -lh $repo | ||
| 24 | +echo "pwd: $PWD" | ||
| 25 | + | ||
| 26 | +$EXE \ | ||
| 27 | + --zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ | ||
| 28 | + --ctc-graph=$repo/HLG.fst \ | ||
| 29 | + --tokens=$repo/tokens.txt \ | ||
| 30 | + $repo/test_wavs/0.wav \ | ||
| 31 | + $repo/test_wavs/1.wav \ | ||
| 32 | + $repo/test_wavs/8k.wav | ||
| 33 | + | ||
| 34 | +rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 | ||
| 35 | + | ||
| 36 | +log "------------------------------------------------------------" | ||
| 17 | log "Run streaming Zipformer2 CTC " | 37 | log "Run streaming Zipformer2 CTC " |
| 18 | log "------------------------------------------------------------" | 38 | log "------------------------------------------------------------" |
| 19 | 39 |
| 1 | #!/usr/bin/env bash | 1 | #!/usr/bin/env bash |
| 2 | 2 | ||
| 3 | -set -e | 3 | +set -ex |
| 4 | 4 | ||
| 5 | log() { | 5 | log() { |
| 6 | # This function is from espnet | 6 | # This function is from espnet |
| @@ -8,6 +8,23 @@ log() { | @@ -8,6 +8,23 @@ log() { | ||
| 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | 8 | echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| 9 | } | 9 | } |
| 10 | 10 | ||
| 11 | +log "test streaming zipformer2 ctc HLG decoding" | ||
| 12 | + | ||
| 13 | +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 14 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 15 | +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 16 | +repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 | ||
| 17 | + | ||
| 18 | +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \ | ||
| 19 | + --debug 1 \ | ||
| 20 | + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \ | ||
| 21 | + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \ | ||
| 22 | + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ | ||
| 23 | + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav | ||
| 24 | + | ||
| 25 | +rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 | ||
| 26 | + | ||
| 27 | + | ||
| 11 | mkdir -p /tmp/icefall-models | 28 | mkdir -p /tmp/icefall-models |
| 12 | dir=/tmp/icefall-models | 29 | dir=/tmp/icefall-models |
| 13 | 30 |
| @@ -124,6 +124,14 @@ jobs: | @@ -124,6 +124,14 @@ jobs: | ||
| 124 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} | 124 | name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} |
| 125 | path: build/bin/* | 125 | path: build/bin/* |
| 126 | 126 | ||
| 127 | + - name: Test online CTC | ||
| 128 | + shell: bash | ||
| 129 | + run: | | ||
| 130 | + export PATH=$PWD/build/bin:$PATH | ||
| 131 | + export EXE=sherpa-onnx | ||
| 132 | + | ||
| 133 | + .github/scripts/test-online-ctc.sh | ||
| 134 | + | ||
| 127 | - name: Test C API | 135 | - name: Test C API |
| 128 | shell: bash | 136 | shell: bash |
| 129 | run: | | 137 | run: | |
| @@ -149,13 +157,6 @@ jobs: | @@ -149,13 +157,6 @@ jobs: | ||
| 149 | 157 | ||
| 150 | .github/scripts/test-kws.sh | 158 | .github/scripts/test-kws.sh |
| 151 | 159 | ||
| 152 | - - name: Test online CTC | ||
| 153 | - shell: bash | ||
| 154 | - run: | | ||
| 155 | - export PATH=$PWD/build/bin:$PATH | ||
| 156 | - export EXE=sherpa-onnx | ||
| 157 | - | ||
| 158 | - .github/scripts/test-online-ctc.sh | ||
| 159 | 160 | ||
| 160 | - name: Test offline Whisper | 161 | - name: Test offline Whisper |
| 161 | if: matrix.build_type != 'Debug' | 162 | if: matrix.build_type != 'Debug' |
| 1 | function(download_kaldi_decoder) | 1 | function(download_kaldi_decoder) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz") | ||
| 5 | - set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz") | ||
| 6 | - set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601") | 4 | + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz") |
| 5 | + set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz") | ||
| 6 | + set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff") | ||
| 7 | 7 | ||
| 8 | set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| 9 | set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE) | 9 | set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE) |
| @@ -12,11 +12,11 @@ function(download_kaldi_decoder) | @@ -12,11 +12,11 @@ function(download_kaldi_decoder) | ||
| 12 | # If you don't have access to the Internet, | 12 | # If you don't have access to the Internet, |
| 13 | # please pre-download kaldi-decoder | 13 | # please pre-download kaldi-decoder |
| 14 | set(possible_file_locations | 14 | set(possible_file_locations |
| 15 | - $ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz | ||
| 16 | - ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz | ||
| 17 | - ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz | ||
| 18 | - /tmp/kaldi-decoder-0.2.4.tar.gz | ||
| 19 | - /star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz | 15 | + $ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz |
| 16 | + ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz | ||
| 17 | + ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz | ||
| 18 | + /tmp/kaldi-decoder-0.2.5.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz | ||
| 20 | ) | 20 | ) |
| 21 | 21 | ||
| 22 | foreach(f IN LISTS possible_file_locations) | 22 | foreach(f IN LISTS possible_file_locations) |
| 1 | +#!/usr/bin/env python3 | ||
| 2 | + | ||
| 3 | +# This file shows how to use a streaming zipformer CTC model and an HLG | ||
| 4 | +# graph for decoding. | ||
| 5 | +# | ||
| 6 | +# We use the following model as an example | ||
| 7 | +# | ||
| 8 | +""" | ||
| 9 | +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 10 | +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 11 | +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 | ||
| 12 | + | ||
| 13 | +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \ | ||
| 14 | + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \ | ||
| 15 | + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \ | ||
| 16 | + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ | ||
| 17 | + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav | ||
| 18 | + | ||
| 19 | +""" | ||
| 20 | +# (The above model is from https://github.com/k2-fsa/icefall/pull/1557) | ||
| 21 | + | ||
| 22 | +import argparse | ||
| 23 | +import time | ||
| 24 | +import wave | ||
| 25 | +from pathlib import Path | ||
| 26 | +from typing import List, Tuple | ||
| 27 | + | ||
| 28 | +import numpy as np | ||
| 29 | +import sherpa_onnx | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +def get_args(): | ||
| 33 | + parser = argparse.ArgumentParser( | ||
| 34 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
| 35 | + ) | ||
| 36 | + | ||
| 37 | + parser.add_argument( | ||
| 38 | + "--tokens", | ||
| 39 | + type=str, | ||
| 40 | + required=True, | ||
| 41 | + help="Path to tokens.txt", | ||
| 42 | + ) | ||
| 43 | + | ||
| 44 | + parser.add_argument( | ||
| 45 | + "--model", | ||
| 46 | + type=str, | ||
| 47 | + required=True, | ||
| 48 | + help="Path to the ONNX model", | ||
| 49 | + ) | ||
| 50 | + | ||
| 51 | + parser.add_argument( | ||
| 52 | + "--graph", | ||
| 53 | + type=str, | ||
| 54 | + required=True, | ||
| 55 | + help="Path to H.fst, HL.fst, or HLG.fst", | ||
| 56 | + ) | ||
| 57 | + | ||
| 58 | + parser.add_argument( | ||
| 59 | + "--num-threads", | ||
| 60 | + type=int, | ||
| 61 | + default=1, | ||
| 62 | + help="Number of threads for neural network computation", | ||
| 63 | + ) | ||
| 64 | + | ||
| 65 | + parser.add_argument( | ||
| 66 | + "--provider", | ||
| 67 | + type=str, | ||
| 68 | + default="cpu", | ||
| 69 | + help="Valid values: cpu, cuda, coreml", | ||
| 70 | + ) | ||
| 71 | + | ||
| 72 | + parser.add_argument( | ||
| 73 | + "--debug", | ||
| 74 | + type=int, | ||
| 75 | + default=0, | ||
| 76 | + help="Valid values: 1, 0", | ||
| 77 | + ) | ||
| 78 | + | ||
| 79 | + parser.add_argument( | ||
| 80 | + "sound_file", | ||
| 81 | + type=str, | ||
| 82 | + help="The input sound file to decode. It must be of WAVE" | ||
| 83 | + "format with a single channel, and each sample has 16-bit, " | ||
| 84 | + "i.e., int16_t. " | ||
| 85 | + "The sample rate of the file can be arbitrary and does not need to " | ||
| 86 | + "be 16 kHz", | ||
| 87 | + ) | ||
| 88 | + | ||
| 89 | + return parser.parse_args() | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +def assert_file_exists(filename: str): | ||
| 93 | + assert Path(filename).is_file(), ( | ||
| 94 | + f"{filename} does not exist!\n" | ||
| 95 | + "Please refer to " | ||
| 96 | + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" | ||
| 97 | + ) | ||
| 98 | + | ||
| 99 | + | ||
| 100 | +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | ||
| 101 | + """ | ||
| 102 | + Args: | ||
| 103 | + wave_filename: | ||
| 104 | + Path to a wave file. It should be single channel and each sample should | ||
| 105 | + be 16-bit. Its sample rate does not need to be 16kHz. | ||
| 106 | + Returns: | ||
| 107 | + Return a tuple containing: | ||
| 108 | + - A 1-D array of dtype np.float32 containing the samples, which are | ||
| 109 | + normalized to the range [-1, 1]. | ||
| 110 | + - sample rate of the wave file | ||
| 111 | + """ | ||
| 112 | + | ||
| 113 | + with wave.open(wave_filename) as f: | ||
| 114 | + assert f.getnchannels() == 1, f.getnchannels() | ||
| 115 | + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes | ||
| 116 | + num_samples = f.getnframes() | ||
| 117 | + samples = f.readframes(num_samples) | ||
| 118 | + samples_int16 = np.frombuffer(samples, dtype=np.int16) | ||
| 119 | + samples_float32 = samples_int16.astype(np.float32) | ||
| 120 | + | ||
| 121 | + samples_float32 = samples_float32 / 32768 | ||
| 122 | + return samples_float32, f.getframerate() | ||
| 123 | + | ||
| 124 | + | ||
| 125 | +def main(): | ||
| 126 | + args = get_args() | ||
| 127 | + print(vars(args)) | ||
| 128 | + | ||
| 129 | + assert_file_exists(args.tokens) | ||
| 130 | + assert_file_exists(args.graph) | ||
| 131 | + assert_file_exists(args.model) | ||
| 132 | + | ||
| 133 | + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc( | ||
| 134 | + tokens=args.tokens, | ||
| 135 | + model=args.model, | ||
| 136 | + num_threads=args.num_threads, | ||
| 137 | + provider=args.provider, | ||
| 138 | + sample_rate=16000, | ||
| 139 | + feature_dim=80, | ||
| 140 | + ctc_graph=args.graph, | ||
| 141 | + ) | ||
| 142 | + | ||
| 143 | + wave_filename = args.sound_file | ||
| 144 | + assert_file_exists(wave_filename) | ||
| 145 | + samples, sample_rate = read_wave(wave_filename) | ||
| 146 | + duration = len(samples) / sample_rate | ||
| 147 | + | ||
| 148 | + print("Started") | ||
| 149 | + | ||
| 150 | + start_time = time.time() | ||
| 151 | + s = recognizer.create_stream() | ||
| 152 | + s.accept_waveform(sample_rate, samples) | ||
| 153 | + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) | ||
| 154 | + s.accept_waveform(sample_rate, tail_paddings) | ||
| 155 | + s.input_finished() | ||
| 156 | + while recognizer.is_ready(s): | ||
| 157 | + recognizer.decode_stream(s) | ||
| 158 | + | ||
| 159 | + result = recognizer.get_result(s).lower() | ||
| 160 | + end_time = time.time() | ||
| 161 | + | ||
| 162 | + elapsed_seconds = end_time - start_time | ||
| 163 | + rtf = elapsed_seconds / duration | ||
| 164 | + print(f"num_threads: {args.num_threads}") | ||
| 165 | + print(f"Wave duration: {duration:.3f} s") | ||
| 166 | + print(f"Elapsed time: {elapsed_seconds:.3f} s") | ||
| 167 | + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") | ||
| 168 | + print(result) | ||
| 169 | + | ||
| 170 | + | ||
| 171 | +if __name__ == "__main__": | ||
| 172 | + main() |
| @@ -51,6 +51,8 @@ set(sources | @@ -51,6 +51,8 @@ set(sources | ||
| 51 | offline-zipformer-ctc-model-config.cc | 51 | offline-zipformer-ctc-model-config.cc |
| 52 | offline-zipformer-ctc-model.cc | 52 | offline-zipformer-ctc-model.cc |
| 53 | online-conformer-transducer-model.cc | 53 | online-conformer-transducer-model.cc |
| 54 | + online-ctc-fst-decoder-config.cc | ||
| 55 | + online-ctc-fst-decoder.cc | ||
| 54 | online-ctc-greedy-search-decoder.cc | 56 | online-ctc-greedy-search-decoder.cc |
| 55 | online-ctc-model.cc | 57 | online-ctc-model.cc |
| 56 | online-lm-config.cc | 58 | online-lm-config.cc |
| @@ -7,6 +7,9 @@ | @@ -7,6 +7,9 @@ | ||
| 7 | #include <sstream> | 7 | #include <sstream> |
| 8 | #include <string> | 8 | #include <string> |
| 9 | 9 | ||
| 10 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 10 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 11 | 14 | ||
| 12 | std::string OfflineCtcFstDecoderConfig::ToString() const { | 15 | std::string OfflineCtcFstDecoderConfig::ToString() const { |
| @@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { | @@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { | ||
| 29 | "Decoder max active states. Larger->slower; more accurate"); | 32 | "Decoder max active states. Larger->slower; more accurate"); |
| 30 | } | 33 | } |
| 31 | 34 | ||
| 35 | +bool OfflineCtcFstDecoderConfig::Validate() const { | ||
| 36 | + if (!graph.empty() && !FileExists(graph)) { | ||
| 37 | + SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); | ||
| 38 | + return false; | ||
| 39 | + } | ||
| 40 | + return true; | ||
| 41 | +} | ||
| 42 | + | ||
| 32 | } // namespace sherpa_onnx | 43 | } // namespace sherpa_onnx |
| @@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig { | @@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig { | ||
| 24 | std::string ToString() const; | 24 | std::string ToString() const; |
| 25 | 25 | ||
| 26 | void Register(ParseOptions *po); | 26 | void Register(ParseOptions *po); |
| 27 | + bool Validate() const; | ||
| 27 | }; | 28 | }; |
| 28 | 29 | ||
| 29 | } // namespace sherpa_onnx | 30 | } // namespace sherpa_onnx |
| @@ -20,7 +20,7 @@ namespace sherpa_onnx { | @@ -20,7 +20,7 @@ namespace sherpa_onnx { | ||
| 20 | // @param filename Path to a StdVectorFst or StdConstFst graph | 20 | // @param filename Path to a StdVectorFst or StdConstFst graph |
| 21 | // @return The caller should free the returned pointer using `delete` to | 21 | // @return The caller should free the returned pointer using `delete` to |
| 22 | // avoid memory leak. | 22 | // avoid memory leak. |
| 23 | -static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) { | 23 | +fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) { |
| 24 | // read decoding network FST | 24 | // read decoding network FST |
| 25 | std::ifstream is(filename, std::ios::binary); | 25 | std::ifstream is(filename, std::ios::binary); |
| 26 | if (!is.good()) { | 26 | if (!is.good()) { |
| @@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const { | @@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const { | ||
| 67 | return false; | 67 | return false; |
| 68 | } | 68 | } |
| 69 | 69 | ||
| 70 | + if (!ctc_fst_decoder_config.graph.empty() && | ||
| 71 | + !ctc_fst_decoder_config.Validate()) { | ||
| 72 | + SHERPA_ONNX_LOGE("Errors in fst_decoder"); | ||
| 73 | + return false; | ||
| 74 | + } | ||
| 75 | + | ||
| 70 | return model_config.Validate(); | 76 | return model_config.Validate(); |
| 71 | } | 77 | } |
| 72 | 78 |
| @@ -5,12 +5,16 @@ | @@ -5,12 +5,16 @@ | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ |
| 6 | #define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ | 6 | #define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ |
| 7 | 7 | ||
| 8 | +#include <memory> | ||
| 8 | #include <vector> | 9 | #include <vector> |
| 9 | 10 | ||
| 11 | +#include "kaldi-decoder/csrc/faster-decoder.h" | ||
| 10 | #include "onnxruntime_cxx_api.h" // NOLINT | 12 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 11 | 13 | ||
| 12 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 13 | 15 | ||
| 16 | +class OnlineStream; | ||
| 17 | + | ||
| 14 | struct OnlineCtcDecoderResult { | 18 | struct OnlineCtcDecoderResult { |
| 15 | /// Number of frames after subsampling we have decoded so far | 19 | /// Number of frames after subsampling we have decoded so far |
| 16 | int32_t frame_offset = 0; | 20 | int32_t frame_offset = 0; |
| @@ -37,7 +41,13 @@ class OnlineCtcDecoder { | @@ -37,7 +41,13 @@ class OnlineCtcDecoder { | ||
| 37 | * @param results Input & Output parameters.. | 41 | * @param results Input & Output parameters.. |
| 38 | */ | 42 | */ |
| 39 | virtual void Decode(Ort::Value log_probs, | 43 | virtual void Decode(Ort::Value log_probs, |
| 40 | - std::vector<OnlineCtcDecoderResult> *results) = 0; | 44 | + std::vector<OnlineCtcDecoderResult> *results, |
| 45 | + OnlineStream **ss = nullptr, int32_t n = 0) = 0; | ||
| 46 | + | ||
| 47 | + virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder() | ||
| 48 | + const { | ||
| 49 | + return nullptr; | ||
| 50 | + } | ||
| 41 | }; | 51 | }; |
| 42 | 52 | ||
| 43 | } // namespace sherpa_onnx | 53 | } // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" | ||
| 6 | + | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +std::string OnlineCtcFstDecoderConfig::ToString() const { | ||
| 16 | + std::ostringstream os; | ||
| 17 | + | ||
| 18 | + os << "OnlineCtcFstDecoderConfig("; | ||
| 19 | + os << "graph=\"" << graph << "\", "; | ||
| 20 | + os << "max_active=" << max_active << ")"; | ||
| 21 | + | ||
| 22 | + return os.str(); | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) { | ||
| 26 | + po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); | ||
| 27 | + | ||
| 28 | + po->Register("ctc-max-active", &max_active, | ||
| 29 | + "Decoder max active states. Larger->slower; more accurate"); | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +bool OnlineCtcFstDecoderConfig::Validate() const { | ||
| 33 | + if (!graph.empty() && !FileExists(graph)) { | ||
| 34 | + SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str()); | ||
| 35 | + return false; | ||
| 36 | + } | ||
| 37 | + return true; | ||
| 38 | +} | ||
| 39 | + | ||
| 40 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OnlineCtcFstDecoderConfig { | ||
| 15 | + // Path to H.fst, HL.fst or HLG.fst | ||
| 16 | + std::string graph; | ||
| 17 | + int32_t max_active = 3000; | ||
| 18 | + | ||
| 19 | + OnlineCtcFstDecoderConfig() = default; | ||
| 20 | + | ||
| 21 | + OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active) | ||
| 22 | + : graph(graph), max_active(max_active) {} | ||
| 23 | + | ||
| 24 | + std::string ToString() const; | ||
| 25 | + | ||
| 26 | + void Register(ParseOptions *po); | ||
| 27 | + bool Validate() const; | ||
| 28 | +}; | ||
| 29 | + | ||
| 30 | +} // namespace sherpa_onnx | ||
| 31 | + | ||
| 32 | +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ |
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-ctc-fst-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <memory> | ||
| 9 | +#include <string> | ||
| 10 | +#include <utility> | ||
| 11 | +#include <vector> | ||
| 12 | + | ||
| 13 | +#include "fst/fstlib.h" | ||
| 14 | +#include "kaldi-decoder/csrc/decodable-ctc.h" | ||
| 15 | +#include "kaldifst/csrc/fstext-utils.h" | ||
| 16 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 17 | +#include "sherpa-onnx/csrc/online-stream.h" | ||
| 18 | + | ||
| 19 | +namespace sherpa_onnx { | ||
| 20 | + | ||
| 21 | +// defined in ./offline-ctc-fst-decoder.cc | ||
| 22 | +fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename); | ||
| 23 | + | ||
| 24 | +OnlineCtcFstDecoder::OnlineCtcFstDecoder( | ||
| 25 | + const OnlineCtcFstDecoderConfig &config, int32_t blank_id) | ||
| 26 | + : config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) { | ||
| 27 | + options_.max_active = config_.max_active; | ||
| 28 | +} | ||
| 29 | + | ||
| 30 | +std::unique_ptr<kaldi_decoder::FasterDecoder> | ||
| 31 | +OnlineCtcFstDecoder::CreateFasterDecoder() const { | ||
| 32 | + return std::make_unique<kaldi_decoder::FasterDecoder>(*fst_, options_); | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +static void DecodeOne(const float *log_probs, int32_t num_rows, | ||
| 36 | + int32_t num_cols, OnlineCtcDecoderResult *result, | ||
| 37 | + OnlineStream *s, int32_t blank_id) { | ||
| 38 | + int32_t &processed_frames = s->GetFasterDecoderProcessedFrames(); | ||
| 39 | + kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols, | ||
| 40 | + processed_frames); | ||
| 41 | + | ||
| 42 | + kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder(); | ||
| 43 | + if (processed_frames == 0) { | ||
| 44 | + decoder->InitDecoding(); | ||
| 45 | + } | ||
| 46 | + | ||
| 47 | + decoder->AdvanceDecoding(&decodable); | ||
| 48 | + | ||
| 49 | + if (decoder->ReachedFinal()) { | ||
| 50 | + fst::VectorFst<fst::LatticeArc> fst_out; | ||
| 51 | + bool ok = decoder->GetBestPath(&fst_out); | ||
| 52 | + if (ok) { | ||
| 53 | + std::vector<int32_t> isymbols_out; | ||
| 54 | + std::vector<int32_t> osymbols_out_unused; | ||
| 55 | + ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out, | ||
| 56 | + &osymbols_out_unused, nullptr); | ||
| 57 | + std::vector<int64_t> tokens; | ||
| 58 | + tokens.reserve(isymbols_out.size()); | ||
| 59 | + | ||
| 60 | + std::vector<int32_t> timestamps; | ||
| 61 | + timestamps.reserve(isymbols_out.size()); | ||
| 62 | + | ||
| 63 | + std::ostringstream os; | ||
| 64 | + int32_t prev_id = -1; | ||
| 65 | + int32_t num_trailing_blanks = 0; | ||
| 66 | + int32_t f = 0; // frame number | ||
| 67 | + | ||
| 68 | + for (auto i : isymbols_out) { | ||
| 69 | + i -= 1; | ||
| 70 | + | ||
| 71 | + if (i == blank_id) { | ||
| 72 | + num_trailing_blanks += 1; | ||
| 73 | + } else { | ||
| 74 | + num_trailing_blanks = 0; | ||
| 75 | + } | ||
| 76 | + | ||
| 77 | + if (i != blank_id && i != prev_id) { | ||
| 78 | + tokens.push_back(i); | ||
| 79 | + timestamps.push_back(f); | ||
| 80 | + } | ||
| 81 | + prev_id = i; | ||
| 82 | + f += 1; | ||
| 83 | + } | ||
| 84 | + | ||
| 85 | + result->tokens = std::move(tokens); | ||
| 86 | + result->timestamps = std::move(timestamps); | ||
| 87 | + // no need to set frame_offset | ||
| 88 | + } | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + processed_frames += num_rows; | ||
| 92 | +} | ||
| 93 | + | ||
| 94 | +void OnlineCtcFstDecoder::Decode(Ort::Value log_probs, | ||
| 95 | + std::vector<OnlineCtcDecoderResult> *results, | ||
| 96 | + OnlineStream **ss, int32_t n) { | ||
| 97 | + std::vector<int64_t> log_probs_shape = | ||
| 98 | + log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 99 | + | ||
| 100 | + if (log_probs_shape[0] != results->size()) { | ||
| 101 | + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", | ||
| 102 | + static_cast<int32_t>(log_probs_shape[0]), | ||
| 103 | + static_cast<int32_t>(results->size())); | ||
| 104 | + exit(-1); | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + if (log_probs_shape[0] != n) { | ||
| 108 | + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", | ||
| 109 | + static_cast<int32_t>(log_probs_shape[0]), n); | ||
| 110 | + exit(-1); | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]); | ||
| 114 | + int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]); | ||
| 115 | + int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]); | ||
| 116 | + | ||
| 117 | + const float *p = log_probs.GetTensorData<float>(); | ||
| 118 | + | ||
| 119 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 120 | + DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, | ||
| 121 | + &(*results)[i], ss[i], blank_id_); | ||
| 122 | + } | ||
| 123 | +} | ||
| 124 | + | ||
| 125 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-ctc-fst-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-ctc-fst-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "fst/fst.h" | ||
| 12 | +#include "sherpa-onnx/csrc/online-ctc-decoder.h" | ||
| 13 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OnlineCtcFstDecoder : public OnlineCtcDecoder { | ||
| 18 | + public: | ||
| 19 | + OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, | ||
| 20 | + int32_t blank_id); | ||
| 21 | + | ||
| 22 | + void Decode(Ort::Value log_probs, | ||
| 23 | + std::vector<OnlineCtcDecoderResult> *results, | ||
| 24 | + OnlineStream **ss = nullptr, int32_t n = 0) override; | ||
| 25 | + | ||
| 26 | + std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder() | ||
| 27 | + const override; | ||
| 28 | + | ||
| 29 | + private: | ||
| 30 | + OnlineCtcFstDecoderConfig config_; | ||
| 31 | + kaldi_decoder::FasterDecoderOptions options_; | ||
| 32 | + | ||
| 33 | + std::unique_ptr<fst::Fst<fst::StdArc>> fst_; | ||
| 34 | + int32_t blank_id_ = 0; | ||
| 35 | +}; | ||
| 36 | + | ||
| 37 | +} // namespace sherpa_onnx | ||
| 38 | + | ||
| 39 | +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ |
| @@ -13,7 +13,8 @@ | @@ -13,7 +13,8 @@ | ||
| 13 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| 14 | 14 | ||
| 15 | void OnlineCtcGreedySearchDecoder::Decode( | 15 | void OnlineCtcGreedySearchDecoder::Decode( |
| 16 | - Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) { | 16 | + Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results, |
| 17 | + OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { | ||
| 17 | std::vector<int64_t> log_probs_shape = | 18 | std::vector<int64_t> log_probs_shape = |
| 18 | log_probs.GetTensorTypeAndShapeInfo().GetShape(); | 19 | log_probs.GetTensorTypeAndShapeInfo().GetShape(); |
| 19 | 20 |
| @@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { | @@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { | ||
| 17 | : blank_id_(blank_id) {} | 17 | : blank_id_(blank_id) {} |
| 18 | 18 | ||
| 19 | void Decode(Ort::Value log_probs, | 19 | void Decode(Ort::Value log_probs, |
| 20 | - std::vector<OnlineCtcDecoderResult> *results) override; | 20 | + std::vector<OnlineCtcDecoderResult> *results, |
| 21 | + OnlineStream **ss = nullptr, int32_t n = 0) override; | ||
| 21 | 22 | ||
| 22 | private: | 23 | private: |
| 23 | int32_t blank_id_; | 24 | int32_t blank_id_; |
| @@ -16,6 +16,7 @@ | @@ -16,6 +16,7 @@ | ||
| 16 | #include "sherpa-onnx/csrc/file-utils.h" | 16 | #include "sherpa-onnx/csrc/file-utils.h" |
| 17 | #include "sherpa-onnx/csrc/macros.h" | 17 | #include "sherpa-onnx/csrc/macros.h" |
| 18 | #include "sherpa-onnx/csrc/online-ctc-decoder.h" | 18 | #include "sherpa-onnx/csrc/online-ctc-decoder.h" |
| 19 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" | ||
| 19 | #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" | 20 | #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" |
| 20 | #include "sherpa-onnx/csrc/online-ctc-model.h" | 21 | #include "sherpa-onnx/csrc/online-ctc-model.h" |
| 21 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" | 22 | #include "sherpa-onnx/csrc/online-recognizer-impl.h" |
| @@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 99 | std::unique_ptr<OnlineStream> CreateStream() const override { | 100 | std::unique_ptr<OnlineStream> CreateStream() const override { |
| 100 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); | 101 | auto stream = std::make_unique<OnlineStream>(config_.feat_config); |
| 101 | stream->SetStates(model_->GetInitStates()); | 102 | stream->SetStates(model_->GetInitStates()); |
| 103 | + stream->SetFasterDecoder(decoder_->CreateFasterDecoder()); | ||
| 102 | 104 | ||
| 103 | return stream; | 105 | return stream; |
| 104 | } | 106 | } |
| @@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 165 | std::vector<std::vector<Ort::Value>> next_states = | 167 | std::vector<std::vector<Ort::Value>> next_states = |
| 166 | model_->UnStackStates(std::move(out_states)); | 168 | model_->UnStackStates(std::move(out_states)); |
| 167 | 169 | ||
| 168 | - decoder_->Decode(std::move(out[0]), &results); | 170 | + decoder_->Decode(std::move(out[0]), &results, ss, n); |
| 169 | 171 | ||
| 170 | for (int32_t k = 0; k != n; ++k) { | 172 | for (int32_t k = 0; k != n; ++k) { |
| 171 | ss[k]->SetCtcResult(results[k]); | 173 | ss[k]->SetCtcResult(results[k]); |
| @@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 221 | 223 | ||
| 222 | private: | 224 | private: |
| 223 | void InitDecoder() { | 225 | void InitDecoder() { |
| 224 | - if (config_.decoding_method == "greedy_search") { | ||
| 225 | - if (!sym_.contains("<blk>") && !sym_.contains("<eps>") && | ||
| 226 | - !sym_.contains("<blank>")) { | ||
| 227 | - SHERPA_ONNX_LOGE( | ||
| 228 | - "We expect that tokens.txt contains " | ||
| 229 | - "the symbol <blk> or <eps> or <blank> and its ID."); | ||
| 230 | - exit(-1); | ||
| 231 | - } | 226 | + if (!sym_.contains("<blk>") && !sym_.contains("<eps>") && |
| 227 | + !sym_.contains("<blank>")) { | ||
| 228 | + SHERPA_ONNX_LOGE( | ||
| 229 | + "We expect that tokens.txt contains " | ||
| 230 | + "the symbol <blk> or <eps> or <blank> and its ID."); | ||
| 231 | + exit(-1); | ||
| 232 | + } | ||
| 232 | 233 | ||
| 233 | - int32_t blank_id = 0; | ||
| 234 | - if (sym_.contains("<blk>")) { | ||
| 235 | - blank_id = sym_["<blk>"]; | ||
| 236 | - } else if (sym_.contains("<eps>")) { | ||
| 237 | - // for tdnn models of the yesno recipe from icefall | ||
| 238 | - blank_id = sym_["<eps>"]; | ||
| 239 | - } else if (sym_.contains("<blank>")) { | ||
| 240 | - // for WeNet CTC models | ||
| 241 | - blank_id = sym_["<blank>"]; | ||
| 242 | - } | 234 | + int32_t blank_id = 0; |
| 235 | + if (sym_.contains("<blk>")) { | ||
| 236 | + blank_id = sym_["<blk>"]; | ||
| 237 | + } else if (sym_.contains("<eps>")) { | ||
| 238 | + // for tdnn models of the yesno recipe from icefall | ||
| 239 | + blank_id = sym_["<eps>"]; | ||
| 240 | + } else if (sym_.contains("<blank>")) { | ||
| 241 | + // for WeNet CTC models | ||
| 242 | + blank_id = sym_["<blank>"]; | ||
| 243 | + } | ||
| 243 | 244 | ||
| 245 | + if (!config_.ctc_fst_decoder_config.graph.empty()) { | ||
| 246 | + decoder_ = std::make_unique<OnlineCtcFstDecoder>( | ||
| 247 | + config_.ctc_fst_decoder_config, blank_id); | ||
| 248 | + } else if (config_.decoding_method == "greedy_search") { | ||
| 244 | decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id); | 249 | decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id); |
| 245 | } else { | 250 | } else { |
| 246 | - SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | ||
| 247 | - config_.decoding_method.c_str()); | 251 | + SHERPA_ONNX_LOGE( |
| 252 | + "Unsupported decoding method: %s for streaming CTC models", | ||
| 253 | + config_.decoding_method.c_str()); | ||
| 248 | exit(-1); | 254 | exit(-1); |
| 249 | } | 255 | } |
| 250 | } | 256 | } |
| @@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | @@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { | ||
| 281 | std::vector<OnlineCtcDecoderResult> results(1); | 287 | std::vector<OnlineCtcDecoderResult> results(1); |
| 282 | results[0] = std::move(s->GetCtcResult()); | 288 | results[0] = std::move(s->GetCtcResult()); |
| 283 | 289 | ||
| 284 | - decoder_->Decode(std::move(out[0]), &results); | 290 | + decoder_->Decode(std::move(out[0]), &results, &s, 1); |
| 285 | s->SetCtcResult(results[0]); | 291 | s->SetCtcResult(results[0]); |
| 286 | } | 292 | } |
| 287 | 293 |
| @@ -19,13 +19,13 @@ | @@ -19,13 +19,13 @@ | ||
| 19 | namespace sherpa_onnx { | 19 | namespace sherpa_onnx { |
| 20 | 20 | ||
| 21 | /// Helper for `OnlineRecognizerResult::AsJsonString()` | 21 | /// Helper for `OnlineRecognizerResult::AsJsonString()` |
| 22 | -template<typename T> | ||
| 23 | -std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) { | 22 | +template <typename T> |
| 23 | +std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) { | ||
| 24 | std::ostringstream oss; | 24 | std::ostringstream oss; |
| 25 | oss << std::fixed << std::setprecision(precision); | 25 | oss << std::fixed << std::setprecision(precision); |
| 26 | oss << "[ "; | 26 | oss << "[ "; |
| 27 | std::string sep = ""; | 27 | std::string sep = ""; |
| 28 | - for (const auto& item : vec) { | 28 | + for (const auto &item : vec) { |
| 29 | oss << sep << item; | 29 | oss << sep << item; |
| 30 | sep = ", "; | 30 | sep = ", "; |
| 31 | } | 31 | } |
| @@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) { | @@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) { | ||
| 34 | } | 34 | } |
| 35 | 35 | ||
| 36 | /// Helper for `OnlineRecognizerResult::AsJsonString()` | 36 | /// Helper for `OnlineRecognizerResult::AsJsonString()` |
| 37 | -template<> // explicit specialization for T = std::string | ||
| 38 | -std::string VecToString<std::string>(const std::vector<std::string>& vec, | 37 | +template <> // explicit specialization for T = std::string |
| 38 | +std::string VecToString<std::string>(const std::vector<std::string> &vec, | ||
| 39 | int32_t) { // ignore 2nd arg | 39 | int32_t) { // ignore 2nd arg |
| 40 | std::ostringstream oss; | 40 | std::ostringstream oss; |
| 41 | oss << "[ "; | 41 | oss << "[ "; |
| 42 | std::string sep = ""; | 42 | std::string sep = ""; |
| 43 | - for (const auto& item : vec) { | 43 | + for (const auto &item : vec) { |
| 44 | oss << sep << "\"" << item << "\""; | 44 | oss << sep << "\"" << item << "\""; |
| 45 | sep = ", "; | 45 | sep = ", "; |
| 46 | } | 46 | } |
| @@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec, | @@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec, | ||
| 51 | std::string OnlineRecognizerResult::AsJsonString() const { | 51 | std::string OnlineRecognizerResult::AsJsonString() const { |
| 52 | std::ostringstream os; | 52 | std::ostringstream os; |
| 53 | os << "{ "; | 53 | os << "{ "; |
| 54 | - os << "\"text\": " << "\"" << text << "\"" << ", "; | 54 | + os << "\"text\": " |
| 55 | + << "\"" << text << "\"" | ||
| 56 | + << ", "; | ||
| 55 | os << "\"tokens\": " << VecToString(tokens) << ", "; | 57 | os << "\"tokens\": " << VecToString(tokens) << ", "; |
| 56 | os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; | 58 | os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; |
| 57 | os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; | 59 | os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; |
| 58 | os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; | 60 | os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; |
| 59 | os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; | 61 | os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; |
| 60 | os << "\"segment\": " << segment << ", "; | 62 | os << "\"segment\": " << segment << ", "; |
| 61 | - os << "\"start_time\": " << std::fixed << std::setprecision(2) | ||
| 62 | - << start_time << ", "; | 63 | + os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time |
| 64 | + << ", "; | ||
| 63 | os << "\"is_final\": " << (is_final ? "true" : "false"); | 65 | os << "\"is_final\": " << (is_final ? "true" : "false"); |
| 64 | os << "}"; | 66 | os << "}"; |
| 65 | return os.str(); | 67 | return os.str(); |
| @@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | @@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { | ||
| 70 | model_config.Register(po); | 72 | model_config.Register(po); |
| 71 | endpoint_config.Register(po); | 73 | endpoint_config.Register(po); |
| 72 | lm_config.Register(po); | 74 | lm_config.Register(po); |
| 75 | + ctc_fst_decoder_config.Register(po); | ||
| 73 | 76 | ||
| 74 | po->Register("enable-endpoint", &enable_endpoint, | 77 | po->Register("enable-endpoint", &enable_endpoint, |
| 75 | "True to enable endpoint detection. False to disable it."); | 78 | "True to enable endpoint detection. False to disable it."); |
| @@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const { | @@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const { | ||
| 116 | return false; | 119 | return false; |
| 117 | } | 120 | } |
| 118 | 121 | ||
| 122 | + if (!ctc_fst_decoder_config.graph.empty() && | ||
| 123 | + !ctc_fst_decoder_config.Validate()) { | ||
| 124 | + SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config"); | ||
| 125 | + return false; | ||
| 126 | + } | ||
| 127 | + | ||
| 119 | return model_config.Validate(); | 128 | return model_config.Validate(); |
| 120 | } | 129 | } |
| 121 | 130 | ||
| @@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const { | @@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const { | ||
| 127 | os << "model_config=" << model_config.ToString() << ", "; | 136 | os << "model_config=" << model_config.ToString() << ", "; |
| 128 | os << "lm_config=" << lm_config.ToString() << ", "; | 137 | os << "lm_config=" << lm_config.ToString() << ", "; |
| 129 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; | 138 | os << "endpoint_config=" << endpoint_config.ToString() << ", "; |
| 139 | + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", "; | ||
| 130 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; | 140 | os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; |
| 131 | os << "max_active_paths=" << max_active_paths << ", "; | 141 | os << "max_active_paths=" << max_active_paths << ", "; |
| 132 | os << "hotwords_score=" << hotwords_score << ", "; | 142 | os << "hotwords_score=" << hotwords_score << ", "; |
| @@ -16,6 +16,7 @@ | @@ -16,6 +16,7 @@ | ||
| 16 | 16 | ||
| 17 | #include "sherpa-onnx/csrc/endpoint.h" | 17 | #include "sherpa-onnx/csrc/endpoint.h" |
| 18 | #include "sherpa-onnx/csrc/features.h" | 18 | #include "sherpa-onnx/csrc/features.h" |
| 19 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" | ||
| 19 | #include "sherpa-onnx/csrc/online-lm-config.h" | 20 | #include "sherpa-onnx/csrc/online-lm-config.h" |
| 20 | #include "sherpa-onnx/csrc/online-model-config.h" | 21 | #include "sherpa-onnx/csrc/online-model-config.h" |
| 21 | #include "sherpa-onnx/csrc/online-stream.h" | 22 | #include "sherpa-onnx/csrc/online-stream.h" |
| @@ -80,6 +81,7 @@ struct OnlineRecognizerConfig { | @@ -80,6 +81,7 @@ struct OnlineRecognizerConfig { | ||
| 80 | OnlineModelConfig model_config; | 81 | OnlineModelConfig model_config; |
| 81 | OnlineLMConfig lm_config; | 82 | OnlineLMConfig lm_config; |
| 82 | EndpointConfig endpoint_config; | 83 | EndpointConfig endpoint_config; |
| 84 | + OnlineCtcFstDecoderConfig ctc_fst_decoder_config; | ||
| 83 | bool enable_endpoint = true; | 85 | bool enable_endpoint = true; |
| 84 | 86 | ||
| 85 | std::string decoding_method = "greedy_search"; | 87 | std::string decoding_method = "greedy_search"; |
| @@ -96,19 +98,19 @@ struct OnlineRecognizerConfig { | @@ -96,19 +98,19 @@ struct OnlineRecognizerConfig { | ||
| 96 | 98 | ||
| 97 | OnlineRecognizerConfig() = default; | 99 | OnlineRecognizerConfig() = default; |
| 98 | 100 | ||
| 99 | - OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config, | ||
| 100 | - const OnlineModelConfig &model_config, | ||
| 101 | - const OnlineLMConfig &lm_config, | ||
| 102 | - const EndpointConfig &endpoint_config, | ||
| 103 | - bool enable_endpoint, | ||
| 104 | - const std::string &decoding_method, | ||
| 105 | - int32_t max_active_paths, | ||
| 106 | - const std::string &hotwords_file, float hotwords_score, | ||
| 107 | - float blank_penalty) | 101 | + OnlineRecognizerConfig( |
| 102 | + const FeatureExtractorConfig &feat_config, | ||
| 103 | + const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, | ||
| 104 | + const EndpointConfig &endpoint_config, | ||
| 105 | + const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, | ||
| 106 | + bool enable_endpoint, const std::string &decoding_method, | ||
| 107 | + int32_t max_active_paths, const std::string &hotwords_file, | ||
| 108 | + float hotwords_score, float blank_penalty) | ||
| 108 | : feat_config(feat_config), | 109 | : feat_config(feat_config), |
| 109 | model_config(model_config), | 110 | model_config(model_config), |
| 110 | lm_config(lm_config), | 111 | lm_config(lm_config), |
| 111 | endpoint_config(endpoint_config), | 112 | endpoint_config(endpoint_config), |
| 113 | + ctc_fst_decoder_config(ctc_fst_decoder_config), | ||
| 112 | enable_endpoint(enable_endpoint), | 114 | enable_endpoint(enable_endpoint), |
| 113 | decoding_method(decoding_method), | 115 | decoding_method(decoding_method), |
| 114 | max_active_paths(max_active_paths), | 116 | max_active_paths(max_active_paths), |
| @@ -104,6 +104,18 @@ class OnlineStream::Impl { | @@ -104,6 +104,18 @@ class OnlineStream::Impl { | ||
| 104 | return paraformer_alpha_cache_; | 104 | return paraformer_alpha_cache_; |
| 105 | } | 105 | } |
| 106 | 106 | ||
| 107 | + void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) { | ||
| 108 | + faster_decoder_ = std::move(decoder); | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + kaldi_decoder::FasterDecoder *GetFasterDecoder() const { | ||
| 112 | + return faster_decoder_.get(); | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + int32_t &GetFasterDecoderProcessedFrames() { | ||
| 116 | + return faster_decoder_processed_frames_; | ||
| 117 | + } | ||
| 118 | + | ||
| 107 | private: | 119 | private: |
| 108 | FeatureExtractor feat_extractor_; | 120 | FeatureExtractor feat_extractor_; |
| 109 | /// For contextual-biasing | 121 | /// For contextual-biasing |
| @@ -121,6 +133,8 @@ class OnlineStream::Impl { | @@ -121,6 +133,8 @@ class OnlineStream::Impl { | ||
| 121 | std::vector<float> paraformer_encoder_out_cache_; | 133 | std::vector<float> paraformer_encoder_out_cache_; |
| 122 | std::vector<float> paraformer_alpha_cache_; | 134 | std::vector<float> paraformer_alpha_cache_; |
| 123 | OnlineParaformerDecoderResult paraformer_result_; | 135 | OnlineParaformerDecoderResult paraformer_result_; |
| 136 | + std::unique_ptr<kaldi_decoder::FasterDecoder> faster_decoder_; | ||
| 137 | + int32_t faster_decoder_processed_frames_ = 0; | ||
| 124 | }; | 138 | }; |
| 125 | 139 | ||
| 126 | OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, | 140 | OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, |
| @@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { | @@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { | ||
| 208 | return impl_->GetContextGraph(); | 222 | return impl_->GetContextGraph(); |
| 209 | } | 223 | } |
| 210 | 224 | ||
| 225 | +void OnlineStream::SetFasterDecoder( | ||
| 226 | + std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) { | ||
| 227 | + impl_->SetFasterDecoder(std::move(decoder)); | ||
| 228 | +} | ||
| 229 | + | ||
| 230 | +kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const { | ||
| 231 | + return impl_->GetFasterDecoder(); | ||
| 232 | +} | ||
| 233 | + | ||
| 234 | +int32_t &OnlineStream::GetFasterDecoderProcessedFrames() { | ||
| 235 | + return impl_->GetFasterDecoderProcessedFrames(); | ||
| 236 | +} | ||
| 237 | + | ||
| 211 | std::vector<float> &OnlineStream::GetParaformerFeatCache() { | 238 | std::vector<float> &OnlineStream::GetParaformerFeatCache() { |
| 212 | return impl_->GetParaformerFeatCache(); | 239 | return impl_->GetParaformerFeatCache(); |
| 213 | } | 240 | } |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include <memory> | 8 | #include <memory> |
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | +#include "kaldi-decoder/csrc/faster-decoder.h" | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 12 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | #include "sherpa-onnx/csrc/context-graph.h" | 13 | #include "sherpa-onnx/csrc/context-graph.h" |
| 13 | #include "sherpa-onnx/csrc/features.h" | 14 | #include "sherpa-onnx/csrc/features.h" |
| @@ -97,6 +98,11 @@ class OnlineStream { | @@ -97,6 +98,11 @@ class OnlineStream { | ||
| 97 | */ | 98 | */ |
| 98 | const ContextGraphPtr &GetContextGraph() const; | 99 | const ContextGraphPtr &GetContextGraph() const; |
| 99 | 100 | ||
| 101 | + // for online ctc decoder | ||
| 102 | + void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder); | ||
| 103 | + kaldi_decoder::FasterDecoder *GetFasterDecoder() const; | ||
| 104 | + int32_t &GetFasterDecoderProcessedFrames(); | ||
| 105 | + | ||
| 100 | // for streaming paraformer | 106 | // for streaming paraformer |
| 101 | std::vector<float> &GetParaformerFeatCache(); | 107 | std::vector<float> &GetParaformerFeatCache(); |
| 102 | std::vector<float> &GetParaformerEncoderOutCache(); | 108 | std::vector<float> &GetParaformerEncoderOutCache(); |
| @@ -18,6 +18,7 @@ set(srcs | @@ -18,6 +18,7 @@ set(srcs | ||
| 18 | offline-wenet-ctc-model-config.cc | 18 | offline-wenet-ctc-model-config.cc |
| 19 | offline-whisper-model-config.cc | 19 | offline-whisper-model-config.cc |
| 20 | offline-zipformer-ctc-model-config.cc | 20 | offline-zipformer-ctc-model-config.cc |
| 21 | + online-ctc-fst-decoder-config.cc | ||
| 21 | online-lm-config.cc | 22 | online-lm-config.cc |
| 22 | online-model-config.cc | 23 | online-model-config.cc |
| 23 | online-paraformer-model-config.cc | 24 | online-paraformer-model-config.cc |
| 1 | +// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindOnlineCtcFstDecoderConfig(py::module *m) { | ||
| 14 | + using PyClass = OnlineCtcFstDecoderConfig; | ||
| 15 | + py::class_<PyClass>(*m, "OnlineCtcFstDecoderConfig") | ||
| 16 | + .def(py::init<const std::string &, int32_t>(), py::arg("graph") = "", | ||
| 17 | + py::arg("max_active") = 3000) | ||
| 18 | + .def_readwrite("graph", &PyClass::graph) | ||
| 19 | + .def_readwrite("max_active", &PyClass::max_active) | ||
| 20 | + .def("__str__", &PyClass::ToString); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2024 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOnlineCtcFstDecoderConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ |
| @@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) { | @@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) { | ||
| 24 | "tokens", | 24 | "tokens", |
| 25 | [](PyClass &self) -> std::vector<std::string> { return self.tokens; }) | 25 | [](PyClass &self) -> std::vector<std::string> { return self.tokens; }) |
| 26 | .def_property_readonly( | 26 | .def_property_readonly( |
| 27 | - "start_time", | ||
| 28 | - [](PyClass &self) -> float { return self.start_time; }) | 27 | + "start_time", [](PyClass &self) -> float { return self.start_time; }) |
| 29 | .def_property_readonly( | 28 | .def_property_readonly( |
| 30 | "timestamps", | 29 | "timestamps", |
| 31 | [](PyClass &self) -> std::vector<float> { return self.timestamps; }) | 30 | [](PyClass &self) -> std::vector<float> { return self.timestamps; }) |
| @@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) { | @@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) { | ||
| 35 | .def_property_readonly( | 34 | .def_property_readonly( |
| 36 | "lm_probs", | 35 | "lm_probs", |
| 37 | [](PyClass &self) -> std::vector<float> { return self.lm_probs; }) | 36 | [](PyClass &self) -> std::vector<float> { return self.lm_probs; }) |
| 37 | + .def_property_readonly("context_scores", | ||
| 38 | + [](PyClass &self) -> std::vector<float> { | ||
| 39 | + return self.context_scores; | ||
| 40 | + }) | ||
| 38 | .def_property_readonly( | 41 | .def_property_readonly( |
| 39 | - "context_scores", | ||
| 40 | - [](PyClass &self) -> std::vector<float> { | ||
| 41 | - return self.context_scores; | ||
| 42 | - }) | 42 | + "segment", [](PyClass &self) -> int32_t { return self.segment; }) |
| 43 | .def_property_readonly( | 43 | .def_property_readonly( |
| 44 | - "segment", | ||
| 45 | - [](PyClass &self) -> int32_t { return self.segment; }) | ||
| 46 | - .def_property_readonly( | ||
| 47 | - "is_final", | ||
| 48 | - [](PyClass &self) -> bool { return self.is_final; }) | 44 | + "is_final", [](PyClass &self) -> bool { return self.is_final; }) |
| 49 | .def("as_json_string", &PyClass::AsJsonString, | 45 | .def("as_json_string", &PyClass::AsJsonString, |
| 50 | - py::call_guard<py::gil_scoped_release>()); | 46 | + py::call_guard<py::gil_scoped_release>()); |
| 51 | } | 47 | } |
| 52 | 48 | ||
| 53 | static void PybindOnlineRecognizerConfig(py::module *m) { | 49 | static void PybindOnlineRecognizerConfig(py::module *m) { |
| 54 | using PyClass = OnlineRecognizerConfig; | 50 | using PyClass = OnlineRecognizerConfig; |
| 55 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") | 51 | py::class_<PyClass>(*m, "OnlineRecognizerConfig") |
| 56 | - .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, | ||
| 57 | - const OnlineLMConfig &, const EndpointConfig &, bool, | ||
| 58 | - const std::string &, int32_t, const std::string &, float, | ||
| 59 | - float>(), | ||
| 60 | - py::arg("feat_config"), py::arg("model_config"), | ||
| 61 | - py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"), | ||
| 62 | - py::arg("enable_endpoint"), py::arg("decoding_method"), | ||
| 63 | - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | ||
| 64 | - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) | 52 | + .def( |
| 53 | + py::init<const FeatureExtractorConfig &, const OnlineModelConfig &, | ||
| 54 | + const OnlineLMConfig &, const EndpointConfig &, | ||
| 55 | + const OnlineCtcFstDecoderConfig &, bool, const std::string &, | ||
| 56 | + int32_t, const std::string &, float, float>(), | ||
| 57 | + py::arg("feat_config"), py::arg("model_config"), | ||
| 58 | + py::arg("lm_config") = OnlineLMConfig(), | ||
| 59 | + py::arg("endpoint_config") = EndpointConfig(), | ||
| 60 | + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), | ||
| 61 | + py::arg("enable_endpoint"), py::arg("decoding_method"), | ||
| 62 | + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | ||
| 63 | + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) | ||
| 65 | .def_readwrite("feat_config", &PyClass::feat_config) | 64 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 66 | .def_readwrite("model_config", &PyClass::model_config) | 65 | .def_readwrite("model_config", &PyClass::model_config) |
| 67 | .def_readwrite("lm_config", &PyClass::lm_config) | 66 | .def_readwrite("lm_config", &PyClass::lm_config) |
| 68 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) | 67 | .def_readwrite("endpoint_config", &PyClass::endpoint_config) |
| 68 | + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config) | ||
| 69 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) | 69 | .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) |
| 70 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 70 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 71 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 71 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| @@ -15,6 +15,7 @@ | @@ -15,6 +15,7 @@ | ||
| 15 | #include "sherpa-onnx/python/csrc/offline-model-config.h" | 15 | #include "sherpa-onnx/python/csrc/offline-model-config.h" |
| 16 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" | 16 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" |
| 17 | #include "sherpa-onnx/python/csrc/offline-stream.h" | 17 | #include "sherpa-onnx/python/csrc/offline-stream.h" |
| 18 | +#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" | ||
| 18 | #include "sherpa-onnx/python/csrc/online-lm-config.h" | 19 | #include "sherpa-onnx/python/csrc/online-lm-config.h" |
| 19 | #include "sherpa-onnx/python/csrc/online-model-config.h" | 20 | #include "sherpa-onnx/python/csrc/online-model-config.h" |
| 20 | #include "sherpa-onnx/python/csrc/online-recognizer.h" | 21 | #include "sherpa-onnx/python/csrc/online-recognizer.h" |
| @@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 36 | m.doc() = "pybind11 binding of sherpa-onnx"; | 37 | m.doc() = "pybind11 binding of sherpa-onnx"; |
| 37 | 38 | ||
| 38 | PybindFeatures(&m); | 39 | PybindFeatures(&m); |
| 40 | + PybindOnlineCtcFstDecoderConfig(&m); | ||
| 39 | PybindOnlineModelConfig(&m); | 41 | PybindOnlineModelConfig(&m); |
| 40 | PybindOnlineLMConfig(&m); | 42 | PybindOnlineLMConfig(&m); |
| 41 | PybindOnlineStream(&m); | 43 | PybindOnlineStream(&m); |
| @@ -16,6 +16,7 @@ from _sherpa_onnx import ( | @@ -16,6 +16,7 @@ from _sherpa_onnx import ( | ||
| 16 | OnlineTransducerModelConfig, | 16 | OnlineTransducerModelConfig, |
| 17 | OnlineWenetCtcModelConfig, | 17 | OnlineWenetCtcModelConfig, |
| 18 | OnlineZipformer2CtcModelConfig, | 18 | OnlineZipformer2CtcModelConfig, |
| 19 | + OnlineCtcFstDecoderConfig, | ||
| 19 | ) | 20 | ) |
| 20 | 21 | ||
| 21 | 22 | ||
| @@ -314,6 +315,8 @@ class OnlineRecognizer(object): | @@ -314,6 +315,8 @@ class OnlineRecognizer(object): | ||
| 314 | rule2_min_trailing_silence: float = 1.2, | 315 | rule2_min_trailing_silence: float = 1.2, |
| 315 | rule3_min_utterance_length: float = 20.0, | 316 | rule3_min_utterance_length: float = 20.0, |
| 316 | decoding_method: str = "greedy_search", | 317 | decoding_method: str = "greedy_search", |
| 318 | + ctc_graph: str = "", | ||
| 319 | + ctc_max_active: int = 3000, | ||
| 317 | provider: str = "cpu", | 320 | provider: str = "cpu", |
| 318 | ): | 321 | ): |
| 319 | """ | 322 | """ |
| @@ -355,6 +358,12 @@ class OnlineRecognizer(object): | @@ -355,6 +358,12 @@ class OnlineRecognizer(object): | ||
| 355 | is detected. | 358 | is detected. |
| 356 | decoding_method: | 359 | decoding_method: |
| 357 | The only valid value is greedy_search. | 360 | The only valid value is greedy_search. |
| 361 | + ctc_graph: | ||
| 362 | + If not empty, decoding_method is ignored. It contains the path to | ||
| 363 | + H.fst, HL.fst, or HLG.fst | ||
| 364 | + ctc_max_active: | ||
| 365 | + Used only when ctc_graph is not empty. It specifies the maximum | ||
| 366 | + active paths at a time. | ||
| 358 | provider: | 367 | provider: |
| 359 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. | 368 | onnxruntime execution providers. Valid values are: cpu, cuda, coreml. |
| 360 | """ | 369 | """ |
| @@ -384,10 +393,16 @@ class OnlineRecognizer(object): | @@ -384,10 +393,16 @@ class OnlineRecognizer(object): | ||
| 384 | rule3_min_utterance_length=rule3_min_utterance_length, | 393 | rule3_min_utterance_length=rule3_min_utterance_length, |
| 385 | ) | 394 | ) |
| 386 | 395 | ||
| 396 | + ctc_fst_decoder_config = OnlineCtcFstDecoderConfig( | ||
| 397 | + graph=ctc_graph, | ||
| 398 | + max_active=ctc_max_active, | ||
| 399 | + ) | ||
| 400 | + | ||
| 387 | recognizer_config = OnlineRecognizerConfig( | 401 | recognizer_config = OnlineRecognizerConfig( |
| 388 | feat_config=feat_config, | 402 | feat_config=feat_config, |
| 389 | model_config=model_config, | 403 | model_config=model_config, |
| 390 | endpoint_config=endpoint_config, | 404 | endpoint_config=endpoint_config, |
| 405 | + ctc_fst_decoder_config=ctc_fst_decoder_config, | ||
| 391 | enable_endpoint=enable_endpoint_detection, | 406 | enable_endpoint=enable_endpoint_detection, |
| 392 | decoding_method=decoding_method, | 407 | decoding_method=decoding_method, |
| 393 | ) | 408 | ) |
-
请 注册 或 登录 后发表评论