Fangjun Kuang
Committed by GitHub

Add HLG decoding for streaming CTC models (#731)

1 #!/usr/bin/env bash 1 #!/usr/bin/env bash
2 2
3 -set -e 3 +set -ex
4 4
5 log() { 5 log() {
6 # This function is from espnet 6 # This function is from espnet
@@ -14,6 +14,26 @@ echo "PATH: $PATH" @@ -14,6 +14,26 @@ echo "PATH: $PATH"
14 which $EXE 14 which $EXE
15 15
16 log "------------------------------------------------------------" 16 log "------------------------------------------------------------"
  17 +log "Run streaming Zipformer2 CTC HLG decoding "
  18 +log "------------------------------------------------------------"
  19 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  20 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  21 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  22 +repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
  23 +ls -lh $repo
  24 +echo "pwd: $PWD"
  25 +
  26 +$EXE \
  27 + --zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
  28 + --ctc-graph=$repo/HLG.fst \
  29 + --tokens=$repo/tokens.txt \
  30 + $repo/test_wavs/0.wav \
  31 + $repo/test_wavs/1.wav \
  32 + $repo/test_wavs/8k.wav
  33 +
  34 +rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
  35 +
  36 +log "------------------------------------------------------------"
17 log "Run streaming Zipformer2 CTC " 37 log "Run streaming Zipformer2 CTC "
18 log "------------------------------------------------------------" 38 log "------------------------------------------------------------"
19 39
1 #!/usr/bin/env bash 1 #!/usr/bin/env bash
2 2
3 -set -e 3 +set -ex
4 4
5 log() { 5 log() {
6 # This function is from espnet 6 # This function is from espnet
@@ -8,6 +8,23 @@ log() { @@ -8,6 +8,23 @@ log() {
8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" 8 echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
9 } 9 }
10 10
  11 +log "test streaming zipformer2 ctc HLG decoding"
  12 +
  13 +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  14 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  15 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  16 +repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
  17 +
  18 +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
  19 + --debug 1 \
  20 + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
  21 + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
  22 + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
  23 + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
  24 +
  25 +rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
  26 +
  27 +
11 mkdir -p /tmp/icefall-models 28 mkdir -p /tmp/icefall-models
12 dir=/tmp/icefall-models 29 dir=/tmp/icefall-models
13 30
@@ -124,6 +124,14 @@ jobs: @@ -124,6 +124,14 @@ jobs:
124 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} 124 name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
125 path: build/bin/* 125 path: build/bin/*
126 126
  127 + - name: Test online CTC
  128 + shell: bash
  129 + run: |
  130 + export PATH=$PWD/build/bin:$PATH
  131 + export EXE=sherpa-onnx
  132 +
  133 + .github/scripts/test-online-ctc.sh
  134 +
127 - name: Test C API 135 - name: Test C API
128 shell: bash 136 shell: bash
129 run: | 137 run: |
@@ -149,13 +157,6 @@ jobs: @@ -149,13 +157,6 @@ jobs:
149 157
150 .github/scripts/test-kws.sh 158 .github/scripts/test-kws.sh
151 159
152 - - name: Test online CTC  
153 - shell: bash  
154 - run: |  
155 - export PATH=$PWD/build/bin:$PATH  
156 - export EXE=sherpa-onnx  
157 -  
158 - .github/scripts/test-online-ctc.sh  
159 160
160 - name: Test offline Whisper 161 - name: Test offline Whisper
161 if: matrix.build_type != 'Debug' 162 if: matrix.build_type != 'Debug'
1 function(download_kaldi_decoder) 1 function(download_kaldi_decoder)
2 include(FetchContent) 2 include(FetchContent)
3 3
4 - set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")  
5 - set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")  
6 - set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601") 4 + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
  5 + set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
  6 + set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff")
7 7
8 set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) 8 set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
9 set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE) 9 set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
@@ -12,11 +12,11 @@ function(download_kaldi_decoder) @@ -12,11 +12,11 @@ function(download_kaldi_decoder)
12 # If you don't have access to the Internet, 12 # If you don't have access to the Internet,
13 # please pre-download kaldi-decoder 13 # please pre-download kaldi-decoder
14 set(possible_file_locations 14 set(possible_file_locations
15 - $ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz  
16 - ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz  
17 - ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz  
18 - /tmp/kaldi-decoder-0.2.4.tar.gz  
19 - /star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz 15 + $ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz
  16 + ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz
  17 + ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz
  18 + /tmp/kaldi-decoder-0.2.5.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz
20 ) 20 )
21 21
22 foreach(f IN LISTS possible_file_locations) 22 foreach(f IN LISTS possible_file_locations)
  1 +#!/usr/bin/env python3
  2 +
  3 +# This file shows how to use a streaming zipformer CTC model and an HLG
  4 +# graph for decoding.
  5 +#
  6 +# We use the following model as an example
  7 +#
  8 +"""
  9 +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  10 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  11 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
  12 +
  13 +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
  14 + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
  15 + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
  16 + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
  17 + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
  18 +
  19 +"""
  20 +# (The above model is from https://github.com/k2-fsa/icefall/pull/1557)
  21 +
  22 +import argparse
  23 +import time
  24 +import wave
  25 +from pathlib import Path
  26 +from typing import List, Tuple
  27 +
  28 +import numpy as np
  29 +import sherpa_onnx
  30 +
  31 +
  32 +def get_args():
  33 + parser = argparse.ArgumentParser(
  34 + formatter_class=argparse.ArgumentDefaultsHelpFormatter
  35 + )
  36 +
  37 + parser.add_argument(
  38 + "--tokens",
  39 + type=str,
  40 + required=True,
  41 + help="Path to tokens.txt",
  42 + )
  43 +
  44 + parser.add_argument(
  45 + "--model",
  46 + type=str,
  47 + required=True,
  48 + help="Path to the ONNX model",
  49 + )
  50 +
  51 + parser.add_argument(
  52 + "--graph",
  53 + type=str,
  54 + required=True,
  55 + help="Path to H.fst, HL.fst, or HLG.fst",
  56 + )
  57 +
  58 + parser.add_argument(
  59 + "--num-threads",
  60 + type=int,
  61 + default=1,
  62 + help="Number of threads for neural network computation",
  63 + )
  64 +
  65 + parser.add_argument(
  66 + "--provider",
  67 + type=str,
  68 + default="cpu",
  69 + help="Valid values: cpu, cuda, coreml",
  70 + )
  71 +
  72 + parser.add_argument(
  73 + "--debug",
  74 + type=int,
  75 + default=0,
  76 + help="Valid values: 1, 0",
  77 + )
  78 +
  79 + parser.add_argument(
  80 + "sound_file",
  81 + type=str,
  82 + help="The input sound file to decode. It must be of WAVE"
  83 + "format with a single channel, and each sample has 16-bit, "
  84 + "i.e., int16_t. "
  85 + "The sample rate of the file can be arbitrary and does not need to "
  86 + "be 16 kHz",
  87 + )
  88 +
  89 + return parser.parse_args()
  90 +
  91 +
  92 +def assert_file_exists(filename: str):
  93 + assert Path(filename).is_file(), (
  94 + f"{filename} does not exist!\n"
  95 + "Please refer to "
  96 + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
  97 + )
  98 +
  99 +
  100 +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
  101 + """
  102 + Args:
  103 + wave_filename:
  104 + Path to a wave file. It should be single channel and each sample should
  105 + be 16-bit. Its sample rate does not need to be 16kHz.
  106 + Returns:
  107 + Return a tuple containing:
  108 + - A 1-D array of dtype np.float32 containing the samples, which are
  109 + normalized to the range [-1, 1].
  110 + - sample rate of the wave file
  111 + """
  112 +
  113 + with wave.open(wave_filename) as f:
  114 + assert f.getnchannels() == 1, f.getnchannels()
  115 + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  116 + num_samples = f.getnframes()
  117 + samples = f.readframes(num_samples)
  118 + samples_int16 = np.frombuffer(samples, dtype=np.int16)
  119 + samples_float32 = samples_int16.astype(np.float32)
  120 +
  121 + samples_float32 = samples_float32 / 32768
  122 + return samples_float32, f.getframerate()
  123 +
  124 +
  125 +def main():
  126 + args = get_args()
  127 + print(vars(args))
  128 +
  129 + assert_file_exists(args.tokens)
  130 + assert_file_exists(args.graph)
  131 + assert_file_exists(args.model)
  132 +
  133 + recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
  134 + tokens=args.tokens,
  135 + model=args.model,
  136 + num_threads=args.num_threads,
  137 + provider=args.provider,
  138 + sample_rate=16000,
  139 + feature_dim=80,
  140 + ctc_graph=args.graph,
  141 + )
  142 +
  143 + wave_filename = args.sound_file
  144 + assert_file_exists(wave_filename)
  145 + samples, sample_rate = read_wave(wave_filename)
  146 + duration = len(samples) / sample_rate
  147 +
  148 + print("Started")
  149 +
  150 + start_time = time.time()
  151 + s = recognizer.create_stream()
  152 + s.accept_waveform(sample_rate, samples)
  153 + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
  154 + s.accept_waveform(sample_rate, tail_paddings)
  155 + s.input_finished()
  156 + while recognizer.is_ready(s):
  157 + recognizer.decode_stream(s)
  158 +
  159 + result = recognizer.get_result(s).lower()
  160 + end_time = time.time()
  161 +
  162 + elapsed_seconds = end_time - start_time
  163 + rtf = elapsed_seconds / duration
  164 + print(f"num_threads: {args.num_threads}")
  165 + print(f"Wave duration: {duration:.3f} s")
  166 + print(f"Elapsed time: {elapsed_seconds:.3f} s")
  167 + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
  168 + print(result)
  169 +
  170 +
  171 +if __name__ == "__main__":
  172 + main()
@@ -51,6 +51,8 @@ set(sources @@ -51,6 +51,8 @@ set(sources
51 offline-zipformer-ctc-model-config.cc 51 offline-zipformer-ctc-model-config.cc
52 offline-zipformer-ctc-model.cc 52 offline-zipformer-ctc-model.cc
53 online-conformer-transducer-model.cc 53 online-conformer-transducer-model.cc
  54 + online-ctc-fst-decoder-config.cc
  55 + online-ctc-fst-decoder.cc
54 online-ctc-greedy-search-decoder.cc 56 online-ctc-greedy-search-decoder.cc
55 online-ctc-model.cc 57 online-ctc-model.cc
56 online-lm-config.cc 58 online-lm-config.cc
@@ -7,6 +7,9 @@ @@ -7,6 +7,9 @@
7 #include <sstream> 7 #include <sstream>
8 #include <string> 8 #include <string>
9 9
  10 +#include "sherpa-onnx/csrc/file-utils.h"
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
10 namespace sherpa_onnx { 13 namespace sherpa_onnx {
11 14
12 std::string OfflineCtcFstDecoderConfig::ToString() const { 15 std::string OfflineCtcFstDecoderConfig::ToString() const {
@@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { @@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
29 "Decoder max active states. Larger->slower; more accurate"); 32 "Decoder max active states. Larger->slower; more accurate");
30 } 33 }
31 34
  35 +bool OfflineCtcFstDecoderConfig::Validate() const {
  36 + if (!graph.empty() && !FileExists(graph)) {
  37 + SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
  38 + return false;
  39 + }
  40 + return true;
  41 +}
  42 +
32 } // namespace sherpa_onnx 43 } // namespace sherpa_onnx
@@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig { @@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
24 std::string ToString() const; 24 std::string ToString() const;
25 25
26 void Register(ParseOptions *po); 26 void Register(ParseOptions *po);
  27 + bool Validate() const;
27 }; 28 };
28 29
29 } // namespace sherpa_onnx 30 } // namespace sherpa_onnx
@@ -20,7 +20,7 @@ namespace sherpa_onnx { @@ -20,7 +20,7 @@ namespace sherpa_onnx {
20 // @param filename Path to a StdVectorFst or StdConstFst graph 20 // @param filename Path to a StdVectorFst or StdConstFst graph
21 // @return The caller should free the returned pointer using `delete` to 21 // @return The caller should free the returned pointer using `delete` to
22 // avoid memory leak. 22 // avoid memory leak.
23 -static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) { 23 +fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
24 // read decoding network FST 24 // read decoding network FST
25 std::ifstream is(filename, std::ios::binary); 25 std::ifstream is(filename, std::ios::binary);
26 if (!is.good()) { 26 if (!is.good()) {
@@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const { @@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
67 return false; 67 return false;
68 } 68 }
69 69
  70 + if (!ctc_fst_decoder_config.graph.empty() &&
  71 + !ctc_fst_decoder_config.Validate()) {
  72 + SHERPA_ONNX_LOGE("Errors in fst_decoder");
  73 + return false;
  74 + }
  75 +
70 return model_config.Validate(); 76 return model_config.Validate();
71 } 77 }
72 78
@@ -5,12 +5,16 @@ @@ -5,12 +5,16 @@
5 #ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ 5 #ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
6 #define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ 6 #define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
7 7
  8 +#include <memory>
8 #include <vector> 9 #include <vector>
9 10
  11 +#include "kaldi-decoder/csrc/faster-decoder.h"
10 #include "onnxruntime_cxx_api.h" // NOLINT 12 #include "onnxruntime_cxx_api.h" // NOLINT
11 13
12 namespace sherpa_onnx { 14 namespace sherpa_onnx {
13 15
  16 +class OnlineStream;
  17 +
14 struct OnlineCtcDecoderResult { 18 struct OnlineCtcDecoderResult {
15 /// Number of frames after subsampling we have decoded so far 19 /// Number of frames after subsampling we have decoded so far
16 int32_t frame_offset = 0; 20 int32_t frame_offset = 0;
@@ -37,7 +41,13 @@ class OnlineCtcDecoder { @@ -37,7 +41,13 @@ class OnlineCtcDecoder {
37 * @param results Input & Output parameters.. 41 * @param results Input & Output parameters..
38 */ 42 */
39 virtual void Decode(Ort::Value log_probs, 43 virtual void Decode(Ort::Value log_probs,
40 - std::vector<OnlineCtcDecoderResult> *results) = 0; 44 + std::vector<OnlineCtcDecoderResult> *results,
  45 + OnlineStream **ss = nullptr, int32_t n = 0) = 0;
  46 +
  47 + virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
  48 + const {
  49 + return nullptr;
  50 + }
41 }; 51 };
42 52
43 } // namespace sherpa_onnx 53 } // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
  6 +
  7 +#include <sstream>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/file-utils.h"
  11 +#include "sherpa-onnx/csrc/macros.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +std::string OnlineCtcFstDecoderConfig::ToString() const {
  16 + std::ostringstream os;
  17 +
  18 + os << "OnlineCtcFstDecoderConfig(";
  19 + os << "graph=\"" << graph << "\", ";
  20 + os << "max_active=" << max_active << ")";
  21 +
  22 + return os.str();
  23 +}
  24 +
  25 +void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) {
  26 + po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
  27 +
  28 + po->Register("ctc-max-active", &max_active,
  29 + "Decoder max active states. Larger->slower; more accurate");
  30 +}
  31 +
  32 +bool OnlineCtcFstDecoderConfig::Validate() const {
  33 + if (!graph.empty() && !FileExists(graph)) {
  34 + SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
  35 + return false;
  36 + }
  37 + return true;
  38 +}
  39 +
  40 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OnlineCtcFstDecoderConfig {
  15 + // Path to H.fst, HL.fst or HLG.fst
  16 + std::string graph;
  17 + int32_t max_active = 3000;
  18 +
  19 + OnlineCtcFstDecoderConfig() = default;
  20 +
  21 + OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
  22 + : graph(graph), max_active(max_active) {}
  23 +
  24 + std::string ToString() const;
  25 +
  26 + void Register(ParseOptions *po);
  27 + bool Validate() const;
  28 +};
  29 +
  30 +} // namespace sherpa_onnx
  31 +
  32 +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
  1 +// sherpa-onnx/csrc/online-ctc-fst-decoder.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
  6 +
  7 +#include <algorithm>
  8 +#include <memory>
  9 +#include <string>
  10 +#include <utility>
  11 +#include <vector>
  12 +
  13 +#include "fst/fstlib.h"
  14 +#include "kaldi-decoder/csrc/decodable-ctc.h"
  15 +#include "kaldifst/csrc/fstext-utils.h"
  16 +#include "sherpa-onnx/csrc/macros.h"
  17 +#include "sherpa-onnx/csrc/online-stream.h"
  18 +
  19 +namespace sherpa_onnx {
  20 +
  21 +// defined in ./offline-ctc-fst-decoder.cc
  22 +fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename);
  23 +
  24 +OnlineCtcFstDecoder::OnlineCtcFstDecoder(
  25 + const OnlineCtcFstDecoderConfig &config, int32_t blank_id)
  26 + : config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) {
  27 + options_.max_active = config_.max_active;
  28 +}
  29 +
  30 +std::unique_ptr<kaldi_decoder::FasterDecoder>
  31 +OnlineCtcFstDecoder::CreateFasterDecoder() const {
  32 + return std::make_unique<kaldi_decoder::FasterDecoder>(*fst_, options_);
  33 +}
  34 +
  35 +static void DecodeOne(const float *log_probs, int32_t num_rows,
  36 + int32_t num_cols, OnlineCtcDecoderResult *result,
  37 + OnlineStream *s, int32_t blank_id) {
  38 + int32_t &processed_frames = s->GetFasterDecoderProcessedFrames();
  39 + kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols,
  40 + processed_frames);
  41 +
  42 + kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder();
  43 + if (processed_frames == 0) {
  44 + decoder->InitDecoding();
  45 + }
  46 +
  47 + decoder->AdvanceDecoding(&decodable);
  48 +
  49 + if (decoder->ReachedFinal()) {
  50 + fst::VectorFst<fst::LatticeArc> fst_out;
  51 + bool ok = decoder->GetBestPath(&fst_out);
  52 + if (ok) {
  53 + std::vector<int32_t> isymbols_out;
  54 + std::vector<int32_t> osymbols_out_unused;
  55 + ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
  56 + &osymbols_out_unused, nullptr);
  57 + std::vector<int64_t> tokens;
  58 + tokens.reserve(isymbols_out.size());
  59 +
  60 + std::vector<int32_t> timestamps;
  61 + timestamps.reserve(isymbols_out.size());
  62 +
  63 + std::ostringstream os;
  64 + int32_t prev_id = -1;
  65 + int32_t num_trailing_blanks = 0;
  66 + int32_t f = 0; // frame number
  67 +
  68 + for (auto i : isymbols_out) {
  69 + i -= 1;
  70 +
  71 + if (i == blank_id) {
  72 + num_trailing_blanks += 1;
  73 + } else {
  74 + num_trailing_blanks = 0;
  75 + }
  76 +
  77 + if (i != blank_id && i != prev_id) {
  78 + tokens.push_back(i);
  79 + timestamps.push_back(f);
  80 + }
  81 + prev_id = i;
  82 + f += 1;
  83 + }
  84 +
  85 + result->tokens = std::move(tokens);
  86 + result->timestamps = std::move(timestamps);
  87 + // no need to set frame_offset
  88 + }
  89 + }
  90 +
  91 + processed_frames += num_rows;
  92 +}
  93 +
  94 +void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
  95 + std::vector<OnlineCtcDecoderResult> *results,
  96 + OnlineStream **ss, int32_t n) {
  97 + std::vector<int64_t> log_probs_shape =
  98 + log_probs.GetTensorTypeAndShapeInfo().GetShape();
  99 +
  100 + if (log_probs_shape[0] != results->size()) {
  101 + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
  102 + static_cast<int32_t>(log_probs_shape[0]),
  103 + static_cast<int32_t>(results->size()));
  104 + exit(-1);
  105 + }
  106 +
  107 + if (log_probs_shape[0] != n) {
  108 + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
  109 + static_cast<int32_t>(log_probs_shape[0]), n);
  110 + exit(-1);
  111 + }
  112 +
  113 + int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
  114 + int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
  115 + int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
  116 +
  117 + const float *p = log_probs.GetTensorData<float>();
  118 +
  119 + for (int32_t i = 0; i != batch_size; ++i) {
  120 + DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
  121 + &(*results)[i], ss[i], blank_id_);
  122 + }
  123 +}
  124 +
  125 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/online-ctc-fst-decoder.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
  7 +
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +#include "fst/fst.h"
  12 +#include "sherpa-onnx/csrc/online-ctc-decoder.h"
  13 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
  14 +
  15 +namespace sherpa_onnx {
  16 +
  17 +class OnlineCtcFstDecoder : public OnlineCtcDecoder {
  18 + public:
  19 + OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
  20 + int32_t blank_id);
  21 +
  22 + void Decode(Ort::Value log_probs,
  23 + std::vector<OnlineCtcDecoderResult> *results,
  24 + OnlineStream **ss = nullptr, int32_t n = 0) override;
  25 +
  26 + std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
  27 + const override;
  28 +
  29 + private:
  30 + OnlineCtcFstDecoderConfig config_;
  31 + kaldi_decoder::FasterDecoderOptions options_;
  32 +
  33 + std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
  34 + int32_t blank_id_ = 0;
  35 +};
  36 +
  37 +} // namespace sherpa_onnx
  38 +
  39 +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
@@ -13,7 +13,8 @@ @@ -13,7 +13,8 @@
13 namespace sherpa_onnx { 13 namespace sherpa_onnx {
14 14
15 void OnlineCtcGreedySearchDecoder::Decode( 15 void OnlineCtcGreedySearchDecoder::Decode(
16 - Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) { 16 + Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
  17 + OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
17 std::vector<int64_t> log_probs_shape = 18 std::vector<int64_t> log_probs_shape =
18 log_probs.GetTensorTypeAndShapeInfo().GetShape(); 19 log_probs.GetTensorTypeAndShapeInfo().GetShape();
19 20
@@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { @@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
17 : blank_id_(blank_id) {} 17 : blank_id_(blank_id) {}
18 18
19 void Decode(Ort::Value log_probs, 19 void Decode(Ort::Value log_probs,
20 - std::vector<OnlineCtcDecoderResult> *results) override; 20 + std::vector<OnlineCtcDecoderResult> *results,
  21 + OnlineStream **ss = nullptr, int32_t n = 0) override;
21 22
22 private: 23 private:
23 int32_t blank_id_; 24 int32_t blank_id_;
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 #include "sherpa-onnx/csrc/file-utils.h" 16 #include "sherpa-onnx/csrc/file-utils.h"
17 #include "sherpa-onnx/csrc/macros.h" 17 #include "sherpa-onnx/csrc/macros.h"
18 #include "sherpa-onnx/csrc/online-ctc-decoder.h" 18 #include "sherpa-onnx/csrc/online-ctc-decoder.h"
  19 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
19 #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" 20 #include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
20 #include "sherpa-onnx/csrc/online-ctc-model.h" 21 #include "sherpa-onnx/csrc/online-ctc-model.h"
21 #include "sherpa-onnx/csrc/online-recognizer-impl.h" 22 #include "sherpa-onnx/csrc/online-recognizer-impl.h"
@@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
99 std::unique_ptr<OnlineStream> CreateStream() const override { 100 std::unique_ptr<OnlineStream> CreateStream() const override {
100 auto stream = std::make_unique<OnlineStream>(config_.feat_config); 101 auto stream = std::make_unique<OnlineStream>(config_.feat_config);
101 stream->SetStates(model_->GetInitStates()); 102 stream->SetStates(model_->GetInitStates());
  103 + stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
102 104
103 return stream; 105 return stream;
104 } 106 }
@@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
165 std::vector<std::vector<Ort::Value>> next_states = 167 std::vector<std::vector<Ort::Value>> next_states =
166 model_->UnStackStates(std::move(out_states)); 168 model_->UnStackStates(std::move(out_states));
167 169
168 - decoder_->Decode(std::move(out[0]), &results); 170 + decoder_->Decode(std::move(out[0]), &results, ss, n);
169 171
170 for (int32_t k = 0; k != n; ++k) { 172 for (int32_t k = 0; k != n; ++k) {
171 ss[k]->SetCtcResult(results[k]); 173 ss[k]->SetCtcResult(results[k]);
@@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
221 223
222 private: 224 private:
223 void InitDecoder() { 225 void InitDecoder() {
224 - if (config_.decoding_method == "greedy_search") {  
225 - if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&  
226 - !sym_.contains("<blank>")) {  
227 - SHERPA_ONNX_LOGE(  
228 - "We expect that tokens.txt contains "  
229 - "the symbol <blk> or <eps> or <blank> and its ID.");  
230 - exit(-1);  
231 - } 226 + if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
  227 + !sym_.contains("<blank>")) {
  228 + SHERPA_ONNX_LOGE(
  229 + "We expect that tokens.txt contains "
  230 + "the symbol <blk> or <eps> or <blank> and its ID.");
  231 + exit(-1);
  232 + }
232 233
233 - int32_t blank_id = 0;  
234 - if (sym_.contains("<blk>")) {  
235 - blank_id = sym_["<blk>"];  
236 - } else if (sym_.contains("<eps>")) {  
237 - // for tdnn models of the yesno recipe from icefall  
238 - blank_id = sym_["<eps>"];  
239 - } else if (sym_.contains("<blank>")) {  
240 - // for WeNet CTC models  
241 - blank_id = sym_["<blank>"];  
242 - } 234 + int32_t blank_id = 0;
  235 + if (sym_.contains("<blk>")) {
  236 + blank_id = sym_["<blk>"];
  237 + } else if (sym_.contains("<eps>")) {
  238 + // for tdnn models of the yesno recipe from icefall
  239 + blank_id = sym_["<eps>"];
  240 + } else if (sym_.contains("<blank>")) {
  241 + // for WeNet CTC models
  242 + blank_id = sym_["<blank>"];
  243 + }
243 244
  245 + if (!config_.ctc_fst_decoder_config.graph.empty()) {
  246 + decoder_ = std::make_unique<OnlineCtcFstDecoder>(
  247 + config_.ctc_fst_decoder_config, blank_id);
  248 + } else if (config_.decoding_method == "greedy_search") {
244 decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id); 249 decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
245 } else { 250 } else {
246 - SHERPA_ONNX_LOGE("Unsupported decoding method: %s",  
247 - config_.decoding_method.c_str()); 251 + SHERPA_ONNX_LOGE(
  252 + "Unsupported decoding method: %s for streaming CTC models",
  253 + config_.decoding_method.c_str());
248 exit(-1); 254 exit(-1);
249 } 255 }
250 } 256 }
@@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { @@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
281 std::vector<OnlineCtcDecoderResult> results(1); 287 std::vector<OnlineCtcDecoderResult> results(1);
282 results[0] = std::move(s->GetCtcResult()); 288 results[0] = std::move(s->GetCtcResult());
283 289
284 - decoder_->Decode(std::move(out[0]), &results); 290 + decoder_->Decode(std::move(out[0]), &results, &s, 1);
285 s->SetCtcResult(results[0]); 291 s->SetCtcResult(results[0]);
286 } 292 }
287 293
@@ -19,13 +19,13 @@ @@ -19,13 +19,13 @@
19 namespace sherpa_onnx { 19 namespace sherpa_onnx {
20 20
21 /// Helper for `OnlineRecognizerResult::AsJsonString()` 21 /// Helper for `OnlineRecognizerResult::AsJsonString()`
22 -template<typename T>  
23 -std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) { 22 +template <typename T>
  23 +std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
24 std::ostringstream oss; 24 std::ostringstream oss;
25 oss << std::fixed << std::setprecision(precision); 25 oss << std::fixed << std::setprecision(precision);
26 oss << "[ "; 26 oss << "[ ";
27 std::string sep = ""; 27 std::string sep = "";
28 - for (const auto& item : vec) { 28 + for (const auto &item : vec) {
29 oss << sep << item; 29 oss << sep << item;
30 sep = ", "; 30 sep = ", ";
31 } 31 }
@@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) { @@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
34 } 34 }
35 35
36 /// Helper for `OnlineRecognizerResult::AsJsonString()` 36 /// Helper for `OnlineRecognizerResult::AsJsonString()`
37 -template<> // explicit specialization for T = std::string  
38 -std::string VecToString<std::string>(const std::vector<std::string>& vec, 37 +template <> // explicit specialization for T = std::string
  38 +std::string VecToString<std::string>(const std::vector<std::string> &vec,
39 int32_t) { // ignore 2nd arg 39 int32_t) { // ignore 2nd arg
40 std::ostringstream oss; 40 std::ostringstream oss;
41 oss << "[ "; 41 oss << "[ ";
42 std::string sep = ""; 42 std::string sep = "";
43 - for (const auto& item : vec) { 43 + for (const auto &item : vec) {
44 oss << sep << "\"" << item << "\""; 44 oss << sep << "\"" << item << "\"";
45 sep = ", "; 45 sep = ", ";
46 } 46 }
@@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec, @@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec,
51 std::string OnlineRecognizerResult::AsJsonString() const { 51 std::string OnlineRecognizerResult::AsJsonString() const {
52 std::ostringstream os; 52 std::ostringstream os;
53 os << "{ "; 53 os << "{ ";
54 - os << "\"text\": " << "\"" << text << "\"" << ", "; 54 + os << "\"text\": "
  55 + << "\"" << text << "\""
  56 + << ", ";
55 os << "\"tokens\": " << VecToString(tokens) << ", "; 57 os << "\"tokens\": " << VecToString(tokens) << ", ";
56 os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; 58 os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
57 os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; 59 os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
58 os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; 60 os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
59 os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; 61 os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
60 os << "\"segment\": " << segment << ", "; 62 os << "\"segment\": " << segment << ", ";
61 - os << "\"start_time\": " << std::fixed << std::setprecision(2)  
62 - << start_time << ", "; 63 + os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
  64 + << ", ";
63 os << "\"is_final\": " << (is_final ? "true" : "false"); 65 os << "\"is_final\": " << (is_final ? "true" : "false");
64 os << "}"; 66 os << "}";
65 return os.str(); 67 return os.str();
@@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
70 model_config.Register(po); 72 model_config.Register(po);
71 endpoint_config.Register(po); 73 endpoint_config.Register(po);
72 lm_config.Register(po); 74 lm_config.Register(po);
  75 + ctc_fst_decoder_config.Register(po);
73 76
74 po->Register("enable-endpoint", &enable_endpoint, 77 po->Register("enable-endpoint", &enable_endpoint,
75 "True to enable endpoint detection. False to disable it."); 78 "True to enable endpoint detection. False to disable it.");
@@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const { @@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const {
116 return false; 119 return false;
117 } 120 }
118 121
  122 + if (!ctc_fst_decoder_config.graph.empty() &&
  123 + !ctc_fst_decoder_config.Validate()) {
  124 + SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config");
  125 + return false;
  126 + }
  127 +
119 return model_config.Validate(); 128 return model_config.Validate();
120 } 129 }
121 130
@@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const {
127 os << "model_config=" << model_config.ToString() << ", "; 136 os << "model_config=" << model_config.ToString() << ", ";
128 os << "lm_config=" << lm_config.ToString() << ", "; 137 os << "lm_config=" << lm_config.ToString() << ", ";
129 os << "endpoint_config=" << endpoint_config.ToString() << ", "; 138 os << "endpoint_config=" << endpoint_config.ToString() << ", ";
  139 + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
130 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; 140 os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
131 os << "max_active_paths=" << max_active_paths << ", "; 141 os << "max_active_paths=" << max_active_paths << ", ";
132 os << "hotwords_score=" << hotwords_score << ", "; 142 os << "hotwords_score=" << hotwords_score << ", ";
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 16
17 #include "sherpa-onnx/csrc/endpoint.h" 17 #include "sherpa-onnx/csrc/endpoint.h"
18 #include "sherpa-onnx/csrc/features.h" 18 #include "sherpa-onnx/csrc/features.h"
  19 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
19 #include "sherpa-onnx/csrc/online-lm-config.h" 20 #include "sherpa-onnx/csrc/online-lm-config.h"
20 #include "sherpa-onnx/csrc/online-model-config.h" 21 #include "sherpa-onnx/csrc/online-model-config.h"
21 #include "sherpa-onnx/csrc/online-stream.h" 22 #include "sherpa-onnx/csrc/online-stream.h"
@@ -80,6 +81,7 @@ struct OnlineRecognizerConfig { @@ -80,6 +81,7 @@ struct OnlineRecognizerConfig {
80 OnlineModelConfig model_config; 81 OnlineModelConfig model_config;
81 OnlineLMConfig lm_config; 82 OnlineLMConfig lm_config;
82 EndpointConfig endpoint_config; 83 EndpointConfig endpoint_config;
  84 + OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
83 bool enable_endpoint = true; 85 bool enable_endpoint = true;
84 86
85 std::string decoding_method = "greedy_search"; 87 std::string decoding_method = "greedy_search";
@@ -96,19 +98,19 @@ struct OnlineRecognizerConfig { @@ -96,19 +98,19 @@ struct OnlineRecognizerConfig {
96 98
97 OnlineRecognizerConfig() = default; 99 OnlineRecognizerConfig() = default;
98 100
99 - OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,  
100 - const OnlineModelConfig &model_config,  
101 - const OnlineLMConfig &lm_config,  
102 - const EndpointConfig &endpoint_config,  
103 - bool enable_endpoint,  
104 - const std::string &decoding_method,  
105 - int32_t max_active_paths,  
106 - const std::string &hotwords_file, float hotwords_score,  
107 - float blank_penalty) 101 + OnlineRecognizerConfig(
  102 + const FeatureExtractorConfig &feat_config,
  103 + const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
  104 + const EndpointConfig &endpoint_config,
  105 + const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
  106 + bool enable_endpoint, const std::string &decoding_method,
  107 + int32_t max_active_paths, const std::string &hotwords_file,
  108 + float hotwords_score, float blank_penalty)
108 : feat_config(feat_config), 109 : feat_config(feat_config),
109 model_config(model_config), 110 model_config(model_config),
110 lm_config(lm_config), 111 lm_config(lm_config),
111 endpoint_config(endpoint_config), 112 endpoint_config(endpoint_config),
  113 + ctc_fst_decoder_config(ctc_fst_decoder_config),
112 enable_endpoint(enable_endpoint), 114 enable_endpoint(enable_endpoint),
113 decoding_method(decoding_method), 115 decoding_method(decoding_method),
114 max_active_paths(max_active_paths), 116 max_active_paths(max_active_paths),
@@ -104,6 +104,18 @@ class OnlineStream::Impl { @@ -104,6 +104,18 @@ class OnlineStream::Impl {
104 return paraformer_alpha_cache_; 104 return paraformer_alpha_cache_;
105 } 105 }
106 106
  107 + void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
  108 + faster_decoder_ = std::move(decoder);
  109 + }
  110 +
  111 + kaldi_decoder::FasterDecoder *GetFasterDecoder() const {
  112 + return faster_decoder_.get();
  113 + }
  114 +
  115 + int32_t &GetFasterDecoderProcessedFrames() {
  116 + return faster_decoder_processed_frames_;
  117 + }
  118 +
107 private: 119 private:
108 FeatureExtractor feat_extractor_; 120 FeatureExtractor feat_extractor_;
109 /// For contextual-biasing 121 /// For contextual-biasing
@@ -121,6 +133,8 @@ class OnlineStream::Impl { @@ -121,6 +133,8 @@ class OnlineStream::Impl {
121 std::vector<float> paraformer_encoder_out_cache_; 133 std::vector<float> paraformer_encoder_out_cache_;
122 std::vector<float> paraformer_alpha_cache_; 134 std::vector<float> paraformer_alpha_cache_;
123 OnlineParaformerDecoderResult paraformer_result_; 135 OnlineParaformerDecoderResult paraformer_result_;
  136 + std::unique_ptr<kaldi_decoder::FasterDecoder> faster_decoder_;
  137 + int32_t faster_decoder_processed_frames_ = 0;
124 }; 138 };
125 139
126 OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, 140 OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
@@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const { @@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
208 return impl_->GetContextGraph(); 222 return impl_->GetContextGraph();
209 } 223 }
210 224
  225 +void OnlineStream::SetFasterDecoder(
  226 + std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
  227 + impl_->SetFasterDecoder(std::move(decoder));
  228 +}
  229 +
  230 +kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const {
  231 + return impl_->GetFasterDecoder();
  232 +}
  233 +
  234 +int32_t &OnlineStream::GetFasterDecoderProcessedFrames() {
  235 + return impl_->GetFasterDecoderProcessedFrames();
  236 +}
  237 +
211 std::vector<float> &OnlineStream::GetParaformerFeatCache() { 238 std::vector<float> &OnlineStream::GetParaformerFeatCache() {
212 return impl_->GetParaformerFeatCache(); 239 return impl_->GetParaformerFeatCache();
213 } 240 }
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include <memory> 8 #include <memory>
9 #include <vector> 9 #include <vector>
10 10
  11 +#include "kaldi-decoder/csrc/faster-decoder.h"
11 #include "onnxruntime_cxx_api.h" // NOLINT 12 #include "onnxruntime_cxx_api.h" // NOLINT
12 #include "sherpa-onnx/csrc/context-graph.h" 13 #include "sherpa-onnx/csrc/context-graph.h"
13 #include "sherpa-onnx/csrc/features.h" 14 #include "sherpa-onnx/csrc/features.h"
@@ -97,6 +98,11 @@ class OnlineStream { @@ -97,6 +98,11 @@ class OnlineStream {
97 */ 98 */
98 const ContextGraphPtr &GetContextGraph() const; 99 const ContextGraphPtr &GetContextGraph() const;
99 100
  101 + // for online ctc decoder
  102 + void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder);
  103 + kaldi_decoder::FasterDecoder *GetFasterDecoder() const;
  104 + int32_t &GetFasterDecoderProcessedFrames();
  105 +
100 // for streaming paraformer 106 // for streaming paraformer
101 std::vector<float> &GetParaformerFeatCache(); 107 std::vector<float> &GetParaformerFeatCache();
102 std::vector<float> &GetParaformerEncoderOutCache(); 108 std::vector<float> &GetParaformerEncoderOutCache();
@@ -18,6 +18,7 @@ set(srcs @@ -18,6 +18,7 @@ set(srcs
18 offline-wenet-ctc-model-config.cc 18 offline-wenet-ctc-model-config.cc
19 offline-whisper-model-config.cc 19 offline-whisper-model-config.cc
20 offline-zipformer-ctc-model-config.cc 20 offline-zipformer-ctc-model-config.cc
  21 + online-ctc-fst-decoder-config.cc
21 online-lm-config.cc 22 online-lm-config.cc
22 online-model-config.cc 23 online-model-config.cc
23 online-paraformer-model-config.cc 24 online-paraformer-model-config.cc
  1 +// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOnlineCtcFstDecoderConfig(py::module *m) {
  14 + using PyClass = OnlineCtcFstDecoderConfig;
  15 + py::class_<PyClass>(*m, "OnlineCtcFstDecoderConfig")
  16 + .def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
  17 + py::arg("max_active") = 3000)
  18 + .def_readwrite("graph", &PyClass::graph)
  19 + .def_readwrite("max_active", &PyClass::max_active)
  20 + .def("__str__", &PyClass::ToString);
  21 +}
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
  2 +//
  3 +// Copyright (c) 2024 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOnlineCtcFstDecoderConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
@@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) { @@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) {
24 "tokens", 24 "tokens",
25 [](PyClass &self) -> std::vector<std::string> { return self.tokens; }) 25 [](PyClass &self) -> std::vector<std::string> { return self.tokens; })
26 .def_property_readonly( 26 .def_property_readonly(
27 - "start_time",  
28 - [](PyClass &self) -> float { return self.start_time; }) 27 + "start_time", [](PyClass &self) -> float { return self.start_time; })
29 .def_property_readonly( 28 .def_property_readonly(
30 "timestamps", 29 "timestamps",
31 [](PyClass &self) -> std::vector<float> { return self.timestamps; }) 30 [](PyClass &self) -> std::vector<float> { return self.timestamps; })
@@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) { @@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) {
35 .def_property_readonly( 34 .def_property_readonly(
36 "lm_probs", 35 "lm_probs",
37 [](PyClass &self) -> std::vector<float> { return self.lm_probs; }) 36 [](PyClass &self) -> std::vector<float> { return self.lm_probs; })
  37 + .def_property_readonly("context_scores",
  38 + [](PyClass &self) -> std::vector<float> {
  39 + return self.context_scores;
  40 + })
38 .def_property_readonly( 41 .def_property_readonly(
39 - "context_scores",  
40 - [](PyClass &self) -> std::vector<float> {  
41 - return self.context_scores;  
42 - }) 42 + "segment", [](PyClass &self) -> int32_t { return self.segment; })
43 .def_property_readonly( 43 .def_property_readonly(
44 - "segment",  
45 - [](PyClass &self) -> int32_t { return self.segment; })  
46 - .def_property_readonly(  
47 - "is_final",  
48 - [](PyClass &self) -> bool { return self.is_final; }) 44 + "is_final", [](PyClass &self) -> bool { return self.is_final; })
49 .def("as_json_string", &PyClass::AsJsonString, 45 .def("as_json_string", &PyClass::AsJsonString,
50 - py::call_guard<py::gil_scoped_release>()); 46 + py::call_guard<py::gil_scoped_release>());
51 } 47 }
52 48
53 static void PybindOnlineRecognizerConfig(py::module *m) { 49 static void PybindOnlineRecognizerConfig(py::module *m) {
54 using PyClass = OnlineRecognizerConfig; 50 using PyClass = OnlineRecognizerConfig;
55 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 51 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
56 - .def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,  
57 - const OnlineLMConfig &, const EndpointConfig &, bool,  
58 - const std::string &, int32_t, const std::string &, float,  
59 - float>(),  
60 - py::arg("feat_config"), py::arg("model_config"),  
61 - py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),  
62 - py::arg("enable_endpoint"), py::arg("decoding_method"),  
63 - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",  
64 - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) 52 + .def(
  53 + py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
  54 + const OnlineLMConfig &, const EndpointConfig &,
  55 + const OnlineCtcFstDecoderConfig &, bool, const std::string &,
  56 + int32_t, const std::string &, float, float>(),
  57 + py::arg("feat_config"), py::arg("model_config"),
  58 + py::arg("lm_config") = OnlineLMConfig(),
  59 + py::arg("endpoint_config") = EndpointConfig(),
  60 + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
  61 + py::arg("enable_endpoint"), py::arg("decoding_method"),
  62 + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
  63 + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
65 .def_readwrite("feat_config", &PyClass::feat_config) 64 .def_readwrite("feat_config", &PyClass::feat_config)
66 .def_readwrite("model_config", &PyClass::model_config) 65 .def_readwrite("model_config", &PyClass::model_config)
67 .def_readwrite("lm_config", &PyClass::lm_config) 66 .def_readwrite("lm_config", &PyClass::lm_config)
68 .def_readwrite("endpoint_config", &PyClass::endpoint_config) 67 .def_readwrite("endpoint_config", &PyClass::endpoint_config)
  68 + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
69 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) 69 .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
70 .def_readwrite("decoding_method", &PyClass::decoding_method) 70 .def_readwrite("decoding_method", &PyClass::decoding_method)
71 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 71 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
@@ -15,6 +15,7 @@ @@ -15,6 +15,7 @@
15 #include "sherpa-onnx/python/csrc/offline-model-config.h" 15 #include "sherpa-onnx/python/csrc/offline-model-config.h"
16 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 16 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
17 #include "sherpa-onnx/python/csrc/offline-stream.h" 17 #include "sherpa-onnx/python/csrc/offline-stream.h"
  18 +#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
18 #include "sherpa-onnx/python/csrc/online-lm-config.h" 19 #include "sherpa-onnx/python/csrc/online-lm-config.h"
19 #include "sherpa-onnx/python/csrc/online-model-config.h" 20 #include "sherpa-onnx/python/csrc/online-model-config.h"
20 #include "sherpa-onnx/python/csrc/online-recognizer.h" 21 #include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
36 m.doc() = "pybind11 binding of sherpa-onnx"; 37 m.doc() = "pybind11 binding of sherpa-onnx";
37 38
38 PybindFeatures(&m); 39 PybindFeatures(&m);
  40 + PybindOnlineCtcFstDecoderConfig(&m);
39 PybindOnlineModelConfig(&m); 41 PybindOnlineModelConfig(&m);
40 PybindOnlineLMConfig(&m); 42 PybindOnlineLMConfig(&m);
41 PybindOnlineStream(&m); 43 PybindOnlineStream(&m);
@@ -16,6 +16,7 @@ from _sherpa_onnx import ( @@ -16,6 +16,7 @@ from _sherpa_onnx import (
16 OnlineTransducerModelConfig, 16 OnlineTransducerModelConfig,
17 OnlineWenetCtcModelConfig, 17 OnlineWenetCtcModelConfig,
18 OnlineZipformer2CtcModelConfig, 18 OnlineZipformer2CtcModelConfig,
  19 + OnlineCtcFstDecoderConfig,
19 ) 20 )
20 21
21 22
@@ -314,6 +315,8 @@ class OnlineRecognizer(object): @@ -314,6 +315,8 @@ class OnlineRecognizer(object):
314 rule2_min_trailing_silence: float = 1.2, 315 rule2_min_trailing_silence: float = 1.2,
315 rule3_min_utterance_length: float = 20.0, 316 rule3_min_utterance_length: float = 20.0,
316 decoding_method: str = "greedy_search", 317 decoding_method: str = "greedy_search",
  318 + ctc_graph: str = "",
  319 + ctc_max_active: int = 3000,
317 provider: str = "cpu", 320 provider: str = "cpu",
318 ): 321 ):
319 """ 322 """
@@ -355,6 +358,12 @@ class OnlineRecognizer(object): @@ -355,6 +358,12 @@ class OnlineRecognizer(object):
355 is detected. 358 is detected.
356 decoding_method: 359 decoding_method:
357 The only valid value is greedy_search. 360 The only valid value is greedy_search.
  361 + ctc_graph:
  362 + If not empty, decoding_method is ignored. It contains the path to
  363 + H.fst, HL.fst, or HLG.fst
  364 + ctc_max_active:
  365 + Used only when ctc_graph is not empty. It specifies the maximum
  366 + active paths at a time.
358 provider: 367 provider:
359 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 368 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
360 """ 369 """
@@ -384,10 +393,16 @@ class OnlineRecognizer(object): @@ -384,10 +393,16 @@ class OnlineRecognizer(object):
384 rule3_min_utterance_length=rule3_min_utterance_length, 393 rule3_min_utterance_length=rule3_min_utterance_length,
385 ) 394 )
386 395
  396 + ctc_fst_decoder_config = OnlineCtcFstDecoderConfig(
  397 + graph=ctc_graph,
  398 + max_active=ctc_max_active,
  399 + )
  400 +
387 recognizer_config = OnlineRecognizerConfig( 401 recognizer_config = OnlineRecognizerConfig(
388 feat_config=feat_config, 402 feat_config=feat_config,
389 model_config=model_config, 403 model_config=model_config,
390 endpoint_config=endpoint_config, 404 endpoint_config=endpoint_config,
  405 + ctc_fst_decoder_config=ctc_fst_decoder_config,
391 enable_endpoint=enable_endpoint_detection, 406 enable_endpoint=enable_endpoint_detection,
392 decoding_method=decoding_method, 407 decoding_method=decoding_method,
393 ) 408 )