Fangjun Kuang
Committed by GitHub

Add CTC HLG decoding using OpenFst (#349)

正在显示 39 个修改的文件 包含 963 行增加55 行删除
@@ -89,3 +89,48 @@ time $EXE \ @@ -89,3 +89,48 @@ time $EXE \
89 $repo/test_wavs/8k.wav 89 $repo/test_wavs/8k.wav
90 90
91 rm -rf $repo 91 rm -rf $repo
  92 +
  93 +log "------------------------------------------------------------"
  94 +log "Run Librispeech zipformer CTC H/HL/HLG decoding (English) "
  95 +log "------------------------------------------------------------"
  96 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
  97 +log "Start testing ${repo_url}"
  98 +repo=$(basename $repo_url)
  99 +log "Download pretrained model and test-data from $repo_url"
  100 +
  101 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  102 +pushd $repo
  103 +git lfs pull --include "*.onnx"
  104 +git lfs pull --include "*.fst"
  105 +ls -lh
  106 +popd
  107 +
  108 +graphs=(
  109 +$repo/H.fst
  110 +$repo/HL.fst
  111 +$repo/HLG.fst
  112 +)
  113 +
  114 +for graph in ${graphs[@]}; do
  115 + log "test float32 models with $graph"
  116 + time $EXE \
  117 + --model-type=zipformer2_ctc \
  118 + --ctc.graph=$graph \
  119 + --zipformer-ctc-model=$repo/model.onnx \
  120 + --tokens=$repo/tokens.txt \
  121 + $repo/test_wavs/0.wav \
  122 + $repo/test_wavs/1.wav \
  123 + $repo/test_wavs/2.wav
  124 +
  125 + log "test int8 models with $graph"
  126 + time $EXE \
  127 + --model-type=zipformer2_ctc \
  128 + --ctc.graph=$graph \
  129 + --zipformer-ctc-model=$repo/model.int8.onnx \
  130 + --tokens=$repo/tokens.txt \
  131 + $repo/test_wavs/0.wav \
  132 + $repo/test_wavs/1.wav \
  133 + $repo/test_wavs/2.wav
  134 +done
  135 +
  136 +rm -rf $repo
@@ -18,7 +18,7 @@ permissions: @@ -18,7 +18,7 @@ permissions:
18 jobs: 18 jobs:
19 python_online_websocket_server: 19 python_online_websocket_server:
20 runs-on: ${{ matrix.os }} 20 runs-on: ${{ matrix.os }}
21 - name: ${{ matrix.os }} ${{ matrix.python-version }} 21 + name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
22 strategy: 22 strategy:
23 fail-fast: false 23 fail-fast: false
24 matrix: 24 matrix:
@@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) @@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
154 endif() 154 endif()
155 155
156 include(kaldi-native-fbank) 156 include(kaldi-native-fbank)
  157 +include(kaldi-decoder)
157 include(onnxruntime) 158 include(onnxruntime)
158 159
159 if(SHERPA_ONNX_ENABLE_PORTAUDIO) 160 if(SHERPA_ONNX_ENABLE_PORTAUDIO)
  1 +function(download_eigen)
  2 + include(FetchContent)
  3 +
  4 + set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz")
  5 + set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz")
  6 + set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download eigen
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/eigen-3.4.0.tar.gz
  12 + ${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz
  13 + ${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz
  14 + /tmp/eigen-3.4.0.tar.gz
  15 + /star-fj/fangjun/download/github/eigen-3.4.0.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(eigen_URL "${f}")
  21 + file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL)
  22 + message(STATUS "Found local downloaded eigen: ${eigen_URL}")
  23 + set(eigen_URL2)
  24 + break()
  25 + endif()
  26 + endforeach()
  27 +
  28 + set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
  29 + set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE)
  30 +
  31 + FetchContent_Declare(eigen
  32 + URL ${eigen_URL}
  33 + URL_HASH ${eigen_HASH}
  34 + )
  35 +
  36 + FetchContent_GetProperties(eigen)
  37 + if(NOT eigen_POPULATED)
  38 + message(STATUS "Downloading eigen from ${eigen_URL}")
  39 + FetchContent_Populate(eigen)
  40 + endif()
  41 + message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}")
  42 + message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}")
  43 +
  44 + add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL)
  45 +endfunction()
  46 +
  47 +download_eigen()
  48 +
  1 +function(download_kaldi_decoder)
  2 + include(FetchContent)
  3 +
  4 + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.3.tar.gz")
  5 + set(kaldi_decoder_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.3.tar.gz")
  6 + set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")
  7 +
  8 + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  9 + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  10 + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  11 +
  12 + # If you don't have access to the Internet,
  13 + # please pre-download kaldi-decoder
  14 + set(possible_file_locations
  15 + $ENV{HOME}/Downloads/kaldi-decoder-0.2.3.tar.gz
  16 + ${PROJECT_SOURCE_DIR}/kaldi-decoder-0.2.3.tar.gz
  17 + ${PROJECT_BINARY_DIR}/kaldi-decoder-0.2.3.tar.gz
  18 + /tmp/kaldi-decoder-0.2.3.tar.gz
  19 + /star-fj/fangjun/download/github/kaldi-decoder-0.2.3.tar.gz
  20 + )
  21 +
  22 + foreach(f IN LISTS possible_file_locations)
  23 + if(EXISTS ${f})
  24 + set(kaldi_decoder_URL "${f}")
  25 + file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL)
  26 + message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}")
  27 + set(kaldi_decoder_URL2 )
  28 + break()
  29 + endif()
  30 + endforeach()
  31 +
  32 + FetchContent_Declare(kaldi_decoder
  33 + URL
  34 + ${kaldi_decoder_URL}
  35 + ${kaldi_decoder_URL2}
  36 + URL_HASH ${kaldi_decoder_HASH}
  37 + )
  38 +
  39 + FetchContent_GetProperties(kaldi_decoder)
  40 + if(NOT kaldi_decoder_POPULATED)
  41 + message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}")
  42 + FetchContent_Populate(kaldi_decoder)
  43 + endif()
  44 + message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}")
  45 + message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}")
  46 +
  47 + include_directories(${kaldi_decoder_SOURCE_DIR})
  48 + add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL)
  49 +
  50 + target_include_directories(kaldi-decoder-core
  51 + INTERFACE
  52 + ${kaldi-decoder_SOURCE_DIR}/
  53 + )
  54 + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
  55 + install(TARGETS
  56 + kaldi-decoder-core
  57 + kaldifst_core
  58 + fst
  59 + DESTINATION ..)
  60 + else()
  61 + install(TARGETS
  62 + kaldi-decoder-core
  63 + kaldifst_core
  64 + fst
  65 + DESTINATION lib)
  66 + endif()
  67 +
  68 + if(WIN32 AND BUILD_SHARED_LIBS)
  69 + install(TARGETS
  70 + kaldi-decoder-core
  71 + kaldifst_core
  72 + fst
  73 + DESTINATION bin)
  74 + endif()
  75 +endfunction()
  76 +
  77 +download_kaldi_decoder()
  78 +
  1 +function(download_kaldifst)
  2 + include(FetchContent)
  3 +
  4 + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz")
  5 + set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz")
  6 + set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2")
  7 +
  8 + # If you don't have access to the Internet,
  9 + # please pre-download kaldifst
  10 + set(possible_file_locations
  11 + $ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz
  12 + ${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz
  13 + ${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz
  14 + /tmp/kaldifst-1.7.6.tar.gz
  15 + /star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz
  16 + )
  17 +
  18 + foreach(f IN LISTS possible_file_locations)
  19 + if(EXISTS ${f})
  20 + set(kaldifst_URL "${f}")
  21 + file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL)
  22 + message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}")
  23 + set(kaldifst_URL2)
  24 + break()
  25 + endif()
  26 + endforeach()
  27 +
  28 + set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE)
  29 + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
  30 +
  31 + FetchContent_Declare(kaldifst
  32 + URL ${kaldifst_URL}
  33 + URL_HASH ${kaldifst_HASH}
  34 + )
  35 +
  36 + FetchContent_GetProperties(kaldifst)
  37 + if(NOT kaldifst_POPULATED)
  38 + message(STATUS "Downloading kaldifst from ${kaldifst_URL}")
  39 + FetchContent_Populate(kaldifst)
  40 + endif()
  41 + message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}")
  42 + message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}")
  43 +
  44 + list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake)
  45 +
  46 + add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL)
  47 +
  48 + target_include_directories(kaldifst_core
  49 + PUBLIC
  50 + ${kaldifst_SOURCE_DIR}/
  51 + )
  52 +
  53 + target_include_directories(fst
  54 + PUBLIC
  55 + ${openfst_SOURCE_DIR}/src/include
  56 + )
  57 +
  58 + set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-onnx-kaldifst-core")
  59 + set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst")
  60 +endfunction()
  61 +
  62 +download_kaldifst()
@@ -13,4 +13,4 @@ Cflags: -I"${includedir}" @@ -13,4 +13,4 @@ Cflags: -I"${includedir}"
13 # Note: -lcargs is required only for the following file 13 # Note: -lcargs is required only for the following file
14 # https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c 14 # https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c
15 # We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c 15 # We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c
16 -Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ 16 +Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
@@ -9,6 +9,9 @@ @@ -9,6 +9,9 @@
9 sherpa-onnx-portaudio_static.lib; 9 sherpa-onnx-portaudio_static.lib;
10 sherpa-onnx-c-api.lib; 10 sherpa-onnx-c-api.lib;
11 sherpa-onnx-core.lib; 11 sherpa-onnx-core.lib;
  12 + kaldi-decoder-core.lib;
  13 + sherpa-onnx-kaldifst-core.lib;
  14 + sherpa-onnx-fst.lib;
12 kaldi-native-fbank-core.lib; 15 kaldi-native-fbank-core.lib;
13 absl_base.lib; 16 absl_base.lib;
14 absl_city.lib; 17 absl_city.lib;
@@ -9,6 +9,9 @@ @@ -9,6 +9,9 @@
9 sherpa-onnx-portaudio_static.lib; 9 sherpa-onnx-portaudio_static.lib;
10 sherpa-onnx-c-api.lib; 10 sherpa-onnx-c-api.lib;
11 sherpa-onnx-core.lib; 11 sherpa-onnx-core.lib;
  12 + kaldi-decoder-core.lib;
  13 + sherpa-onnx-kaldifst-core.lib;
  14 + sherpa-onnx-fst.lib;
12 kaldi-native-fbank-core.lib; 15 kaldi-native-fbank-core.lib;
13 absl_base.lib; 16 absl_base.lib;
14 absl_city.lib; 17 absl_city.lib;
@@ -19,6 +19,8 @@ set(sources @@ -19,6 +19,8 @@ set(sources
19 features.cc 19 features.cc
20 file-utils.cc 20 file-utils.cc
21 hypothesis.cc 21 hypothesis.cc
  22 + offline-ctc-fst-decoder-config.cc
  23 + offline-ctc-fst-decoder.cc
22 offline-ctc-greedy-search-decoder.cc 24 offline-ctc-greedy-search-decoder.cc
23 offline-ctc-model.cc 25 offline-ctc-model.cc
24 offline-lm-config.cc 26 offline-lm-config.cc
@@ -42,6 +44,8 @@ set(sources @@ -42,6 +44,8 @@ set(sources
42 offline-whisper-greedy-search-decoder.cc 44 offline-whisper-greedy-search-decoder.cc
43 offline-whisper-model-config.cc 45 offline-whisper-model-config.cc
44 offline-whisper-model.cc 46 offline-whisper-model.cc
  47 + offline-zipformer-ctc-model-config.cc
  48 + offline-zipformer-ctc-model.cc
45 online-conformer-transducer-model.cc 49 online-conformer-transducer-model.cc
46 online-lm-config.cc 50 online-lm-config.cc
47 online-lm.cc 51 online-lm.cc
@@ -97,6 +101,8 @@ endif() @@ -97,6 +101,8 @@ endif()
97 101
98 target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core) 102 target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core)
99 103
  104 +target_link_libraries(sherpa-onnx-core kaldi-decoder-core)
  105 +
100 if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm) 106 if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm)
101 target_link_libraries(sherpa-onnx-core onnxruntime) 107 target_link_libraries(sherpa-onnx-core onnxruntime)
102 else() 108 else()
  1 +// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
  6 +
  7 +#include <sstream>
  8 +#include <string>
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +std::string OfflineCtcFstDecoderConfig::ToString() const {
  13 + std::ostringstream os;
  14 +
  15 + os << "OfflineCtcFstDecoderConfig(";
  16 + os << "graph=\"" << graph << "\", ";
  17 + os << "max_active=" << max_active << ")";
  18 +
  19 + return os.str();
  20 +}
  21 +
  22 +void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
  23 + std::string prefix = "ctc";
  24 + ParseOptions p(prefix, po);
  25 +
  26 + p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
  27 +
  28 + p.Register("max-active", &max_active,
  29 + "Decoder max active states. Larger->slower; more accurate");
  30 +}
  31 +
  32 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
  7 +
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/parse-options.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineCtcFstDecoderConfig {
  15 + // Path to H.fst, HL.fst or HLG.fst
  16 + std::string graph;
  17 + int32_t max_active = 3000;
  18 +
  19 + OfflineCtcFstDecoderConfig() = default;
  20 +
  21 + OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
  22 + : graph(graph), max_active(max_active) {}
  23 +
  24 + std::string ToString() const;
  25 +
  26 + void Register(ParseOptions *po);
  27 +};
  28 +
  29 +} // namespace sherpa_onnx
  30 +
  31 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
  6 +
  7 +#include <string>
  8 +#include <utility>
  9 +
  10 +#include "fst/fstlib.h"
  11 +#include "kaldi-decoder/csrc/decodable-ctc.h"
  12 +#include "kaldi-decoder/csrc/eigen.h"
  13 +#include "kaldi-decoder/csrc/faster-decoder.h"
  14 +#include "sherpa-onnx/csrc/macros.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +// This function is copied from kaldi.
  19 +//
  20 +// @param filename Path to a StdVectorFst or StdConstFst graph
  21 +// @return The caller should free the returned pointer using `delete` to
  22 +// avoid memory leak.
  23 +static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
  24 + // read decoding network FST
  25 + std::ifstream is(filename, std::ios::binary);
  26 + if (!is.good()) {
  27 + SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str());
  28 + }
  29 +
  30 + fst::FstHeader hdr;
  31 + if (!hdr.Read(is, "<unknown>")) {
  32 + SHERPA_ONNX_LOGE("Reading FST: error reading FST header.");
  33 + }
  34 +
  35 + if (hdr.ArcType() != fst::StdArc::Type()) {
  36 + SHERPA_ONNX_LOGE("FST with arc type %s not supported",
  37 + hdr.ArcType().c_str());
  38 + }
  39 + fst::FstReadOptions ropts("<unspecified>", &hdr);
  40 +
  41 + fst::Fst<fst::StdArc> *decode_fst = nullptr;
  42 +
  43 + if (hdr.FstType() == "vector") {
  44 + decode_fst = fst::VectorFst<fst::StdArc>::Read(is, ropts);
  45 + } else if (hdr.FstType() == "const") {
  46 + decode_fst = fst::ConstFst<fst::StdArc>::Read(is, ropts);
  47 + } else {
  48 + SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s",
  49 + hdr.FstType().c_str());
  50 + }
  51 +
  52 + if (decode_fst == nullptr) { // fst code will warn.
  53 + SHERPA_ONNX_LOGE("Error reading FST (after reading header).");
  54 + return nullptr;
  55 + } else {
  56 + return decode_fst;
  57 + }
  58 +}
  59 +
  60 +/**
  61 + * @param decoder
  62 + * @param p Pointer to a 2-d array of shape (num_frames, vocab_size)
  63 + * @param num_frames Number of rows in the 2-d array.
  64 + * @param vocab_size Number of columns in the 2-d array.
  65 + * @return Return the decoded result.
  66 + */
  67 +static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
  68 + const float *p, int32_t num_frames,
  69 + int32_t vocab_size) {
  70 + OfflineCtcDecoderResult r;
  71 + kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size);
  72 +
  73 + decoder->Decode(&decodable);
  74 +
  75 + if (!decoder->ReachedFinal()) {
  76 + SHERPA_ONNX_LOGE("Not reached final!");
  77 + return r;
  78 + }
  79 +
  80 + fst::VectorFst<fst::LatticeArc> decoded; // linear FST.
  81 + decoder->GetBestPath(&decoded);
  82 +
  83 + if (decoded.NumStates() == 0) {
  84 + SHERPA_ONNX_LOGE("Empty best path!");
  85 + return r;
  86 + }
  87 +
  88 + auto cur_state = decoded.Start();
  89 +
  90 + int32_t blank_id = 0;
  91 +
  92 + for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) {
  93 + fst::ArcIterator<fst::Fst<fst::LatticeArc>> iter(decoded, cur_state);
  94 + const auto &arc = iter.Value();
  95 +
  96 + cur_state = arc.nextstate;
  97 +
  98 + if (arc.ilabel == prev) {
  99 + continue;
  100 + }
  101 +
  102 + // 0 is epsilon here
  103 + if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) {
  104 + prev = arc.ilabel;
  105 + continue;
  106 + }
  107 +
  108 + // -1 here since the input labels are incremented during graph
  109 + // construction
  110 + r.tokens.push_back(arc.ilabel - 1);
  111 +
  112 + r.timestamps.push_back(t);
  113 + prev = arc.ilabel;
  114 + }
  115 +
  116 + return r;
  117 +}
  118 +
  119 +OfflineCtcFstDecoder::OfflineCtcFstDecoder(
  120 + const OfflineCtcFstDecoderConfig &config)
  121 + : config_(config), fst_(ReadGraph(config_.graph)) {}
  122 +
  123 +std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode(
  124 + Ort::Value log_probs, Ort::Value log_probs_length) {
  125 + std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
  126 +
  127 + assert(static_cast<int32_t>(shape.size()) == 3);
  128 + int32_t batch_size = shape[0];
  129 + int32_t T = shape[1];
  130 + int32_t vocab_size = shape[2];
  131 +
  132 + std::vector<int64_t> length_shape =
  133 + log_probs_length.GetTensorTypeAndShapeInfo().GetShape();
  134 + assert(static_cast<int32_t>(length_shape.size()) == 1);
  135 +
  136 + assert(shape[0] == length_shape[0]);
  137 +
  138 + kaldi_decoder::FasterDecoderOptions opts;
  139 + opts.max_active = config_.max_active;
  140 + kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts);
  141 +
  142 + const float *start = log_probs.GetTensorData<float>();
  143 +
  144 + std::vector<OfflineCtcDecoderResult> ans;
  145 + ans.reserve(batch_size);
  146 +
  147 + for (int32_t i = 0; i != batch_size; ++i) {
  148 + const float *p = start + i * T * vocab_size;
  149 + int32_t num_frames = log_probs_length.GetTensorData<int64_t>()[i];
  150 + auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size);
  151 + ans.push_back(std::move(r));
  152 + }
  153 +
  154 + return ans;
  155 +}
  156 +
  157 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-ctc-fst-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
  7 +
  8 +#include <memory>
  9 +#include <vector>
  10 +
  11 +#include "fst/fst.h"
  12 +#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
  13 +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
  14 +#include "sherpa-onnx/csrc/parse-options.h"
  15 +
  16 +namespace sherpa_onnx {
  17 +
  18 +class OfflineCtcFstDecoder : public OfflineCtcDecoder {
  19 + public:
  20 + explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config);
  21 +
  22 + std::vector<OfflineCtcDecoderResult> Decode(
  23 + Ort::Value log_probs, Ort::Value log_probs_length) override;
  24 +
  25 + private:
  26 + OfflineCtcFstDecoderConfig config_;
  27 +
  28 + std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
  29 +};
  30 +
  31 +} // namespace sherpa_onnx
  32 +
  33 +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
@@ -12,6 +12,7 @@ @@ -12,6 +12,7 @@
12 #include "sherpa-onnx/csrc/macros.h" 12 #include "sherpa-onnx/csrc/macros.h"
13 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" 13 #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
14 #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" 14 #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
  15 +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
15 #include "sherpa-onnx/csrc/onnx-utils.h" 16 #include "sherpa-onnx/csrc/onnx-utils.h"
16 17
17 namespace { 18 namespace {
@@ -19,6 +20,7 @@ namespace { @@ -19,6 +20,7 @@ namespace {
19 enum class ModelType { 20 enum class ModelType {
20 kEncDecCTCModelBPE, 21 kEncDecCTCModelBPE,
21 kTdnn, 22 kTdnn,
  23 + kZipformerCtc,
22 kUnkown, 24 kUnkown,
23 }; 25 };
24 26
@@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, @@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
59 return ModelType::kEncDecCTCModelBPE; 61 return ModelType::kEncDecCTCModelBPE;
60 } else if (model_type.get() == std::string("tdnn")) { 62 } else if (model_type.get() == std::string("tdnn")) {
61 return ModelType::kTdnn; 63 return ModelType::kTdnn;
  64 + } else if (model_type.get() == std::string("zipformer2_ctc")) {
  65 + return ModelType::kZipformerCtc;
62 } else { 66 } else {
63 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); 67 SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
64 return ModelType::kUnkown; 68 return ModelType::kUnkown;
@@ -74,6 +78,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -74,6 +78,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
74 filename = config.nemo_ctc.model; 78 filename = config.nemo_ctc.model;
75 } else if (!config.tdnn.model.empty()) { 79 } else if (!config.tdnn.model.empty()) {
76 filename = config.tdnn.model; 80 filename = config.tdnn.model;
  81 + } else if (!config.zipformer_ctc.model.empty()) {
  82 + filename = config.zipformer_ctc.model;
77 } else { 83 } else {
78 SHERPA_ONNX_LOGE("Please specify a CTC model"); 84 SHERPA_ONNX_LOGE("Please specify a CTC model");
79 exit(-1); 85 exit(-1);
@@ -92,6 +98,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -92,6 +98,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
92 case ModelType::kTdnn: 98 case ModelType::kTdnn:
93 return std::make_unique<OfflineTdnnCtcModel>(config); 99 return std::make_unique<OfflineTdnnCtcModel>(config);
94 break; 100 break;
  101 + case ModelType::kZipformerCtc:
  102 + return std::make_unique<OfflineZipformerCtcModel>(config);
  103 + break;
95 case ModelType::kUnkown: 104 case ModelType::kUnkown:
96 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 105 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
97 return nullptr; 106 return nullptr;
@@ -111,6 +120,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -111,6 +120,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
111 filename = config.nemo_ctc.model; 120 filename = config.nemo_ctc.model;
112 } else if (!config.tdnn.model.empty()) { 121 } else if (!config.tdnn.model.empty()) {
113 filename = config.tdnn.model; 122 filename = config.tdnn.model;
  123 + } else if (!config.zipformer_ctc.model.empty()) {
  124 + filename = config.zipformer_ctc.model;
114 } else { 125 } else {
115 SHERPA_ONNX_LOGE("Please specify a CTC model"); 126 SHERPA_ONNX_LOGE("Please specify a CTC model");
116 exit(-1); 127 exit(-1);
@@ -129,6 +140,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( @@ -129,6 +140,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
129 case ModelType::kTdnn: 140 case ModelType::kTdnn:
130 return std::make_unique<OfflineTdnnCtcModel>(mgr, config); 141 return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
131 break; 142 break;
  143 + case ModelType::kZipformerCtc:
  144 + return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
  145 + break;
132 case ModelType::kUnkown: 146 case ModelType::kUnkown:
133 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); 147 SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
134 return nullptr; 148 return nullptr;
@@ -6,7 +6,7 @@ @@ -6,7 +6,7 @@
6 6
7 #include <memory> 7 #include <memory>
8 #include <string> 8 #include <string>
9 -#include <utility> 9 +#include <vector>
10 10
11 #if __ANDROID_API__ >= 9 11 #if __ANDROID_API__ >= 9
12 #include "android/asset_manager.h" 12 #include "android/asset_manager.h"
@@ -32,17 +32,17 @@ class OfflineCtcModel { @@ -32,17 +32,17 @@ class OfflineCtcModel {
32 32
33 /** Run the forward method of the model. 33 /** Run the forward method of the model.
34 * 34 *
35 - * @param features A tensor of shape (N, T, C). It is changed in-place. 35 + * @param features A tensor of shape (N, T, C).
36 * @param features_length A 1-D tensor of shape (N,) containing number of 36 * @param features_length A 1-D tensor of shape (N,) containing number of
37 * valid frames in `features` before padding. 37 * valid frames in `features` before padding.
38 * Its dtype is int64_t. 38 * Its dtype is int64_t.
39 * 39 *
40 - * @return Return a pair containing: 40 + * @return Return a vector containing:
41 * - log_probs: A 3-D tensor of shape (N, T', vocab_size). 41 * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
42 * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t 42 * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
43 */ 43 */
44 - virtual std::pair<Ort::Value, Ort::Value> Forward(  
45 - Ort::Value features, Ort::Value features_length) = 0; 44 + virtual std::vector<Ort::Value> Forward(Ort::Value features,
  45 + Ort::Value features_length) = 0;
46 46
47 /** Return the vocabulary size of the model 47 /** Return the vocabulary size of the model
48 */ 48 */
@@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
16 nemo_ctc.Register(po); 16 nemo_ctc.Register(po);
17 whisper.Register(po); 17 whisper.Register(po);
18 tdnn.Register(po); 18 tdnn.Register(po);
  19 + zipformer_ctc.Register(po);
19 20
20 po->Register("tokens", &tokens, "Path to tokens.txt"); 21 po->Register("tokens", &tokens, "Path to tokens.txt");
21 22
@@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { @@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
31 po->Register("model-type", &model_type, 32 po->Register("model-type", &model_type,
32 "Specify it to reduce model initialization time. " 33 "Specify it to reduce model initialization time. "
33 "Valid values are: transducer, paraformer, nemo_ctc, whisper, " 34 "Valid values are: transducer, paraformer, nemo_ctc, whisper, "
34 - "tdnn." 35 + "tdnn, zipformer2_ctc"
35 "All other values lead to loading the model twice."); 36 "All other values lead to loading the model twice.");
36 } 37 }
37 38
@@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const { @@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const {
62 return tdnn.Validate(); 63 return tdnn.Validate();
63 } 64 }
64 65
  66 + if (!zipformer_ctc.model.empty()) {
  67 + return zipformer_ctc.Validate();
  68 + }
  69 +
65 return transducer.Validate(); 70 return transducer.Validate();
66 } 71 }
67 72
@@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const { @@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const {
74 os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; 79 os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
75 os << "whisper=" << whisper.ToString() << ", "; 80 os << "whisper=" << whisper.ToString() << ", ";
76 os << "tdnn=" << tdnn.ToString() << ", "; 81 os << "tdnn=" << tdnn.ToString() << ", ";
  82 + os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
77 os << "tokens=\"" << tokens << "\", "; 83 os << "tokens=\"" << tokens << "\", ";
78 os << "num_threads=" << num_threads << ", "; 84 os << "num_threads=" << num_threads << ", ";
79 os << "debug=" << (debug ? "True" : "False") << ", "; 85 os << "debug=" << (debug ? "True" : "False") << ", ";
@@ -11,6 +11,7 @@ @@ -11,6 +11,7 @@
11 #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" 11 #include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
12 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 12 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
13 #include "sherpa-onnx/csrc/offline-whisper-model-config.h" 13 #include "sherpa-onnx/csrc/offline-whisper-model-config.h"
  14 +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
14 15
15 namespace sherpa_onnx { 16 namespace sherpa_onnx {
16 17
@@ -20,6 +21,7 @@ struct OfflineModelConfig { @@ -20,6 +21,7 @@ struct OfflineModelConfig {
20 OfflineNemoEncDecCtcModelConfig nemo_ctc; 21 OfflineNemoEncDecCtcModelConfig nemo_ctc;
21 OfflineWhisperModelConfig whisper; 22 OfflineWhisperModelConfig whisper;
22 OfflineTdnnModelConfig tdnn; 23 OfflineTdnnModelConfig tdnn;
  24 + OfflineZipformerCtcModelConfig zipformer_ctc;
23 25
24 std::string tokens; 26 std::string tokens;
25 int32_t num_threads = 2; 27 int32_t num_threads = 2;
@@ -43,6 +45,7 @@ struct OfflineModelConfig { @@ -43,6 +45,7 @@ struct OfflineModelConfig {
43 const OfflineNemoEncDecCtcModelConfig &nemo_ctc, 45 const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
44 const OfflineWhisperModelConfig &whisper, 46 const OfflineWhisperModelConfig &whisper,
45 const OfflineTdnnModelConfig &tdnn, 47 const OfflineTdnnModelConfig &tdnn,
  48 + const OfflineZipformerCtcModelConfig &zipformer_ctc,
46 const std::string &tokens, int32_t num_threads, bool debug, 49 const std::string &tokens, int32_t num_threads, bool debug,
47 const std::string &provider, const std::string &model_type) 50 const std::string &provider, const std::string &model_type)
48 : transducer(transducer), 51 : transducer(transducer),
@@ -50,6 +53,7 @@ struct OfflineModelConfig { @@ -50,6 +53,7 @@ struct OfflineModelConfig {
50 nemo_ctc(nemo_ctc), 53 nemo_ctc(nemo_ctc),
51 whisper(whisper), 54 whisper(whisper),
52 tdnn(tdnn), 55 tdnn(tdnn),
  56 + zipformer_ctc(zipformer_ctc),
53 tokens(tokens), 57 tokens(tokens),
54 num_threads(num_threads), 58 num_threads(num_threads),
55 debug(debug), 59 debug(debug),
@@ -34,7 +34,7 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -34,7 +34,7 @@ class OfflineNemoEncDecCtcModel::Impl {
34 } 34 }
35 #endif 35 #endif
36 36
37 - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, 37 + std::vector<Ort::Value> Forward(Ort::Value features,
38 Ort::Value features_length) { 38 Ort::Value features_length) {
39 std::vector<int64_t> shape = 39 std::vector<int64_t> shape =
40 features_length.GetTensorTypeAndShapeInfo().GetShape(); 40 features_length.GetTensorTypeAndShapeInfo().GetShape();
@@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl { @@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl {
57 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), 57 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
58 output_names_ptr_.data(), output_names_ptr_.size()); 58 output_names_ptr_.data(), output_names_ptr_.size());
59 59
60 - return {std::move(out[0]), std::move(out_features_length)}; 60 + std::vector<Ort::Value> ans;
  61 + ans.reserve(2);
  62 + ans.push_back(std::move(out[0]));
  63 + ans.push_back(std::move(out_features_length));
  64 + return ans;
61 } 65 }
62 66
63 int32_t VocabSize() const { return vocab_size_; } 67 int32_t VocabSize() const { return vocab_size_; }
@@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( @@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
122 126
123 OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; 127 OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
124 128
125 -std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward( 129 +std::vector<Ort::Value> OfflineNemoEncDecCtcModel::Forward(
126 Ort::Value features, Ort::Value features_length) { 130 Ort::Value features, Ort::Value features_length) {
127 return impl_->Forward(std::move(features), std::move(features_length)); 131 return impl_->Forward(std::move(features), std::move(features_length));
128 } 132 }
@@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { @@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
38 38
39 /** Run the forward method of the model. 39 /** Run the forward method of the model.
40 * 40 *
41 - * @param features A tensor of shape (N, T, C). It is changed in-place. 41 + * @param features A tensor of shape (N, T, C).
42 * @param features_length A 1-D tensor of shape (N,) containing number of 42 * @param features_length A 1-D tensor of shape (N,) containing number of
43 * valid frames in `features` before padding. 43 * valid frames in `features` before padding.
44 * Its dtype is int64_t. 44 * Its dtype is int64_t.
45 * 45 *
46 - * @return Return a pair containing: 46 + * @return Return a vector containing:
47 * - log_probs: A 3-D tensor of shape (N, T', vocab_size). 47 * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
48 * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t 48 * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
49 */ 49 */
50 - std::pair<Ort::Value, Ort::Value> Forward(  
51 - Ort::Value features, Ort::Value features_length) override; 50 + std::vector<Ort::Value> Forward(Ort::Value features,
  51 + Ort::Value features_length) override;
52 52
53 /** Return the vocabulary size of the model 53 /** Return the vocabulary size of the model
54 */ 54 */
@@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
16 #endif 16 #endif
17 17
18 #include "sherpa-onnx/csrc/offline-ctc-decoder.h" 18 #include "sherpa-onnx/csrc/offline-ctc-decoder.h"
  19 +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
19 #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" 20 #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
20 #include "sherpa-onnx/csrc/offline-ctc-model.h" 21 #include "sherpa-onnx/csrc/offline-ctc-model.h"
21 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 22 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
@@ -25,9 +26,12 @@ @@ -25,9 +26,12 @@
25 namespace sherpa_onnx { 26 namespace sherpa_onnx {
26 27
27 static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, 28 static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
28 - const SymbolTable &sym_table) { 29 + const SymbolTable &sym_table,
  30 + int32_t frame_shift_ms,
  31 + int32_t subsampling_factor) {
29 OfflineRecognitionResult r; 32 OfflineRecognitionResult r;
30 r.tokens.reserve(src.tokens.size()); 33 r.tokens.reserve(src.tokens.size());
  34 + r.timestamps.reserve(src.timestamps.size());
31 35
32 std::string text; 36 std::string text;
33 37
@@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, @@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
42 } 46 }
43 r.text = std::move(text); 47 r.text = std::move(text);
44 48
  49 + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
  50 + for (auto t : src.timestamps) {
  51 + float time = frame_shift_s * t;
  52 + r.timestamps.push_back(time);
  53 + }
  54 +
45 return r; 55 return r;
46 } 56 }
47 57
@@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
68 config_.feat_config.nemo_normalize_type = 78 config_.feat_config.nemo_normalize_type =
69 model_->FeatureNormalizationMethod(); 79 model_->FeatureNormalizationMethod();
70 80
71 - if (config_.decoding_method == "greedy_search") { 81 + if (!config_.ctc_fst_decoder_config.graph.empty()) {
  82 + // TODO(fangjun): Support android to read the graph from
  83 + // asset_manager
  84 + decoder_ = std::make_unique<OfflineCtcFstDecoder>(
  85 + config_.ctc_fst_decoder_config);
  86 + } else if (config_.decoding_method == "greedy_search") {
72 if (!symbol_table_.contains("<blk>") && 87 if (!symbol_table_.contains("<blk>") &&
73 !symbol_table_.contains("<eps>")) { 88 !symbol_table_.contains("<eps>")) {
74 SHERPA_ONNX_LOGE( 89 SHERPA_ONNX_LOGE(
@@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { @@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
139 -23.025850929940457f); 154 -23.025850929940457f);
140 auto t = model_->Forward(std::move(x), std::move(x_length)); 155 auto t = model_->Forward(std::move(x), std::move(x_length));
141 156
142 - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); 157 + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
143 158
  159 + int32_t frame_shift_ms = 10;
144 for (int32_t i = 0; i != n; ++i) { 160 for (int32_t i = 0; i != n; ++i) {
145 - auto r = Convert(results[i], symbol_table_); 161 + auto r = Convert(results[i], symbol_table_, frame_shift_ms,
  162 + model_->SubsamplingFactor());
146 ss[i]->SetResult(r); 163 ss[i]->SetResult(r);
147 } 164 }
148 } 165 }
@@ -25,9 +25,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -25,9 +25,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
25 return std::make_unique<OfflineRecognizerTransducerImpl>(config); 25 return std::make_unique<OfflineRecognizerTransducerImpl>(config);
26 } else if (model_type == "paraformer") { 26 } else if (model_type == "paraformer") {
27 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 27 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
28 - } else if (model_type == "nemo_ctc") {  
29 - return std::make_unique<OfflineRecognizerCtcImpl>(config);  
30 - } else if (model_type == "tdnn") { 28 + } else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
  29 + model_type == "zipformer2_ctc") {
31 return std::make_unique<OfflineRecognizerCtcImpl>(config); 30 return std::make_unique<OfflineRecognizerCtcImpl>(config);
32 } else if (model_type == "whisper") { 31 } else if (model_type == "whisper") {
33 return std::make_unique<OfflineRecognizerWhisperImpl>(config); 32 return std::make_unique<OfflineRecognizerWhisperImpl>(config);
@@ -50,6 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -50,6 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
50 model_filename = config.model_config.nemo_ctc.model; 49 model_filename = config.model_config.nemo_ctc.model;
51 } else if (!config.model_config.tdnn.model.empty()) { 50 } else if (!config.model_config.tdnn.model.empty()) {
52 model_filename = config.model_config.tdnn.model; 51 model_filename = config.model_config.tdnn.model;
  52 + } else if (!config.model_config.zipformer_ctc.model.empty()) {
  53 + model_filename = config.model_config.zipformer_ctc.model;
53 } else if (!config.model_config.whisper.encoder.empty()) { 54 } else if (!config.model_config.whisper.encoder.empty()) {
54 model_filename = config.model_config.whisper.encoder; 55 model_filename = config.model_config.whisper.encoder;
55 } else { 56 } else {
@@ -93,6 +94,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -93,6 +94,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
93 "\n " 94 "\n "
94 "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" 95 "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
95 "\n" 96 "\n"
  97 + "(5) Zipformer CTC models from icefall"
  98 + "\n "
  99 + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
  100 + "zipformer/export-onnx-ctc.py"
  101 + "\n"
96 "\n"); 102 "\n");
97 exit(-1); 103 exit(-1);
98 } 104 }
@@ -107,11 +113,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -107,11 +113,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
107 return std::make_unique<OfflineRecognizerParaformerImpl>(config); 113 return std::make_unique<OfflineRecognizerParaformerImpl>(config);
108 } 114 }
109 115
110 - if (model_type == "EncDecCTCModelBPE") {  
111 - return std::make_unique<OfflineRecognizerCtcImpl>(config);  
112 - }  
113 -  
114 - if (model_type == "tdnn") { 116 + if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
  117 + model_type == "zipformer2_ctc") {
115 return std::make_unique<OfflineRecognizerCtcImpl>(config); 118 return std::make_unique<OfflineRecognizerCtcImpl>(config);
116 } 119 }
117 120
@@ -126,7 +129,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -126,7 +129,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
126 " - Non-streaming Paraformer models from FunASR\n" 129 " - Non-streaming Paraformer models from FunASR\n"
127 " - EncDecCTCModelBPE models from NeMo\n" 130 " - EncDecCTCModelBPE models from NeMo\n"
128 " - Whisper models\n" 131 " - Whisper models\n"
129 - " - Tdnn models\n", 132 + " - Tdnn models\n"
  133 + " - Zipformer CTC models\n",
130 model_type.c_str()); 134 model_type.c_str());
131 135
132 exit(-1); 136 exit(-1);
@@ -141,9 +145,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -141,9 +145,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
141 return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config); 145 return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
142 } else if (model_type == "paraformer") { 146 } else if (model_type == "paraformer") {
143 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); 147 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
144 - } else if (model_type == "nemo_ctc") {  
145 - return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);  
146 - } else if (model_type == "tdnn") { 148 + } else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
  149 + model_type == "zipformer2_ctc") {
147 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); 150 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
148 } else if (model_type == "whisper") { 151 } else if (model_type == "whisper") {
149 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); 152 return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
@@ -166,6 +169,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -166,6 +169,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
166 model_filename = config.model_config.nemo_ctc.model; 169 model_filename = config.model_config.nemo_ctc.model;
167 } else if (!config.model_config.tdnn.model.empty()) { 170 } else if (!config.model_config.tdnn.model.empty()) {
168 model_filename = config.model_config.tdnn.model; 171 model_filename = config.model_config.tdnn.model;
  172 + } else if (!config.model_config.zipformer_ctc.model.empty()) {
  173 + model_filename = config.model_config.zipformer_ctc.model;
169 } else if (!config.model_config.whisper.encoder.empty()) { 174 } else if (!config.model_config.whisper.encoder.empty()) {
170 model_filename = config.model_config.whisper.encoder; 175 model_filename = config.model_config.whisper.encoder;
171 } else { 176 } else {
@@ -209,6 +214,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -209,6 +214,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
209 "\n " 214 "\n "
210 "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" 215 "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
211 "\n" 216 "\n"
  217 + "(5) Zipformer CTC models from icefall"
  218 + "\n "
  219 + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
  220 + "zipformer/export-onnx-ctc.py"
  221 + "\n"
212 "\n"); 222 "\n");
213 exit(-1); 223 exit(-1);
214 } 224 }
@@ -223,11 +233,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -223,11 +233,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
223 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); 233 return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
224 } 234 }
225 235
226 - if (model_type == "EncDecCTCModelBPE") {  
227 - return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);  
228 - }  
229 -  
230 - if (model_type == "tdnn") { 236 + if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
  237 + model_type == "zipformer2_ctc") {
231 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); 238 return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
232 } 239 }
233 240
@@ -242,7 +249,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -242,7 +249,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
242 " - Non-streaming Paraformer models from FunASR\n" 249 " - Non-streaming Paraformer models from FunASR\n"
243 " - EncDecCTCModelBPE models from NeMo\n" 250 " - EncDecCTCModelBPE models from NeMo\n"
244 " - Whisper models\n" 251 " - Whisper models\n"
245 - " - Tdnn models\n", 252 + " - Tdnn models\n"
  253 + " - Zipformer CTC models\n",
246 model_type.c_str()); 254 model_type.c_str());
247 255
248 exit(-1); 256 exit(-1);
@@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { @@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
17 feat_config.Register(po); 17 feat_config.Register(po);
18 model_config.Register(po); 18 model_config.Register(po);
19 lm_config.Register(po); 19 lm_config.Register(po);
  20 + ctc_fst_decoder_config.Register(po);
20 21
21 po->Register( 22 po->Register(
22 "decoding-method", &decoding_method, 23 "decoding-method", &decoding_method,
@@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const { @@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const {
69 os << "feat_config=" << feat_config.ToString() << ", "; 70 os << "feat_config=" << feat_config.ToString() << ", ";
70 os << "model_config=" << model_config.ToString() << ", "; 71 os << "model_config=" << model_config.ToString() << ", ";
71 os << "lm_config=" << lm_config.ToString() << ", "; 72 os << "lm_config=" << lm_config.ToString() << ", ";
  73 + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
72 os << "decoding_method=\"" << decoding_method << "\", "; 74 os << "decoding_method=\"" << decoding_method << "\", ";
73 os << "max_active_paths=" << max_active_paths << ", "; 75 os << "max_active_paths=" << max_active_paths << ", ";
74 os << "hotwords_file=\"" << hotwords_file << "\", "; 76 os << "hotwords_file=\"" << hotwords_file << "\", ";
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 #include "android/asset_manager_jni.h" 14 #include "android/asset_manager_jni.h"
15 #endif 15 #endif
16 16
  17 +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
17 #include "sherpa-onnx/csrc/offline-lm-config.h" 18 #include "sherpa-onnx/csrc/offline-lm-config.h"
18 #include "sherpa-onnx/csrc/offline-model-config.h" 19 #include "sherpa-onnx/csrc/offline-model-config.h"
19 #include "sherpa-onnx/csrc/offline-stream.h" 20 #include "sherpa-onnx/csrc/offline-stream.h"
@@ -28,6 +29,7 @@ struct OfflineRecognizerConfig { @@ -28,6 +29,7 @@ struct OfflineRecognizerConfig {
28 OfflineFeatureExtractorConfig feat_config; 29 OfflineFeatureExtractorConfig feat_config;
29 OfflineModelConfig model_config; 30 OfflineModelConfig model_config;
30 OfflineLMConfig lm_config; 31 OfflineLMConfig lm_config;
  32 + OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
31 33
32 std::string decoding_method = "greedy_search"; 34 std::string decoding_method = "greedy_search";
33 int32_t max_active_paths = 4; 35 int32_t max_active_paths = 4;
@@ -39,16 +41,16 @@ struct OfflineRecognizerConfig { @@ -39,16 +41,16 @@ struct OfflineRecognizerConfig {
39 // TODO(fangjun): Implement modified_beam_search 41 // TODO(fangjun): Implement modified_beam_search
40 42
41 OfflineRecognizerConfig() = default; 43 OfflineRecognizerConfig() = default;
42 - OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,  
43 - const OfflineModelConfig &model_config,  
44 - const OfflineLMConfig &lm_config,  
45 - const std::string &decoding_method,  
46 - int32_t max_active_paths,  
47 - const std::string &hotwords_file,  
48 - float hotwords_score) 44 + OfflineRecognizerConfig(
  45 + const OfflineFeatureExtractorConfig &feat_config,
  46 + const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
  47 + const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
  48 + const std::string &decoding_method, int32_t max_active_paths,
  49 + const std::string &hotwords_file, float hotwords_score)
49 : feat_config(feat_config), 50 : feat_config(feat_config),
50 model_config(model_config), 51 model_config(model_config),
51 lm_config(lm_config), 52 lm_config(lm_config),
  53 + ctc_fst_decoder_config(ctc_fst_decoder_config),
52 decoding_method(decoding_method), 54 decoding_method(decoding_method),
53 max_active_paths(max_active_paths), 55 max_active_paths(max_active_paths),
54 hotwords_file(hotwords_file), 56 hotwords_file(hotwords_file),
@@ -4,6 +4,8 @@ @@ -4,6 +4,8 @@
4 4
5 #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" 5 #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
6 6
  7 +#include <utility>
  8 +
7 #include "sherpa-onnx/csrc/macros.h" 9 #include "sherpa-onnx/csrc/macros.h"
8 #include "sherpa-onnx/csrc/onnx-utils.h" 10 #include "sherpa-onnx/csrc/onnx-utils.h"
9 #include "sherpa-onnx/csrc/session.h" 11 #include "sherpa-onnx/csrc/session.h"
@@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl { @@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl {
34 } 36 }
35 #endif 37 #endif
36 38
37 - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) { 39 + std::vector<Ort::Value> Forward(Ort::Value features) {
38 auto nnet_out = 40 auto nnet_out =
39 sess_->Run({}, input_names_ptr_.data(), &features, 1, 41 sess_->Run({}, input_names_ptr_.data(), &features, 1,
40 output_names_ptr_.data(), output_names_ptr_.size()); 42 output_names_ptr_.data(), output_names_ptr_.size());
@@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl { @@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl {
52 memory_info, out_length_vec.data(), out_length_vec.size(), 54 memory_info, out_length_vec.data(), out_length_vec.size(),
53 out_length_shape.data(), out_length_shape.size()); 55 out_length_shape.data(), out_length_shape.size());
54 56
55 - return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; 57 + std::vector<Ort::Value> ans;
  58 + ans.reserve(2);
  59 + ans.push_back(std::move(nnet_out[0]));
  60 + ans.push_back(Clone(Allocator(), &nnet_out_length));
  61 + return ans;
56 } 62 }
57 63
58 int32_t VocabSize() const { return vocab_size_; } 64 int32_t VocabSize() const { return vocab_size_; }
@@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr, @@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr,
108 114
109 OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; 115 OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
110 116
111 -std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward( 117 +std::vector<Ort::Value> OfflineTdnnCtcModel::Forward(
112 Ort::Value features, Ort::Value /*features_length*/) { 118 Ort::Value features, Ort::Value /*features_length*/) {
113 return impl_->Forward(std::move(features)); 119 return impl_->Forward(std::move(features));
114 } 120 }
@@ -5,7 +5,6 @@ @@ -5,7 +5,6 @@
5 #define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ 5 #define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
6 #include <memory> 6 #include <memory>
7 #include <string> 7 #include <string>
8 -#include <utility>  
9 #include <vector> 8 #include <vector>
10 9
11 #if __ANDROID_API__ >= 9 10 #if __ANDROID_API__ >= 9
@@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { @@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel {
36 35
37 /** Run the forward method of the model. 36 /** Run the forward method of the model.
38 * 37 *
39 - * @param features A tensor of shape (N, T, C). It is changed in-place. 38 + * @param features A tensor of shape (N, T, C).
40 * @param features_length A 1-D tensor of shape (N,) containing number of 39 * @param features_length A 1-D tensor of shape (N,) containing number of
41 * valid frames in `features` before padding. 40 * valid frames in `features` before padding.
42 * Its dtype is int64_t. 41 * Its dtype is int64_t.
@@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { @@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel {
45 * - log_probs: A 3-D tensor of shape (N, T', vocab_size). 44 * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
46 * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t 45 * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
47 */ 46 */
48 - std::pair<Ort::Value, Ort::Value> Forward(  
49 - Ort::Value features, Ort::Value /*features_length*/) override; 47 + std::vector<Ort::Value> Forward(Ort::Value features,
  48 + Ort::Value /*features_length*/) override;
50 49
51 /** Return the vocabulary size of the model 50 /** Return the vocabulary size of the model
52 */ 51 */
  1 +// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) {
  13 + po->Register("zipformer-ctc-model", &model, "Path to zipformer CTC model");
  14 +}
  15 +
  16 +bool OfflineZipformerCtcModelConfig::Validate() const {
  17 + if (!FileExists(model)) {
  18 + SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist",
  19 + model.c_str());
  20 + return false;
  21 + }
  22 +
  23 + return true;
  24 +}
  25 +
  26 +std::string OfflineZipformerCtcModelConfig::ToString() const {
  27 + std::ostringstream os;
  28 +
  29 + os << "OfflineZipformerCtcModelConfig(";
  30 + os << "model=\"" << model << "\")";
  31 +
  32 + return os.str();
  33 +}
  34 +
  35 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +// for
  14 +// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
  15 +struct OfflineZipformerCtcModelConfig {
  16 + std::string model;
  17 +
  18 + OfflineZipformerCtcModelConfig() = default;
  19 +
  20 + explicit OfflineZipformerCtcModelConfig(const std::string &model)
  21 + : model(model) {}
  22 +
  23 + void Register(ParseOptions *po);
  24 +
  25 + bool Validate() const;
  26 +
  27 + std::string ToString() const;
  28 +};
  29 +
  30 +} // namespace sherpa_onnx
  31 +
  32 +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-zipformer-ctc-model.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
  6 +
  7 +#include "sherpa-onnx/csrc/macros.h"
  8 +#include "sherpa-onnx/csrc/onnx-utils.h"
  9 +#include "sherpa-onnx/csrc/session.h"
  10 +#include "sherpa-onnx/csrc/text-utils.h"
  11 +#include "sherpa-onnx/csrc/transpose.h"
  12 +
  13 +namespace sherpa_onnx {
  14 +
  15 +class OfflineZipformerCtcModel::Impl {
  16 + public:
  17 + explicit Impl(const OfflineModelConfig &config)
  18 + : config_(config),
  19 + env_(ORT_LOGGING_LEVEL_ERROR),
  20 + sess_opts_(GetSessionOptions(config)),
  21 + allocator_{} {
  22 + auto buf = ReadFile(config_.zipformer_ctc.model);
  23 + Init(buf.data(), buf.size());
  24 + }
  25 +
  26 +#if __ANDROID_API__ >= 9
  27 + Impl(AAssetManager *mgr, const OfflineModelConfig &config)
  28 + : config_(config),
  29 + env_(ORT_LOGGING_LEVEL_ERROR),
  30 + sess_opts_(GetSessionOptions(config)),
  31 + allocator_{} {
  32 + auto buf = ReadFile(mgr, config_.zipformer_ctc.model);
  33 + Init(buf.data(), buf.size());
  34 + }
  35 +#endif
  36 +
  37 + std::vector<Ort::Value> Forward(Ort::Value features,
  38 + Ort::Value features_length) {
  39 + std::array<Ort::Value, 2> inputs = {std::move(features),
  40 + std::move(features_length)};
  41 +
  42 + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  43 + output_names_ptr_.data(), output_names_ptr_.size());
  44 + }
  45 +
  46 + int32_t VocabSize() const { return vocab_size_; }
  47 + int32_t SubsamplingFactor() const { return 4; }
  48 +
  49 + OrtAllocator *Allocator() const { return allocator_; }
  50 +
  51 + private:
  52 + void Init(void *model_data, size_t model_data_length) {
  53 + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
  54 + sess_opts_);
  55 +
  56 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  57 +
  58 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  59 +
  60 + // get meta data
  61 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  62 + if (config_.debug) {
  63 + std::ostringstream os;
  64 + PrintModelMetadata(os, meta_data);
  65 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  66 + }
  67 +
  68 + // get vocab size from the output[0].shape, which is (N, T, vocab_size)
  69 + vocab_size_ =
  70 + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2];
  71 + }
  72 +
  73 + private:
  74 + OfflineModelConfig config_;
  75 + Ort::Env env_;
  76 + Ort::SessionOptions sess_opts_;
  77 + Ort::AllocatorWithDefaultOptions allocator_;
  78 +
  79 + std::unique_ptr<Ort::Session> sess_;
  80 +
  81 + std::vector<std::string> input_names_;
  82 + std::vector<const char *> input_names_ptr_;
  83 +
  84 + std::vector<std::string> output_names_;
  85 + std::vector<const char *> output_names_ptr_;
  86 +
  87 + int32_t vocab_size_ = 0;
  88 +};
  89 +
  90 +OfflineZipformerCtcModel::OfflineZipformerCtcModel(
  91 + const OfflineModelConfig &config)
  92 + : impl_(std::make_unique<Impl>(config)) {}
  93 +
  94 +#if __ANDROID_API__ >= 9
  95 +OfflineZipformerCtcModel::OfflineZipformerCtcModel(
  96 + AAssetManager *mgr, const OfflineModelConfig &config)
  97 + : impl_(std::make_unique<Impl>(mgr, config)) {}
  98 +#endif
  99 +
  100 +OfflineZipformerCtcModel::~OfflineZipformerCtcModel() = default;
  101 +
  102 +std::vector<Ort::Value> OfflineZipformerCtcModel::Forward(
  103 + Ort::Value features, Ort::Value features_length) {
  104 + return impl_->Forward(std::move(features), std::move(features_length));
  105 +}
  106 +
  107 +int32_t OfflineZipformerCtcModel::VocabSize() const {
  108 + return impl_->VocabSize();
  109 +}
  110 +
  111 +OrtAllocator *OfflineZipformerCtcModel::Allocator() const {
  112 + return impl_->Allocator();
  113 +}
  114 +
  115 +int32_t OfflineZipformerCtcModel::SubsamplingFactor() const {
  116 + return impl_->SubsamplingFactor();
  117 +}
  118 +
  119 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-zipformer-ctc-model.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
  6 +#include <memory>
  7 +#include <string>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#if __ANDROID_API__ >= 9
  12 +#include "android/asset_manager.h"
  13 +#include "android/asset_manager_jni.h"
  14 +#endif
  15 +
  16 +#include "onnxruntime_cxx_api.h" // NOLINT
  17 +#include "sherpa-onnx/csrc/offline-ctc-model.h"
  18 +#include "sherpa-onnx/csrc/offline-model-config.h"
  19 +
  20 +namespace sherpa_onnx {
  21 +
  22 +/** This class implements the zipformer CTC model of the librispeech recipe
  23 + * from icefall.
  24 + *
  25 + * See
  26 + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
  27 + */
  28 +class OfflineZipformerCtcModel : public OfflineCtcModel {
  29 + public:
  30 + explicit OfflineZipformerCtcModel(const OfflineModelConfig &config);
  31 +
  32 +#if __ANDROID_API__ >= 9
  33 + OfflineZipformerCtcModel(AAssetManager *mgr,
  34 + const OfflineModelConfig &config);
  35 +#endif
  36 +
  37 + ~OfflineZipformerCtcModel() override;
  38 +
  39 + /** Run the forward method of the model.
  40 + *
  41 + * @param features A tensor of shape (N, T, C).
  42 + * @param features_length A 1-D tensor of shape (N,) containing number of
  43 + * valid frames in `features` before padding.
  44 + * Its dtype is int64_t.
  45 + *
  46 + * @return Return a vector containing:
  47 + * - log_probs: A 3-D tensor of shape (N, T', vocab_size).
  48 + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
  49 + */
  50 + std::vector<Ort::Value> Forward(Ort::Value features,
  51 + Ort::Value features_length) override;
  52 +
  53 + /** Return the vocabulary size of the model
  54 + */
  55 + int32_t VocabSize() const override;
  56 +
  57 + /** Return an allocator for allocating memory
  58 + */
  59 + OrtAllocator *Allocator() const override;
  60 +
  61 + int32_t SubsamplingFactor() const override;
  62 +
  63 + private:
  64 + class Impl;
  65 + std::unique_ptr<Impl> impl_;
  66 +};
  67 +
  68 +} // namespace sherpa_onnx
  69 +
  70 +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
5 display.cc 5 display.cc
6 endpoint.cc 6 endpoint.cc
7 features.cc 7 features.cc
  8 + offline-ctc-fst-decoder-config.cc
8 offline-lm-config.cc 9 offline-lm-config.cc
9 offline-model-config.cc 10 offline-model-config.cc
10 offline-nemo-enc-dec-ctc-model-config.cc 11 offline-nemo-enc-dec-ctc-model-config.cc
@@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx @@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
14 offline-tdnn-model-config.cc 15 offline-tdnn-model-config.cc
15 offline-transducer-model-config.cc 16 offline-transducer-model-config.cc
16 offline-whisper-model-config.cc 17 offline-whisper-model-config.cc
  18 + offline-zipformer-ctc-model-config.cc
17 online-lm-config.cc 19 online-lm-config.cc
18 online-model-config.cc 20 online-model-config.cc
19 online-paraformer-model-config.cc 21 online-paraformer-model-config.cc
  1 +// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineCtcFstDecoderConfig(py::module *m) {
  14 + using PyClass = OfflineCtcFstDecoderConfig;
  15 + py::class_<PyClass>(*m, "OfflineCtcFstDecoderConfig")
  16 + .def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
  17 + py::arg("max_active") = 3000)
  18 + .def_readwrite("graph", &PyClass::graph)
  19 + .def_readwrite("max_active", &PyClass::max_active)
  20 + .def("__str__", &PyClass::ToString);
  21 +}
  22 +
  23 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineCtcFstDecoderConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
@@ -13,6 +13,7 @@ @@ -13,6 +13,7 @@
13 #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
14 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" 14 #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
15 #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" 15 #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
  16 +#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
16 17
17 namespace sherpa_onnx { 18 namespace sherpa_onnx {
18 19
@@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) { @@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) {
22 PybindOfflineNemoEncDecCtcModelConfig(m); 23 PybindOfflineNemoEncDecCtcModelConfig(m);
23 PybindOfflineWhisperModelConfig(m); 24 PybindOfflineWhisperModelConfig(m);
24 PybindOfflineTdnnModelConfig(m); 25 PybindOfflineTdnnModelConfig(m);
  26 + PybindOfflineZipformerCtcModelConfig(m);
25 27
26 using PyClass = OfflineModelConfig; 28 using PyClass = OfflineModelConfig;
27 py::class_<PyClass>(*m, "OfflineModelConfig") 29 py::class_<PyClass>(*m, "OfflineModelConfig")
@@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) { @@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) {
29 const OfflineParaformerModelConfig &, 31 const OfflineParaformerModelConfig &,
30 const OfflineNemoEncDecCtcModelConfig &, 32 const OfflineNemoEncDecCtcModelConfig &,
31 const OfflineWhisperModelConfig &, 33 const OfflineWhisperModelConfig &,
32 - const OfflineTdnnModelConfig &, const std::string &, 34 + const OfflineTdnnModelConfig &,
  35 + const OfflineZipformerCtcModelConfig &, const std::string &,
33 int32_t, bool, const std::string &, const std::string &>(), 36 int32_t, bool, const std::string &, const std::string &>(),
34 py::arg("transducer") = OfflineTransducerModelConfig(), 37 py::arg("transducer") = OfflineTransducerModelConfig(),
35 py::arg("paraformer") = OfflineParaformerModelConfig(), 38 py::arg("paraformer") = OfflineParaformerModelConfig(),
36 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), 39 py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
37 py::arg("whisper") = OfflineWhisperModelConfig(), 40 py::arg("whisper") = OfflineWhisperModelConfig(),
38 - py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"),  
39 - py::arg("num_threads"), py::arg("debug") = false, 41 + py::arg("tdnn") = OfflineTdnnModelConfig(),
  42 + py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
  43 + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
40 py::arg("provider") = "cpu", py::arg("model_type") = "") 44 py::arg("provider") = "cpu", py::arg("model_type") = "")
41 .def_readwrite("transducer", &PyClass::transducer) 45 .def_readwrite("transducer", &PyClass::transducer)
42 .def_readwrite("paraformer", &PyClass::paraformer) 46 .def_readwrite("paraformer", &PyClass::paraformer)
43 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) 47 .def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
44 .def_readwrite("whisper", &PyClass::whisper) 48 .def_readwrite("whisper", &PyClass::whisper)
45 .def_readwrite("tdnn", &PyClass::tdnn) 49 .def_readwrite("tdnn", &PyClass::tdnn)
  50 + .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
46 .def_readwrite("tokens", &PyClass::tokens) 51 .def_readwrite("tokens", &PyClass::tokens)
47 .def_readwrite("num_threads", &PyClass::num_threads) 52 .def_readwrite("num_threads", &PyClass::num_threads)
48 .def_readwrite("debug", &PyClass::debug) 53 .def_readwrite("debug", &PyClass::debug)
@@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) { @@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
16 py::class_<PyClass>(*m, "OfflineRecognizerConfig") 16 py::class_<PyClass>(*m, "OfflineRecognizerConfig")
17 .def(py::init<const OfflineFeatureExtractorConfig &, 17 .def(py::init<const OfflineFeatureExtractorConfig &,
18 const OfflineModelConfig &, const OfflineLMConfig &, 18 const OfflineModelConfig &, const OfflineLMConfig &,
19 - const std::string &, int32_t, const std::string &, float>(), 19 + const OfflineCtcFstDecoderConfig &, const std::string &,
  20 + int32_t, const std::string &, float>(),
20 py::arg("feat_config"), py::arg("model_config"), 21 py::arg("feat_config"), py::arg("model_config"),
21 py::arg("lm_config") = OfflineLMConfig(), 22 py::arg("lm_config") = OfflineLMConfig(),
  23 + py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
22 py::arg("decoding_method") = "greedy_search", 24 py::arg("decoding_method") = "greedy_search",
23 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", 25 py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
24 py::arg("hotwords_score") = 1.5) 26 py::arg("hotwords_score") = 1.5)
25 .def_readwrite("feat_config", &PyClass::feat_config) 27 .def_readwrite("feat_config", &PyClass::feat_config)
26 .def_readwrite("model_config", &PyClass::model_config) 28 .def_readwrite("model_config", &PyClass::model_config)
27 .def_readwrite("lm_config", &PyClass::lm_config) 29 .def_readwrite("lm_config", &PyClass::lm_config)
  30 + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
28 .def_readwrite("decoding_method", &PyClass::decoding_method) 31 .def_readwrite("decoding_method", &PyClass::decoding_method)
29 .def_readwrite("max_active_paths", &PyClass::max_active_paths) 32 .def_readwrite("max_active_paths", &PyClass::max_active_paths)
30 .def_readwrite("hotwords_file", &PyClass::hotwords_file) 33 .def_readwrite("hotwords_file", &PyClass::hotwords_file)
  1 +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void PybindOfflineZipformerCtcModelConfig(py::module *m) {
  14 + using PyClass = OfflineZipformerCtcModelConfig;
  15 + py::class_<PyClass>(*m, "OfflineZipformerCtcModelConfig")
  16 + .def(py::init<>())
  17 + .def(py::init<const std::string &>(), py::arg("model"))
  18 + .def_readwrite("model", &PyClass::model)
  19 + .def("__str__", &PyClass::ToString);
  20 +}
  21 +
  22 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
  6 +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
  7 +
  8 +#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void PybindOfflineZipformerCtcModelConfig(py::module *m);
  13 +
  14 +}
  15 +
  16 +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 #include "sherpa-onnx/python/csrc/display.h" 8 #include "sherpa-onnx/python/csrc/display.h"
9 #include "sherpa-onnx/python/csrc/endpoint.h" 9 #include "sherpa-onnx/python/csrc/endpoint.h"
10 #include "sherpa-onnx/python/csrc/features.h" 10 #include "sherpa-onnx/python/csrc/features.h"
  11 +#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
11 #include "sherpa-onnx/python/csrc/offline-lm-config.h" 12 #include "sherpa-onnx/python/csrc/offline-lm-config.h"
12 #include "sherpa-onnx/python/csrc/offline-model-config.h" 13 #include "sherpa-onnx/python/csrc/offline-model-config.h"
13 #include "sherpa-onnx/python/csrc/offline-recognizer.h" 14 #include "sherpa-onnx/python/csrc/offline-recognizer.h"
@@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { @@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
37 PybindOfflineStream(&m); 38 PybindOfflineStream(&m);
38 PybindOfflineLMConfig(&m); 39 PybindOfflineLMConfig(&m);
39 PybindOfflineModelConfig(&m); 40 PybindOfflineModelConfig(&m);
  41 + PybindOfflineCtcFstDecoderConfig(&m);
40 PybindOfflineRecognizer(&m); 42 PybindOfflineRecognizer(&m);
41 43
42 PybindVadModelConfig(&m); 44 PybindVadModelConfig(&m);
@@ -4,12 +4,14 @@ from pathlib import Path @@ -4,12 +4,14 @@ from pathlib import Path
4 from typing import List, Optional 4 from typing import List, Optional
5 5
6 from _sherpa_onnx import ( 6 from _sherpa_onnx import (
  7 + OfflineCtcFstDecoderConfig,
7 OfflineFeatureExtractorConfig, 8 OfflineFeatureExtractorConfig,
8 OfflineModelConfig, 9 OfflineModelConfig,
9 OfflineNemoEncDecCtcModelConfig, 10 OfflineNemoEncDecCtcModelConfig,
10 OfflineParaformerModelConfig, 11 OfflineParaformerModelConfig,
11 OfflineTdnnModelConfig, 12 OfflineTdnnModelConfig,
12 OfflineWhisperModelConfig, 13 OfflineWhisperModelConfig,
  14 + OfflineZipformerCtcModelConfig,
13 ) 15 )
14 from _sherpa_onnx import OfflineRecognizer as _Recognizer 16 from _sherpa_onnx import OfflineRecognizer as _Recognizer
15 from _sherpa_onnx import ( 17 from _sherpa_onnx import (