Committed by
GitHub
Add CTC HLG decoding using OpenFst (#349)
正在显示
39 个修改的文件
包含
963 行增加
和
55 行删除
| @@ -89,3 +89,48 @@ time $EXE \ | @@ -89,3 +89,48 @@ time $EXE \ | ||
| 89 | $repo/test_wavs/8k.wav | 89 | $repo/test_wavs/8k.wav |
| 90 | 90 | ||
| 91 | rm -rf $repo | 91 | rm -rf $repo |
| 92 | + | ||
| 93 | +log "------------------------------------------------------------" | ||
| 94 | +log "Run Librispeech zipformer CTC H/HL/HLG decoding (English) " | ||
| 95 | +log "------------------------------------------------------------" | ||
| 96 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 | ||
| 97 | +log "Start testing ${repo_url}" | ||
| 98 | +repo=$(basename $repo_url) | ||
| 99 | +log "Download pretrained model and test-data from $repo_url" | ||
| 100 | + | ||
| 101 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 102 | +pushd $repo | ||
| 103 | +git lfs pull --include "*.onnx" | ||
| 104 | +git lfs pull --include "*.fst" | ||
| 105 | +ls -lh | ||
| 106 | +popd | ||
| 107 | + | ||
| 108 | +graphs=( | ||
| 109 | +$repo/H.fst | ||
| 110 | +$repo/HL.fst | ||
| 111 | +$repo/HLG.fst | ||
| 112 | +) | ||
| 113 | + | ||
| 114 | +for graph in ${graphs[@]}; do | ||
| 115 | + log "test float32 models with $graph" | ||
| 116 | + time $EXE \ | ||
| 117 | + --model-type=zipformer2_ctc \ | ||
| 118 | + --ctc.graph=$graph \ | ||
| 119 | + --zipformer-ctc-model=$repo/model.onnx \ | ||
| 120 | + --tokens=$repo/tokens.txt \ | ||
| 121 | + $repo/test_wavs/0.wav \ | ||
| 122 | + $repo/test_wavs/1.wav \ | ||
| 123 | + $repo/test_wavs/2.wav | ||
| 124 | + | ||
| 125 | + log "test int8 models with $graph" | ||
| 126 | + time $EXE \ | ||
| 127 | + --model-type=zipformer2_ctc \ | ||
| 128 | + --ctc.graph=$graph \ | ||
| 129 | + --zipformer-ctc-model=$repo/model.int8.onnx \ | ||
| 130 | + --tokens=$repo/tokens.txt \ | ||
| 131 | + $repo/test_wavs/0.wav \ | ||
| 132 | + $repo/test_wavs/1.wav \ | ||
| 133 | + $repo/test_wavs/2.wav | ||
| 134 | +done | ||
| 135 | + | ||
| 136 | +rm -rf $repo |
| @@ -18,7 +18,7 @@ permissions: | @@ -18,7 +18,7 @@ permissions: | ||
| 18 | jobs: | 18 | jobs: |
| 19 | python_online_websocket_server: | 19 | python_online_websocket_server: |
| 20 | runs-on: ${{ matrix.os }} | 20 | runs-on: ${{ matrix.os }} |
| 21 | - name: ${{ matrix.os }} ${{ matrix.python-version }} | 21 | + name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }} |
| 22 | strategy: | 22 | strategy: |
| 23 | fail-fast: false | 23 | fail-fast: false |
| 24 | matrix: | 24 | matrix: |
| @@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) | @@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) | ||
| 154 | endif() | 154 | endif() |
| 155 | 155 | ||
| 156 | include(kaldi-native-fbank) | 156 | include(kaldi-native-fbank) |
| 157 | +include(kaldi-decoder) | ||
| 157 | include(onnxruntime) | 158 | include(onnxruntime) |
| 158 | 159 | ||
| 159 | if(SHERPA_ONNX_ENABLE_PORTAUDIO) | 160 | if(SHERPA_ONNX_ENABLE_PORTAUDIO) |
cmake/eigen.cmake
0 → 100644
| 1 | +function(download_eigen) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz") | ||
| 5 | + set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz") | ||
| 6 | + set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download eigen | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/eigen-3.4.0.tar.gz | ||
| 12 | + ${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz | ||
| 13 | + ${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz | ||
| 14 | + /tmp/eigen-3.4.0.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/eigen-3.4.0.tar.gz | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(eigen_URL "${f}") | ||
| 21 | + file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL) | ||
| 22 | + message(STATUS "Found local downloaded eigen: ${eigen_URL}") | ||
| 23 | + set(eigen_URL2) | ||
| 24 | + break() | ||
| 25 | + endif() | ||
| 26 | + endforeach() | ||
| 27 | + | ||
| 28 | + set(BUILD_TESTING OFF CACHE BOOL "" FORCE) | ||
| 29 | + set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE) | ||
| 30 | + | ||
| 31 | + FetchContent_Declare(eigen | ||
| 32 | + URL ${eigen_URL} | ||
| 33 | + URL_HASH ${eigen_HASH} | ||
| 34 | + ) | ||
| 35 | + | ||
| 36 | + FetchContent_GetProperties(eigen) | ||
| 37 | + if(NOT eigen_POPULATED) | ||
| 38 | + message(STATUS "Downloading eigen from ${eigen_URL}") | ||
| 39 | + FetchContent_Populate(eigen) | ||
| 40 | + endif() | ||
| 41 | + message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}") | ||
| 42 | + message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}") | ||
| 43 | + | ||
| 44 | + add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 45 | +endfunction() | ||
| 46 | + | ||
| 47 | +download_eigen() | ||
| 48 | + |
cmake/kaldi-decoder.cmake
0 → 100644
| 1 | +function(download_kaldi_decoder) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.3.tar.gz") | ||
| 5 | + set(kaldi_decoder_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.3.tar.gz") | ||
| 6 | + set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d") | ||
| 7 | + | ||
| 8 | + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 9 | + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 10 | + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 11 | + | ||
| 12 | + # If you don't have access to the Internet, | ||
| 13 | + # please pre-download kaldi-decoder | ||
| 14 | + set(possible_file_locations | ||
| 15 | + $ENV{HOME}/Downloads/kaldi-decoder-0.2.3.tar.gz | ||
| 16 | + ${PROJECT_SOURCE_DIR}/kaldi-decoder-0.2.3.tar.gz | ||
| 17 | + ${PROJECT_BINARY_DIR}/kaldi-decoder-0.2.3.tar.gz | ||
| 18 | + /tmp/kaldi-decoder-0.2.3.tar.gz | ||
| 19 | + /star-fj/fangjun/download/github/kaldi-decoder-0.2.3.tar.gz | ||
| 20 | + ) | ||
| 21 | + | ||
| 22 | + foreach(f IN LISTS possible_file_locations) | ||
| 23 | + if(EXISTS ${f}) | ||
| 24 | + set(kaldi_decoder_URL "${f}") | ||
| 25 | + file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL) | ||
| 26 | + message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}") | ||
| 27 | + set(kaldi_decoder_URL2 ) | ||
| 28 | + break() | ||
| 29 | + endif() | ||
| 30 | + endforeach() | ||
| 31 | + | ||
| 32 | + FetchContent_Declare(kaldi_decoder | ||
| 33 | + URL | ||
| 34 | + ${kaldi_decoder_URL} | ||
| 35 | + ${kaldi_decoder_URL2} | ||
| 36 | + URL_HASH ${kaldi_decoder_HASH} | ||
| 37 | + ) | ||
| 38 | + | ||
| 39 | + FetchContent_GetProperties(kaldi_decoder) | ||
| 40 | + if(NOT kaldi_decoder_POPULATED) | ||
| 41 | + message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}") | ||
| 42 | + FetchContent_Populate(kaldi_decoder) | ||
| 43 | + endif() | ||
| 44 | + message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}") | ||
| 45 | + message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}") | ||
| 46 | + | ||
| 47 | + include_directories(${kaldi_decoder_SOURCE_DIR}) | ||
| 48 | + add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 49 | + | ||
| 50 | + target_include_directories(kaldi-decoder-core | ||
| 51 | + INTERFACE | ||
| 52 | + ${kaldi-decoder_SOURCE_DIR}/ | ||
| 53 | + ) | ||
| 54 | + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) | ||
| 55 | + install(TARGETS | ||
| 56 | + kaldi-decoder-core | ||
| 57 | + kaldifst_core | ||
| 58 | + fst | ||
| 59 | + DESTINATION ..) | ||
| 60 | + else() | ||
| 61 | + install(TARGETS | ||
| 62 | + kaldi-decoder-core | ||
| 63 | + kaldifst_core | ||
| 64 | + fst | ||
| 65 | + DESTINATION lib) | ||
| 66 | + endif() | ||
| 67 | + | ||
| 68 | + if(WIN32 AND BUILD_SHARED_LIBS) | ||
| 69 | + install(TARGETS | ||
| 70 | + kaldi-decoder-core | ||
| 71 | + kaldifst_core | ||
| 72 | + fst | ||
| 73 | + DESTINATION bin) | ||
| 74 | + endif() | ||
| 75 | +endfunction() | ||
| 76 | + | ||
| 77 | +download_kaldi_decoder() | ||
| 78 | + |
cmake/kaldifst.cmake
0 → 100644
| 1 | +function(download_kaldifst) | ||
| 2 | + include(FetchContent) | ||
| 3 | + | ||
| 4 | + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz") | ||
| 5 | + set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz") | ||
| 6 | + set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2") | ||
| 7 | + | ||
| 8 | + # If you don't have access to the Internet, | ||
| 9 | + # please pre-download kaldifst | ||
| 10 | + set(possible_file_locations | ||
| 11 | + $ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz | ||
| 12 | + ${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz | ||
| 13 | + ${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz | ||
| 14 | + /tmp/kaldifst-1.7.6.tar.gz | ||
| 15 | + /star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz | ||
| 16 | + ) | ||
| 17 | + | ||
| 18 | + foreach(f IN LISTS possible_file_locations) | ||
| 19 | + if(EXISTS ${f}) | ||
| 20 | + set(kaldifst_URL "${f}") | ||
| 21 | + file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL) | ||
| 22 | + message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}") | ||
| 23 | + set(kaldifst_URL2) | ||
| 24 | + break() | ||
| 25 | + endif() | ||
| 26 | + endforeach() | ||
| 27 | + | ||
| 28 | + set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE) | ||
| 29 | + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | ||
| 30 | + | ||
| 31 | + FetchContent_Declare(kaldifst | ||
| 32 | + URL ${kaldifst_URL} | ||
| 33 | + URL_HASH ${kaldifst_HASH} | ||
| 34 | + ) | ||
| 35 | + | ||
| 36 | + FetchContent_GetProperties(kaldifst) | ||
| 37 | + if(NOT kaldifst_POPULATED) | ||
| 38 | + message(STATUS "Downloading kaldifst from ${kaldifst_URL}") | ||
| 39 | + FetchContent_Populate(kaldifst) | ||
| 40 | + endif() | ||
| 41 | + message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}") | ||
| 42 | + message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}") | ||
| 43 | + | ||
| 44 | + list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake) | ||
| 45 | + | ||
| 46 | + add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL) | ||
| 47 | + | ||
| 48 | + target_include_directories(kaldifst_core | ||
| 49 | + PUBLIC | ||
| 50 | + ${kaldifst_SOURCE_DIR}/ | ||
| 51 | + ) | ||
| 52 | + | ||
| 53 | + target_include_directories(fst | ||
| 54 | + PUBLIC | ||
| 55 | + ${openfst_SOURCE_DIR}/src/include | ||
| 56 | + ) | ||
| 57 | + | ||
| 58 | + set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-onnx-kaldifst-core") | ||
| 59 | + set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst") | ||
| 60 | +endfunction() | ||
| 61 | + | ||
| 62 | +download_kaldifst() |
| @@ -13,4 +13,4 @@ Cflags: -I"${includedir}" | @@ -13,4 +13,4 @@ Cflags: -I"${includedir}" | ||
| 13 | # Note: -lcargs is required only for the following file | 13 | # Note: -lcargs is required only for the following file |
| 14 | # https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c | 14 | # https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c |
| 15 | # We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c | 15 | # We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c |
| 16 | -Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ | 16 | +Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ |
| @@ -9,6 +9,9 @@ | @@ -9,6 +9,9 @@ | ||
| 9 | sherpa-onnx-portaudio_static.lib; | 9 | sherpa-onnx-portaudio_static.lib; |
| 10 | sherpa-onnx-c-api.lib; | 10 | sherpa-onnx-c-api.lib; |
| 11 | sherpa-onnx-core.lib; | 11 | sherpa-onnx-core.lib; |
| 12 | + kaldi-decoder-core.lib; | ||
| 13 | + sherpa-onnx-kaldifst-core.lib; | ||
| 14 | + sherpa-onnx-fst.lib; | ||
| 12 | kaldi-native-fbank-core.lib; | 15 | kaldi-native-fbank-core.lib; |
| 13 | absl_base.lib; | 16 | absl_base.lib; |
| 14 | absl_city.lib; | 17 | absl_city.lib; |
| @@ -9,6 +9,9 @@ | @@ -9,6 +9,9 @@ | ||
| 9 | sherpa-onnx-portaudio_static.lib; | 9 | sherpa-onnx-portaudio_static.lib; |
| 10 | sherpa-onnx-c-api.lib; | 10 | sherpa-onnx-c-api.lib; |
| 11 | sherpa-onnx-core.lib; | 11 | sherpa-onnx-core.lib; |
| 12 | + kaldi-decoder-core.lib; | ||
| 13 | + sherpa-onnx-kaldifst-core.lib; | ||
| 14 | + sherpa-onnx-fst.lib; | ||
| 12 | kaldi-native-fbank-core.lib; | 15 | kaldi-native-fbank-core.lib; |
| 13 | absl_base.lib; | 16 | absl_base.lib; |
| 14 | absl_city.lib; | 17 | absl_city.lib; |
| @@ -19,6 +19,8 @@ set(sources | @@ -19,6 +19,8 @@ set(sources | ||
| 19 | features.cc | 19 | features.cc |
| 20 | file-utils.cc | 20 | file-utils.cc |
| 21 | hypothesis.cc | 21 | hypothesis.cc |
| 22 | + offline-ctc-fst-decoder-config.cc | ||
| 23 | + offline-ctc-fst-decoder.cc | ||
| 22 | offline-ctc-greedy-search-decoder.cc | 24 | offline-ctc-greedy-search-decoder.cc |
| 23 | offline-ctc-model.cc | 25 | offline-ctc-model.cc |
| 24 | offline-lm-config.cc | 26 | offline-lm-config.cc |
| @@ -42,6 +44,8 @@ set(sources | @@ -42,6 +44,8 @@ set(sources | ||
| 42 | offline-whisper-greedy-search-decoder.cc | 44 | offline-whisper-greedy-search-decoder.cc |
| 43 | offline-whisper-model-config.cc | 45 | offline-whisper-model-config.cc |
| 44 | offline-whisper-model.cc | 46 | offline-whisper-model.cc |
| 47 | + offline-zipformer-ctc-model-config.cc | ||
| 48 | + offline-zipformer-ctc-model.cc | ||
| 45 | online-conformer-transducer-model.cc | 49 | online-conformer-transducer-model.cc |
| 46 | online-lm-config.cc | 50 | online-lm-config.cc |
| 47 | online-lm.cc | 51 | online-lm.cc |
| @@ -97,6 +101,8 @@ endif() | @@ -97,6 +101,8 @@ endif() | ||
| 97 | 101 | ||
| 98 | target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core) | 102 | target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core) |
| 99 | 103 | ||
| 104 | +target_link_libraries(sherpa-onnx-core kaldi-decoder-core) | ||
| 105 | + | ||
| 100 | if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm) | 106 | if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm) |
| 101 | target_link_libraries(sherpa-onnx-core onnxruntime) | 107 | target_link_libraries(sherpa-onnx-core onnxruntime) |
| 102 | else() | 108 | else() |
| 1 | +// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" | ||
| 6 | + | ||
| 7 | +#include <sstream> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +std::string OfflineCtcFstDecoderConfig::ToString() const { | ||
| 13 | + std::ostringstream os; | ||
| 14 | + | ||
| 15 | + os << "OfflineCtcFstDecoderConfig("; | ||
| 16 | + os << "graph=\"" << graph << "\", "; | ||
| 17 | + os << "max_active=" << max_active << ")"; | ||
| 18 | + | ||
| 19 | + return os.str(); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { | ||
| 23 | + std::string prefix = "ctc"; | ||
| 24 | + ParseOptions p(prefix, po); | ||
| 25 | + | ||
| 26 | + p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); | ||
| 27 | + | ||
| 28 | + p.Register("max-active", &max_active, | ||
| 29 | + "Decoder max active states. Larger->slower; more accurate"); | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineCtcFstDecoderConfig { | ||
| 15 | + // Path to H.fst, HL.fst or HLG.fst | ||
| 16 | + std::string graph; | ||
| 17 | + int32_t max_active = 3000; | ||
| 18 | + | ||
| 19 | + OfflineCtcFstDecoderConfig() = default; | ||
| 20 | + | ||
| 21 | + OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active) | ||
| 22 | + : graph(graph), max_active(max_active) {} | ||
| 23 | + | ||
| 24 | + std::string ToString() const; | ||
| 25 | + | ||
| 26 | + void Register(ParseOptions *po); | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +} // namespace sherpa_onnx | ||
| 30 | + | ||
| 31 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ |
sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ctc-fst-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | + | ||
| 10 | +#include "fst/fstlib.h" | ||
| 11 | +#include "kaldi-decoder/csrc/decodable-ctc.h" | ||
| 12 | +#include "kaldi-decoder/csrc/eigen.h" | ||
| 13 | +#include "kaldi-decoder/csrc/faster-decoder.h" | ||
| 14 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +// This function is copied from kaldi. | ||
| 19 | +// | ||
| 20 | +// @param filename Path to a StdVectorFst or StdConstFst graph | ||
| 21 | +// @return The caller should free the returned pointer using `delete` to | ||
| 22 | +// avoid memory leak. | ||
| 23 | +static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) { | ||
| 24 | + // read decoding network FST | ||
| 25 | + std::ifstream is(filename, std::ios::binary); | ||
| 26 | + if (!is.good()) { | ||
| 27 | + SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str()); | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + fst::FstHeader hdr; | ||
| 31 | + if (!hdr.Read(is, "<unknown>")) { | ||
| 32 | + SHERPA_ONNX_LOGE("Reading FST: error reading FST header."); | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + if (hdr.ArcType() != fst::StdArc::Type()) { | ||
| 36 | + SHERPA_ONNX_LOGE("FST with arc type %s not supported", | ||
| 37 | + hdr.ArcType().c_str()); | ||
| 38 | + } | ||
| 39 | + fst::FstReadOptions ropts("<unspecified>", &hdr); | ||
| 40 | + | ||
| 41 | + fst::Fst<fst::StdArc> *decode_fst = nullptr; | ||
| 42 | + | ||
| 43 | + if (hdr.FstType() == "vector") { | ||
| 44 | + decode_fst = fst::VectorFst<fst::StdArc>::Read(is, ropts); | ||
| 45 | + } else if (hdr.FstType() == "const") { | ||
| 46 | + decode_fst = fst::ConstFst<fst::StdArc>::Read(is, ropts); | ||
| 47 | + } else { | ||
| 48 | + SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s", | ||
| 49 | + hdr.FstType().c_str()); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + if (decode_fst == nullptr) { // fst code will warn. | ||
| 53 | + SHERPA_ONNX_LOGE("Error reading FST (after reading header)."); | ||
| 54 | + return nullptr; | ||
| 55 | + } else { | ||
| 56 | + return decode_fst; | ||
| 57 | + } | ||
| 58 | +} | ||
| 59 | + | ||
| 60 | +/** | ||
| 61 | + * @param decoder | ||
| 62 | + * @param p Pointer to a 2-d array of shape (num_frames, vocab_size) | ||
| 63 | + * @param num_frames Number of rows in the 2-d array. | ||
| 64 | + * @param vocab_size Number of columns in the 2-d array. | ||
| 65 | + * @return Return the decoded result. | ||
| 66 | + */ | ||
| 67 | +static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, | ||
| 68 | + const float *p, int32_t num_frames, | ||
| 69 | + int32_t vocab_size) { | ||
| 70 | + OfflineCtcDecoderResult r; | ||
| 71 | + kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size); | ||
| 72 | + | ||
| 73 | + decoder->Decode(&decodable); | ||
| 74 | + | ||
| 75 | + if (!decoder->ReachedFinal()) { | ||
| 76 | + SHERPA_ONNX_LOGE("Not reached final!"); | ||
| 77 | + return r; | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + fst::VectorFst<fst::LatticeArc> decoded; // linear FST. | ||
| 81 | + decoder->GetBestPath(&decoded); | ||
| 82 | + | ||
| 83 | + if (decoded.NumStates() == 0) { | ||
| 84 | + SHERPA_ONNX_LOGE("Empty best path!"); | ||
| 85 | + return r; | ||
| 86 | + } | ||
| 87 | + | ||
| 88 | + auto cur_state = decoded.Start(); | ||
| 89 | + | ||
| 90 | + int32_t blank_id = 0; | ||
| 91 | + | ||
| 92 | + for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) { | ||
| 93 | + fst::ArcIterator<fst::Fst<fst::LatticeArc>> iter(decoded, cur_state); | ||
| 94 | + const auto &arc = iter.Value(); | ||
| 95 | + | ||
| 96 | + cur_state = arc.nextstate; | ||
| 97 | + | ||
| 98 | + if (arc.ilabel == prev) { | ||
| 99 | + continue; | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + // 0 is epsilon here | ||
| 103 | + if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) { | ||
| 104 | + prev = arc.ilabel; | ||
| 105 | + continue; | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + // -1 here since the input labels are incremented during graph | ||
| 109 | + // construction | ||
| 110 | + r.tokens.push_back(arc.ilabel - 1); | ||
| 111 | + | ||
| 112 | + r.timestamps.push_back(t); | ||
| 113 | + prev = arc.ilabel; | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + return r; | ||
| 117 | +} | ||
| 118 | + | ||
| 119 | +OfflineCtcFstDecoder::OfflineCtcFstDecoder( | ||
| 120 | + const OfflineCtcFstDecoderConfig &config) | ||
| 121 | + : config_(config), fst_(ReadGraph(config_.graph)) {} | ||
| 122 | + | ||
| 123 | +std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode( | ||
| 124 | + Ort::Value log_probs, Ort::Value log_probs_length) { | ||
| 125 | + std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 126 | + | ||
| 127 | + assert(static_cast<int32_t>(shape.size()) == 3); | ||
| 128 | + int32_t batch_size = shape[0]; | ||
| 129 | + int32_t T = shape[1]; | ||
| 130 | + int32_t vocab_size = shape[2]; | ||
| 131 | + | ||
| 132 | + std::vector<int64_t> length_shape = | ||
| 133 | + log_probs_length.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 134 | + assert(static_cast<int32_t>(length_shape.size()) == 1); | ||
| 135 | + | ||
| 136 | + assert(shape[0] == length_shape[0]); | ||
| 137 | + | ||
| 138 | + kaldi_decoder::FasterDecoderOptions opts; | ||
| 139 | + opts.max_active = config_.max_active; | ||
| 140 | + kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts); | ||
| 141 | + | ||
| 142 | + const float *start = log_probs.GetTensorData<float>(); | ||
| 143 | + | ||
| 144 | + std::vector<OfflineCtcDecoderResult> ans; | ||
| 145 | + ans.reserve(batch_size); | ||
| 146 | + | ||
| 147 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 148 | + const float *p = start + i * T * vocab_size; | ||
| 149 | + int32_t num_frames = log_probs_length.GetTensorData<int64_t>()[i]; | ||
| 150 | + auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size); | ||
| 151 | + ans.push_back(std::move(r)); | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + return ans; | ||
| 155 | +} | ||
| 156 | + | ||
| 157 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-ctc-fst-decoder.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-ctc-fst-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "fst/fst.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" | ||
| 14 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OfflineCtcFstDecoder : public OfflineCtcDecoder { | ||
| 19 | + public: | ||
| 20 | + explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config); | ||
| 21 | + | ||
| 22 | + std::vector<OfflineCtcDecoderResult> Decode( | ||
| 23 | + Ort::Value log_probs, Ort::Value log_probs_length) override; | ||
| 24 | + | ||
| 25 | + private: | ||
| 26 | + OfflineCtcFstDecoderConfig config_; | ||
| 27 | + | ||
| 28 | + std::unique_ptr<fst::Fst<fst::StdArc>> fst_; | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ |
| @@ -12,6 +12,7 @@ | @@ -12,6 +12,7 @@ | ||
| 12 | #include "sherpa-onnx/csrc/macros.h" | 12 | #include "sherpa-onnx/csrc/macros.h" |
| 13 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" | 13 | #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" |
| 14 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | 14 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" |
| 15 | +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" | ||
| 15 | #include "sherpa-onnx/csrc/onnx-utils.h" | 16 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 16 | 17 | ||
| 17 | namespace { | 18 | namespace { |
| @@ -19,6 +20,7 @@ namespace { | @@ -19,6 +20,7 @@ namespace { | ||
| 19 | enum class ModelType { | 20 | enum class ModelType { |
| 20 | kEncDecCTCModelBPE, | 21 | kEncDecCTCModelBPE, |
| 21 | kTdnn, | 22 | kTdnn, |
| 23 | + kZipformerCtc, | ||
| 22 | kUnkown, | 24 | kUnkown, |
| 23 | }; | 25 | }; |
| 24 | 26 | ||
| @@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | @@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, | ||
| 59 | return ModelType::kEncDecCTCModelBPE; | 61 | return ModelType::kEncDecCTCModelBPE; |
| 60 | } else if (model_type.get() == std::string("tdnn")) { | 62 | } else if (model_type.get() == std::string("tdnn")) { |
| 61 | return ModelType::kTdnn; | 63 | return ModelType::kTdnn; |
| 64 | + } else if (model_type.get() == std::string("zipformer2_ctc")) { | ||
| 65 | + return ModelType::kZipformerCtc; | ||
| 62 | } else { | 66 | } else { |
| 63 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); | 67 | SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); |
| 64 | return ModelType::kUnkown; | 68 | return ModelType::kUnkown; |
| @@ -74,6 +78,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -74,6 +78,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 74 | filename = config.nemo_ctc.model; | 78 | filename = config.nemo_ctc.model; |
| 75 | } else if (!config.tdnn.model.empty()) { | 79 | } else if (!config.tdnn.model.empty()) { |
| 76 | filename = config.tdnn.model; | 80 | filename = config.tdnn.model; |
| 81 | + } else if (!config.zipformer_ctc.model.empty()) { | ||
| 82 | + filename = config.zipformer_ctc.model; | ||
| 77 | } else { | 83 | } else { |
| 78 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 84 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 79 | exit(-1); | 85 | exit(-1); |
| @@ -92,6 +98,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -92,6 +98,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 92 | case ModelType::kTdnn: | 98 | case ModelType::kTdnn: |
| 93 | return std::make_unique<OfflineTdnnCtcModel>(config); | 99 | return std::make_unique<OfflineTdnnCtcModel>(config); |
| 94 | break; | 100 | break; |
| 101 | + case ModelType::kZipformerCtc: | ||
| 102 | + return std::make_unique<OfflineZipformerCtcModel>(config); | ||
| 103 | + break; | ||
| 95 | case ModelType::kUnkown: | 104 | case ModelType::kUnkown: |
| 96 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 105 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 97 | return nullptr; | 106 | return nullptr; |
| @@ -111,6 +120,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -111,6 +120,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 111 | filename = config.nemo_ctc.model; | 120 | filename = config.nemo_ctc.model; |
| 112 | } else if (!config.tdnn.model.empty()) { | 121 | } else if (!config.tdnn.model.empty()) { |
| 113 | filename = config.tdnn.model; | 122 | filename = config.tdnn.model; |
| 123 | + } else if (!config.zipformer_ctc.model.empty()) { | ||
| 124 | + filename = config.zipformer_ctc.model; | ||
| 114 | } else { | 125 | } else { |
| 115 | SHERPA_ONNX_LOGE("Please specify a CTC model"); | 126 | SHERPA_ONNX_LOGE("Please specify a CTC model"); |
| 116 | exit(-1); | 127 | exit(-1); |
| @@ -129,6 +140,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | @@ -129,6 +140,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create( | ||
| 129 | case ModelType::kTdnn: | 140 | case ModelType::kTdnn: |
| 130 | return std::make_unique<OfflineTdnnCtcModel>(mgr, config); | 141 | return std::make_unique<OfflineTdnnCtcModel>(mgr, config); |
| 131 | break; | 142 | break; |
| 143 | + case ModelType::kZipformerCtc: | ||
| 144 | + return std::make_unique<OfflineZipformerCtcModel>(mgr, config); | ||
| 145 | + break; | ||
| 132 | case ModelType::kUnkown: | 146 | case ModelType::kUnkown: |
| 133 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); | 147 | SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); |
| 134 | return nullptr; | 148 | return nullptr; |
| @@ -6,7 +6,7 @@ | @@ -6,7 +6,7 @@ | ||
| 6 | 6 | ||
| 7 | #include <memory> | 7 | #include <memory> |
| 8 | #include <string> | 8 | #include <string> |
| 9 | -#include <utility> | 9 | +#include <vector> |
| 10 | 10 | ||
| 11 | #if __ANDROID_API__ >= 9 | 11 | #if __ANDROID_API__ >= 9 |
| 12 | #include "android/asset_manager.h" | 12 | #include "android/asset_manager.h" |
| @@ -32,17 +32,17 @@ class OfflineCtcModel { | @@ -32,17 +32,17 @@ class OfflineCtcModel { | ||
| 32 | 32 | ||
| 33 | /** Run the forward method of the model. | 33 | /** Run the forward method of the model. |
| 34 | * | 34 | * |
| 35 | - * @param features A tensor of shape (N, T, C). It is changed in-place. | 35 | + * @param features A tensor of shape (N, T, C). |
| 36 | * @param features_length A 1-D tensor of shape (N,) containing number of | 36 | * @param features_length A 1-D tensor of shape (N,) containing number of |
| 37 | * valid frames in `features` before padding. | 37 | * valid frames in `features` before padding. |
| 38 | * Its dtype is int64_t. | 38 | * Its dtype is int64_t. |
| 39 | * | 39 | * |
| 40 | - * @return Return a pair containing: | 40 | + * @return Return a vector containing: |
| 41 | * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | 41 | * - log_probs: A 3-D tensor of shape (N, T', vocab_size). |
| 42 | * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | 42 | * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t |
| 43 | */ | 43 | */ |
| 44 | - virtual std::pair<Ort::Value, Ort::Value> Forward( | ||
| 45 | - Ort::Value features, Ort::Value features_length) = 0; | 44 | + virtual std::vector<Ort::Value> Forward(Ort::Value features, |
| 45 | + Ort::Value features_length) = 0; | ||
| 46 | 46 | ||
| 47 | /** Return the vocabulary size of the model | 47 | /** Return the vocabulary size of the model |
| 48 | */ | 48 | */ |
| @@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 16 | nemo_ctc.Register(po); | 16 | nemo_ctc.Register(po); |
| 17 | whisper.Register(po); | 17 | whisper.Register(po); |
| 18 | tdnn.Register(po); | 18 | tdnn.Register(po); |
| 19 | + zipformer_ctc.Register(po); | ||
| 19 | 20 | ||
| 20 | po->Register("tokens", &tokens, "Path to tokens.txt"); | 21 | po->Register("tokens", &tokens, "Path to tokens.txt"); |
| 21 | 22 | ||
| @@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | @@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 31 | po->Register("model-type", &model_type, | 32 | po->Register("model-type", &model_type, |
| 32 | "Specify it to reduce model initialization time. " | 33 | "Specify it to reduce model initialization time. " |
| 33 | "Valid values are: transducer, paraformer, nemo_ctc, whisper, " | 34 | "Valid values are: transducer, paraformer, nemo_ctc, whisper, " |
| 34 | - "tdnn." | 35 | + "tdnn, zipformer2_ctc" |
| 35 | "All other values lead to loading the model twice."); | 36 | "All other values lead to loading the model twice."); |
| 36 | } | 37 | } |
| 37 | 38 | ||
| @@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const { | @@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const { | ||
| 62 | return tdnn.Validate(); | 63 | return tdnn.Validate(); |
| 63 | } | 64 | } |
| 64 | 65 | ||
| 66 | + if (!zipformer_ctc.model.empty()) { | ||
| 67 | + return zipformer_ctc.Validate(); | ||
| 68 | + } | ||
| 69 | + | ||
| 65 | return transducer.Validate(); | 70 | return transducer.Validate(); |
| 66 | } | 71 | } |
| 67 | 72 | ||
| @@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const { | @@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const { | ||
| 74 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; | 79 | os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; |
| 75 | os << "whisper=" << whisper.ToString() << ", "; | 80 | os << "whisper=" << whisper.ToString() << ", "; |
| 76 | os << "tdnn=" << tdnn.ToString() << ", "; | 81 | os << "tdnn=" << tdnn.ToString() << ", "; |
| 82 | + os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; | ||
| 77 | os << "tokens=\"" << tokens << "\", "; | 83 | os << "tokens=\"" << tokens << "\", "; |
| 78 | os << "num_threads=" << num_threads << ", "; | 84 | os << "num_threads=" << num_threads << ", "; |
| 79 | os << "debug=" << (debug ? "True" : "False") << ", "; | 85 | os << "debug=" << (debug ? "True" : "False") << ", "; |
| @@ -11,6 +11,7 @@ | @@ -11,6 +11,7 @@ | ||
| 11 | #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" | 11 | #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" |
| 12 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 12 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" | 13 | #include "sherpa-onnx/csrc/offline-whisper-model-config.h" |
| 14 | +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" | ||
| 14 | 15 | ||
| 15 | namespace sherpa_onnx { | 16 | namespace sherpa_onnx { |
| 16 | 17 | ||
| @@ -20,6 +21,7 @@ struct OfflineModelConfig { | @@ -20,6 +21,7 @@ struct OfflineModelConfig { | ||
| 20 | OfflineNemoEncDecCtcModelConfig nemo_ctc; | 21 | OfflineNemoEncDecCtcModelConfig nemo_ctc; |
| 21 | OfflineWhisperModelConfig whisper; | 22 | OfflineWhisperModelConfig whisper; |
| 22 | OfflineTdnnModelConfig tdnn; | 23 | OfflineTdnnModelConfig tdnn; |
| 24 | + OfflineZipformerCtcModelConfig zipformer_ctc; | ||
| 23 | 25 | ||
| 24 | std::string tokens; | 26 | std::string tokens; |
| 25 | int32_t num_threads = 2; | 27 | int32_t num_threads = 2; |
| @@ -43,6 +45,7 @@ struct OfflineModelConfig { | @@ -43,6 +45,7 @@ struct OfflineModelConfig { | ||
| 43 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, | 45 | const OfflineNemoEncDecCtcModelConfig &nemo_ctc, |
| 44 | const OfflineWhisperModelConfig &whisper, | 46 | const OfflineWhisperModelConfig &whisper, |
| 45 | const OfflineTdnnModelConfig &tdnn, | 47 | const OfflineTdnnModelConfig &tdnn, |
| 48 | + const OfflineZipformerCtcModelConfig &zipformer_ctc, | ||
| 46 | const std::string &tokens, int32_t num_threads, bool debug, | 49 | const std::string &tokens, int32_t num_threads, bool debug, |
| 47 | const std::string &provider, const std::string &model_type) | 50 | const std::string &provider, const std::string &model_type) |
| 48 | : transducer(transducer), | 51 | : transducer(transducer), |
| @@ -50,6 +53,7 @@ struct OfflineModelConfig { | @@ -50,6 +53,7 @@ struct OfflineModelConfig { | ||
| 50 | nemo_ctc(nemo_ctc), | 53 | nemo_ctc(nemo_ctc), |
| 51 | whisper(whisper), | 54 | whisper(whisper), |
| 52 | tdnn(tdnn), | 55 | tdnn(tdnn), |
| 56 | + zipformer_ctc(zipformer_ctc), | ||
| 53 | tokens(tokens), | 57 | tokens(tokens), |
| 54 | num_threads(num_threads), | 58 | num_threads(num_threads), |
| 55 | debug(debug), | 59 | debug(debug), |
| @@ -34,7 +34,7 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -34,7 +34,7 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 34 | } | 34 | } |
| 35 | #endif | 35 | #endif |
| 36 | 36 | ||
| 37 | - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, | 37 | + std::vector<Ort::Value> Forward(Ort::Value features, |
| 38 | Ort::Value features_length) { | 38 | Ort::Value features_length) { |
| 39 | std::vector<int64_t> shape = | 39 | std::vector<int64_t> shape = |
| 40 | features_length.GetTensorTypeAndShapeInfo().GetShape(); | 40 | features_length.GetTensorTypeAndShapeInfo().GetShape(); |
| @@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl { | @@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl { | ||
| 57 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | 57 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), |
| 58 | output_names_ptr_.data(), output_names_ptr_.size()); | 58 | output_names_ptr_.data(), output_names_ptr_.size()); |
| 59 | 59 | ||
| 60 | - return {std::move(out[0]), std::move(out_features_length)}; | 60 | + std::vector<Ort::Value> ans; |
| 61 | + ans.reserve(2); | ||
| 62 | + ans.push_back(std::move(out[0])); | ||
| 63 | + ans.push_back(std::move(out_features_length)); | ||
| 64 | + return ans; | ||
| 61 | } | 65 | } |
| 62 | 66 | ||
| 63 | int32_t VocabSize() const { return vocab_size_; } | 67 | int32_t VocabSize() const { return vocab_size_; } |
| @@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( | @@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( | ||
| 122 | 126 | ||
| 123 | OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; | 127 | OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; |
| 124 | 128 | ||
| 125 | -std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward( | 129 | +std::vector<Ort::Value> OfflineNemoEncDecCtcModel::Forward( |
| 126 | Ort::Value features, Ort::Value features_length) { | 130 | Ort::Value features, Ort::Value features_length) { |
| 127 | return impl_->Forward(std::move(features), std::move(features_length)); | 131 | return impl_->Forward(std::move(features), std::move(features_length)); |
| 128 | } | 132 | } |
| @@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { | @@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { | ||
| 38 | 38 | ||
| 39 | /** Run the forward method of the model. | 39 | /** Run the forward method of the model. |
| 40 | * | 40 | * |
| 41 | - * @param features A tensor of shape (N, T, C). It is changed in-place. | 41 | + * @param features A tensor of shape (N, T, C). |
| 42 | * @param features_length A 1-D tensor of shape (N,) containing number of | 42 | * @param features_length A 1-D tensor of shape (N,) containing number of |
| 43 | * valid frames in `features` before padding. | 43 | * valid frames in `features` before padding. |
| 44 | * Its dtype is int64_t. | 44 | * Its dtype is int64_t. |
| 45 | * | 45 | * |
| 46 | - * @return Return a pair containing: | 46 | + * @return Return a vector containing: |
| 47 | * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | 47 | * - log_probs: A 3-D tensor of shape (N, T', vocab_size). |
| 48 | * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | 48 | * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t |
| 49 | */ | 49 | */ |
| 50 | - std::pair<Ort::Value, Ort::Value> Forward( | ||
| 51 | - Ort::Value features, Ort::Value features_length) override; | 50 | + std::vector<Ort::Value> Forward(Ort::Value features, |
| 51 | + Ort::Value features_length) override; | ||
| 52 | 52 | ||
| 53 | /** Return the vocabulary size of the model | 53 | /** Return the vocabulary size of the model |
| 54 | */ | 54 | */ |
| @@ -16,6 +16,7 @@ | @@ -16,6 +16,7 @@ | ||
| 16 | #endif | 16 | #endif |
| 17 | 17 | ||
| 18 | #include "sherpa-onnx/csrc/offline-ctc-decoder.h" | 18 | #include "sherpa-onnx/csrc/offline-ctc-decoder.h" |
| 19 | +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h" | ||
| 19 | #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" | 20 | #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" |
| 20 | #include "sherpa-onnx/csrc/offline-ctc-model.h" | 21 | #include "sherpa-onnx/csrc/offline-ctc-model.h" |
| 21 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 22 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| @@ -25,9 +26,12 @@ | @@ -25,9 +26,12 @@ | ||
| 25 | namespace sherpa_onnx { | 26 | namespace sherpa_onnx { |
| 26 | 27 | ||
| 27 | static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | 28 | static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, |
| 28 | - const SymbolTable &sym_table) { | 29 | + const SymbolTable &sym_table, |
| 30 | + int32_t frame_shift_ms, | ||
| 31 | + int32_t subsampling_factor) { | ||
| 29 | OfflineRecognitionResult r; | 32 | OfflineRecognitionResult r; |
| 30 | r.tokens.reserve(src.tokens.size()); | 33 | r.tokens.reserve(src.tokens.size()); |
| 34 | + r.timestamps.reserve(src.timestamps.size()); | ||
| 31 | 35 | ||
| 32 | std::string text; | 36 | std::string text; |
| 33 | 37 | ||
| @@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | @@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, | ||
| 42 | } | 46 | } |
| 43 | r.text = std::move(text); | 47 | r.text = std::move(text); |
| 44 | 48 | ||
| 49 | + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; | ||
| 50 | + for (auto t : src.timestamps) { | ||
| 51 | + float time = frame_shift_s * t; | ||
| 52 | + r.timestamps.push_back(time); | ||
| 53 | + } | ||
| 54 | + | ||
| 45 | return r; | 55 | return r; |
| 46 | } | 56 | } |
| 47 | 57 | ||
| @@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 68 | config_.feat_config.nemo_normalize_type = | 78 | config_.feat_config.nemo_normalize_type = |
| 69 | model_->FeatureNormalizationMethod(); | 79 | model_->FeatureNormalizationMethod(); |
| 70 | 80 | ||
| 71 | - if (config_.decoding_method == "greedy_search") { | 81 | + if (!config_.ctc_fst_decoder_config.graph.empty()) { |
| 82 | + // TODO(fangjun): Support android to read the graph from | ||
| 83 | + // asset_manager | ||
| 84 | + decoder_ = std::make_unique<OfflineCtcFstDecoder>( | ||
| 85 | + config_.ctc_fst_decoder_config); | ||
| 86 | + } else if (config_.decoding_method == "greedy_search") { | ||
| 72 | if (!symbol_table_.contains("<blk>") && | 87 | if (!symbol_table_.contains("<blk>") && |
| 73 | !symbol_table_.contains("<eps>")) { | 88 | !symbol_table_.contains("<eps>")) { |
| 74 | SHERPA_ONNX_LOGE( | 89 | SHERPA_ONNX_LOGE( |
| @@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | @@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { | ||
| 139 | -23.025850929940457f); | 154 | -23.025850929940457f); |
| 140 | auto t = model_->Forward(std::move(x), std::move(x_length)); | 155 | auto t = model_->Forward(std::move(x), std::move(x_length)); |
| 141 | 156 | ||
| 142 | - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | 157 | + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); |
| 143 | 158 | ||
| 159 | + int32_t frame_shift_ms = 10; | ||
| 144 | for (int32_t i = 0; i != n; ++i) { | 160 | for (int32_t i = 0; i != n; ++i) { |
| 145 | - auto r = Convert(results[i], symbol_table_); | 161 | + auto r = Convert(results[i], symbol_table_, frame_shift_ms, |
| 162 | + model_->SubsamplingFactor()); | ||
| 146 | ss[i]->SetResult(r); | 163 | ss[i]->SetResult(r); |
| 147 | } | 164 | } |
| 148 | } | 165 | } |
| @@ -25,9 +25,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -25,9 +25,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 25 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); | 25 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); |
| 26 | } else if (model_type == "paraformer") { | 26 | } else if (model_type == "paraformer") { |
| 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 27 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 28 | - } else if (model_type == "nemo_ctc") { | ||
| 29 | - return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 30 | - } else if (model_type == "tdnn") { | 28 | + } else if (model_type == "nemo_ctc" || model_type == "tdnn" || |
| 29 | + model_type == "zipformer2_ctc") { | ||
| 31 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 30 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 32 | } else if (model_type == "whisper") { | 31 | } else if (model_type == "whisper") { |
| 33 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); | 32 | return std::make_unique<OfflineRecognizerWhisperImpl>(config); |
| @@ -50,6 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -50,6 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 50 | model_filename = config.model_config.nemo_ctc.model; | 49 | model_filename = config.model_config.nemo_ctc.model; |
| 51 | } else if (!config.model_config.tdnn.model.empty()) { | 50 | } else if (!config.model_config.tdnn.model.empty()) { |
| 52 | model_filename = config.model_config.tdnn.model; | 51 | model_filename = config.model_config.tdnn.model; |
| 52 | + } else if (!config.model_config.zipformer_ctc.model.empty()) { | ||
| 53 | + model_filename = config.model_config.zipformer_ctc.model; | ||
| 53 | } else if (!config.model_config.whisper.encoder.empty()) { | 54 | } else if (!config.model_config.whisper.encoder.empty()) { |
| 54 | model_filename = config.model_config.whisper.encoder; | 55 | model_filename = config.model_config.whisper.encoder; |
| 55 | } else { | 56 | } else { |
| @@ -93,6 +94,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -93,6 +94,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 93 | "\n " | 94 | "\n " |
| 94 | "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" | 95 | "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" |
| 95 | "\n" | 96 | "\n" |
| 97 | + "(5) Zipformer CTC models from icefall" | ||
| 98 | + "\n " | ||
| 99 | + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" | ||
| 100 | + "zipformer/export-onnx-ctc.py" | ||
| 101 | + "\n" | ||
| 96 | "\n"); | 102 | "\n"); |
| 97 | exit(-1); | 103 | exit(-1); |
| 98 | } | 104 | } |
| @@ -107,11 +113,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -107,11 +113,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 107 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); | 113 | return std::make_unique<OfflineRecognizerParaformerImpl>(config); |
| 108 | } | 114 | } |
| 109 | 115 | ||
| 110 | - if (model_type == "EncDecCTCModelBPE") { | ||
| 111 | - return std::make_unique<OfflineRecognizerCtcImpl>(config); | ||
| 112 | - } | ||
| 113 | - | ||
| 114 | - if (model_type == "tdnn") { | 116 | + if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || |
| 117 | + model_type == "zipformer2_ctc") { | ||
| 115 | return std::make_unique<OfflineRecognizerCtcImpl>(config); | 118 | return std::make_unique<OfflineRecognizerCtcImpl>(config); |
| 116 | } | 119 | } |
| 117 | 120 | ||
| @@ -126,7 +129,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -126,7 +129,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 126 | " - Non-streaming Paraformer models from FunASR\n" | 129 | " - Non-streaming Paraformer models from FunASR\n" |
| 127 | " - EncDecCTCModelBPE models from NeMo\n" | 130 | " - EncDecCTCModelBPE models from NeMo\n" |
| 128 | " - Whisper models\n" | 131 | " - Whisper models\n" |
| 129 | - " - Tdnn models\n", | 132 | + " - Tdnn models\n" |
| 133 | + " - Zipformer CTC models\n", | ||
| 130 | model_type.c_str()); | 134 | model_type.c_str()); |
| 131 | 135 | ||
| 132 | exit(-1); | 136 | exit(-1); |
| @@ -141,9 +145,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -141,9 +145,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 141 | return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config); | 145 | return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config); |
| 142 | } else if (model_type == "paraformer") { | 146 | } else if (model_type == "paraformer") { |
| 143 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); | 147 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); |
| 144 | - } else if (model_type == "nemo_ctc") { | ||
| 145 | - return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | ||
| 146 | - } else if (model_type == "tdnn") { | 148 | + } else if (model_type == "nemo_ctc" || model_type == "tdnn" || |
| 149 | + model_type == "zipformer2_ctc") { | ||
| 147 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | 150 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); |
| 148 | } else if (model_type == "whisper") { | 151 | } else if (model_type == "whisper") { |
| 149 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); | 152 | return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config); |
| @@ -166,6 +169,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -166,6 +169,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 166 | model_filename = config.model_config.nemo_ctc.model; | 169 | model_filename = config.model_config.nemo_ctc.model; |
| 167 | } else if (!config.model_config.tdnn.model.empty()) { | 170 | } else if (!config.model_config.tdnn.model.empty()) { |
| 168 | model_filename = config.model_config.tdnn.model; | 171 | model_filename = config.model_config.tdnn.model; |
| 172 | + } else if (!config.model_config.zipformer_ctc.model.empty()) { | ||
| 173 | + model_filename = config.model_config.zipformer_ctc.model; | ||
| 169 | } else if (!config.model_config.whisper.encoder.empty()) { | 174 | } else if (!config.model_config.whisper.encoder.empty()) { |
| 170 | model_filename = config.model_config.whisper.encoder; | 175 | model_filename = config.model_config.whisper.encoder; |
| 171 | } else { | 176 | } else { |
| @@ -209,6 +214,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -209,6 +214,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 209 | "\n " | 214 | "\n " |
| 210 | "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" | 215 | "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" |
| 211 | "\n" | 216 | "\n" |
| 217 | + "(5) Zipformer CTC models from icefall" | ||
| 218 | + "\n " | ||
| 219 | + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" | ||
| 220 | + "zipformer/export-onnx-ctc.py" | ||
| 221 | + "\n" | ||
| 212 | "\n"); | 222 | "\n"); |
| 213 | exit(-1); | 223 | exit(-1); |
| 214 | } | 224 | } |
| @@ -223,11 +233,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -223,11 +233,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 223 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); | 233 | return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config); |
| 224 | } | 234 | } |
| 225 | 235 | ||
| 226 | - if (model_type == "EncDecCTCModelBPE") { | ||
| 227 | - return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | ||
| 228 | - } | ||
| 229 | - | ||
| 230 | - if (model_type == "tdnn") { | 236 | + if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || |
| 237 | + model_type == "zipformer2_ctc") { | ||
| 231 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); | 238 | return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config); |
| 232 | } | 239 | } |
| 233 | 240 | ||
| @@ -242,7 +249,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -242,7 +249,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 242 | " - Non-streaming Paraformer models from FunASR\n" | 249 | " - Non-streaming Paraformer models from FunASR\n" |
| 243 | " - EncDecCTCModelBPE models from NeMo\n" | 250 | " - EncDecCTCModelBPE models from NeMo\n" |
| 244 | " - Whisper models\n" | 251 | " - Whisper models\n" |
| 245 | - " - Tdnn models\n", | 252 | + " - Tdnn models\n" |
| 253 | + " - Zipformer CTC models\n", | ||
| 246 | model_type.c_str()); | 254 | model_type.c_str()); |
| 247 | 255 | ||
| 248 | exit(-1); | 256 | exit(-1); |
| @@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | @@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { | ||
| 17 | feat_config.Register(po); | 17 | feat_config.Register(po); |
| 18 | model_config.Register(po); | 18 | model_config.Register(po); |
| 19 | lm_config.Register(po); | 19 | lm_config.Register(po); |
| 20 | + ctc_fst_decoder_config.Register(po); | ||
| 20 | 21 | ||
| 21 | po->Register( | 22 | po->Register( |
| 22 | "decoding-method", &decoding_method, | 23 | "decoding-method", &decoding_method, |
| @@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 69 | os << "feat_config=" << feat_config.ToString() << ", "; | 70 | os << "feat_config=" << feat_config.ToString() << ", "; |
| 70 | os << "model_config=" << model_config.ToString() << ", "; | 71 | os << "model_config=" << model_config.ToString() << ", "; |
| 71 | os << "lm_config=" << lm_config.ToString() << ", "; | 72 | os << "lm_config=" << lm_config.ToString() << ", "; |
| 73 | + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", "; | ||
| 72 | os << "decoding_method=\"" << decoding_method << "\", "; | 74 | os << "decoding_method=\"" << decoding_method << "\", "; |
| 73 | os << "max_active_paths=" << max_active_paths << ", "; | 75 | os << "max_active_paths=" << max_active_paths << ", "; |
| 74 | os << "hotwords_file=\"" << hotwords_file << "\", "; | 76 | os << "hotwords_file=\"" << hotwords_file << "\", "; |
| @@ -14,6 +14,7 @@ | @@ -14,6 +14,7 @@ | ||
| 14 | #include "android/asset_manager_jni.h" | 14 | #include "android/asset_manager_jni.h" |
| 15 | #endif | 15 | #endif |
| 16 | 16 | ||
| 17 | +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" | ||
| 17 | #include "sherpa-onnx/csrc/offline-lm-config.h" | 18 | #include "sherpa-onnx/csrc/offline-lm-config.h" |
| 18 | #include "sherpa-onnx/csrc/offline-model-config.h" | 19 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 19 | #include "sherpa-onnx/csrc/offline-stream.h" | 20 | #include "sherpa-onnx/csrc/offline-stream.h" |
| @@ -28,6 +29,7 @@ struct OfflineRecognizerConfig { | @@ -28,6 +29,7 @@ struct OfflineRecognizerConfig { | ||
| 28 | OfflineFeatureExtractorConfig feat_config; | 29 | OfflineFeatureExtractorConfig feat_config; |
| 29 | OfflineModelConfig model_config; | 30 | OfflineModelConfig model_config; |
| 30 | OfflineLMConfig lm_config; | 31 | OfflineLMConfig lm_config; |
| 32 | + OfflineCtcFstDecoderConfig ctc_fst_decoder_config; | ||
| 31 | 33 | ||
| 32 | std::string decoding_method = "greedy_search"; | 34 | std::string decoding_method = "greedy_search"; |
| 33 | int32_t max_active_paths = 4; | 35 | int32_t max_active_paths = 4; |
| @@ -39,16 +41,16 @@ struct OfflineRecognizerConfig { | @@ -39,16 +41,16 @@ struct OfflineRecognizerConfig { | ||
| 39 | // TODO(fangjun): Implement modified_beam_search | 41 | // TODO(fangjun): Implement modified_beam_search |
| 40 | 42 | ||
| 41 | OfflineRecognizerConfig() = default; | 43 | OfflineRecognizerConfig() = default; |
| 42 | - OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, | ||
| 43 | - const OfflineModelConfig &model_config, | ||
| 44 | - const OfflineLMConfig &lm_config, | ||
| 45 | - const std::string &decoding_method, | ||
| 46 | - int32_t max_active_paths, | ||
| 47 | - const std::string &hotwords_file, | ||
| 48 | - float hotwords_score) | 44 | + OfflineRecognizerConfig( |
| 45 | + const OfflineFeatureExtractorConfig &feat_config, | ||
| 46 | + const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, | ||
| 47 | + const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, | ||
| 48 | + const std::string &decoding_method, int32_t max_active_paths, | ||
| 49 | + const std::string &hotwords_file, float hotwords_score) | ||
| 49 | : feat_config(feat_config), | 50 | : feat_config(feat_config), |
| 50 | model_config(model_config), | 51 | model_config(model_config), |
| 51 | lm_config(lm_config), | 52 | lm_config(lm_config), |
| 53 | + ctc_fst_decoder_config(ctc_fst_decoder_config), | ||
| 52 | decoding_method(decoding_method), | 54 | decoding_method(decoding_method), |
| 53 | max_active_paths(max_active_paths), | 55 | max_active_paths(max_active_paths), |
| 54 | hotwords_file(hotwords_file), | 56 | hotwords_file(hotwords_file), |
| @@ -4,6 +4,8 @@ | @@ -4,6 +4,8 @@ | ||
| 4 | 4 | ||
| 5 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" | 5 | #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" |
| 6 | 6 | ||
| 7 | +#include <utility> | ||
| 8 | + | ||
| 7 | #include "sherpa-onnx/csrc/macros.h" | 9 | #include "sherpa-onnx/csrc/macros.h" |
| 8 | #include "sherpa-onnx/csrc/onnx-utils.h" | 10 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 9 | #include "sherpa-onnx/csrc/session.h" | 11 | #include "sherpa-onnx/csrc/session.h" |
| @@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl { | @@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl { | ||
| 34 | } | 36 | } |
| 35 | #endif | 37 | #endif |
| 36 | 38 | ||
| 37 | - std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) { | 39 | + std::vector<Ort::Value> Forward(Ort::Value features) { |
| 38 | auto nnet_out = | 40 | auto nnet_out = |
| 39 | sess_->Run({}, input_names_ptr_.data(), &features, 1, | 41 | sess_->Run({}, input_names_ptr_.data(), &features, 1, |
| 40 | output_names_ptr_.data(), output_names_ptr_.size()); | 42 | output_names_ptr_.data(), output_names_ptr_.size()); |
| @@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl { | @@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl { | ||
| 52 | memory_info, out_length_vec.data(), out_length_vec.size(), | 54 | memory_info, out_length_vec.data(), out_length_vec.size(), |
| 53 | out_length_shape.data(), out_length_shape.size()); | 55 | out_length_shape.data(), out_length_shape.size()); |
| 54 | 56 | ||
| 55 | - return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; | 57 | + std::vector<Ort::Value> ans; |
| 58 | + ans.reserve(2); | ||
| 59 | + ans.push_back(std::move(nnet_out[0])); | ||
| 60 | + ans.push_back(Clone(Allocator(), &nnet_out_length)); | ||
| 61 | + return ans; | ||
| 56 | } | 62 | } |
| 57 | 63 | ||
| 58 | int32_t VocabSize() const { return vocab_size_; } | 64 | int32_t VocabSize() const { return vocab_size_; } |
| @@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr, | @@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr, | ||
| 108 | 114 | ||
| 109 | OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; | 115 | OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; |
| 110 | 116 | ||
| 111 | -std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward( | 117 | +std::vector<Ort::Value> OfflineTdnnCtcModel::Forward( |
| 112 | Ort::Value features, Ort::Value /*features_length*/) { | 118 | Ort::Value features, Ort::Value /*features_length*/) { |
| 113 | return impl_->Forward(std::move(features)); | 119 | return impl_->Forward(std::move(features)); |
| 114 | } | 120 | } |
| @@ -5,7 +5,6 @@ | @@ -5,7 +5,6 @@ | ||
| 5 | #define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ | 5 | #define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ |
| 6 | #include <memory> | 6 | #include <memory> |
| 7 | #include <string> | 7 | #include <string> |
| 8 | -#include <utility> | ||
| 9 | #include <vector> | 8 | #include <vector> |
| 10 | 9 | ||
| 11 | #if __ANDROID_API__ >= 9 | 10 | #if __ANDROID_API__ >= 9 |
| @@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { | @@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { | ||
| 36 | 35 | ||
| 37 | /** Run the forward method of the model. | 36 | /** Run the forward method of the model. |
| 38 | * | 37 | * |
| 39 | - * @param features A tensor of shape (N, T, C). It is changed in-place. | 38 | + * @param features A tensor of shape (N, T, C). |
| 40 | * @param features_length A 1-D tensor of shape (N,) containing number of | 39 | * @param features_length A 1-D tensor of shape (N,) containing number of |
| 41 | * valid frames in `features` before padding. | 40 | * valid frames in `features` before padding. |
| 42 | * Its dtype is int64_t. | 41 | * Its dtype is int64_t. |
| @@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { | @@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { | ||
| 45 | * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | 44 | * - log_probs: A 3-D tensor of shape (N, T', vocab_size). |
| 46 | * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | 45 | * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t |
| 47 | */ | 46 | */ |
| 48 | - std::pair<Ort::Value, Ort::Value> Forward( | ||
| 49 | - Ort::Value features, Ort::Value /*features_length*/) override; | 47 | + std::vector<Ort::Value> Forward(Ort::Value features, |
| 48 | + Ort::Value /*features_length*/) override; | ||
| 50 | 49 | ||
| 51 | /** Return the vocabulary size of the model | 50 | /** Return the vocabulary size of the model |
| 52 | */ | 51 | */ |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("zipformer-ctc-model", &model, "Path to zipformer CTC model"); | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +bool OfflineZipformerCtcModelConfig::Validate() const { | ||
| 17 | + if (!FileExists(model)) { | ||
| 18 | + SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist", | ||
| 19 | + model.c_str()); | ||
| 20 | + return false; | ||
| 21 | + } | ||
| 22 | + | ||
| 23 | + return true; | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +std::string OfflineZipformerCtcModelConfig::ToString() const { | ||
| 27 | + std::ostringstream os; | ||
| 28 | + | ||
| 29 | + os << "OfflineZipformerCtcModelConfig("; | ||
| 30 | + os << "model=\"" << model << "\")"; | ||
| 31 | + | ||
| 32 | + return os.str(); | ||
| 33 | +} | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +// for | ||
| 14 | +// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py | ||
| 15 | +struct OfflineZipformerCtcModelConfig { | ||
| 16 | + std::string model; | ||
| 17 | + | ||
| 18 | + OfflineZipformerCtcModelConfig() = default; | ||
| 19 | + | ||
| 20 | + explicit OfflineZipformerCtcModelConfig(const std::string &model) | ||
| 21 | + : model(model) {} | ||
| 22 | + | ||
| 23 | + void Register(ParseOptions *po); | ||
| 24 | + | ||
| 25 | + bool Validate() const; | ||
| 26 | + | ||
| 27 | + std::string ToString() const; | ||
| 28 | +}; | ||
| 29 | + | ||
| 30 | +} // namespace sherpa_onnx | ||
| 31 | + | ||
| 32 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-ctc-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 8 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/session.h" | ||
| 10 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 11 | +#include "sherpa-onnx/csrc/transpose.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OfflineZipformerCtcModel::Impl { | ||
| 16 | + public: | ||
| 17 | + explicit Impl(const OfflineModelConfig &config) | ||
| 18 | + : config_(config), | ||
| 19 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 20 | + sess_opts_(GetSessionOptions(config)), | ||
| 21 | + allocator_{} { | ||
| 22 | + auto buf = ReadFile(config_.zipformer_ctc.model); | ||
| 23 | + Init(buf.data(), buf.size()); | ||
| 24 | + } | ||
| 25 | + | ||
| 26 | +#if __ANDROID_API__ >= 9 | ||
| 27 | + Impl(AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 28 | + : config_(config), | ||
| 29 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 30 | + sess_opts_(GetSessionOptions(config)), | ||
| 31 | + allocator_{} { | ||
| 32 | + auto buf = ReadFile(mgr, config_.zipformer_ctc.model); | ||
| 33 | + Init(buf.data(), buf.size()); | ||
| 34 | + } | ||
| 35 | +#endif | ||
| 36 | + | ||
| 37 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 38 | + Ort::Value features_length) { | ||
| 39 | + std::array<Ort::Value, 2> inputs = {std::move(features), | ||
| 40 | + std::move(features_length)}; | ||
| 41 | + | ||
| 42 | + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 43 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 47 | + int32_t SubsamplingFactor() const { return 4; } | ||
| 48 | + | ||
| 49 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 50 | + | ||
| 51 | + private: | ||
| 52 | + void Init(void *model_data, size_t model_data_length) { | ||
| 53 | + sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length, | ||
| 54 | + sess_opts_); | ||
| 55 | + | ||
| 56 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 57 | + | ||
| 58 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 59 | + | ||
| 60 | + // get meta data | ||
| 61 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 62 | + if (config_.debug) { | ||
| 63 | + std::ostringstream os; | ||
| 64 | + PrintModelMetadata(os, meta_data); | ||
| 65 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + // get vocab size from the output[0].shape, which is (N, T, vocab_size) | ||
| 69 | + vocab_size_ = | ||
| 70 | + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2]; | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + private: | ||
| 74 | + OfflineModelConfig config_; | ||
| 75 | + Ort::Env env_; | ||
| 76 | + Ort::SessionOptions sess_opts_; | ||
| 77 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 78 | + | ||
| 79 | + std::unique_ptr<Ort::Session> sess_; | ||
| 80 | + | ||
| 81 | + std::vector<std::string> input_names_; | ||
| 82 | + std::vector<const char *> input_names_ptr_; | ||
| 83 | + | ||
| 84 | + std::vector<std::string> output_names_; | ||
| 85 | + std::vector<const char *> output_names_ptr_; | ||
| 86 | + | ||
| 87 | + int32_t vocab_size_ = 0; | ||
| 88 | +}; | ||
| 89 | + | ||
| 90 | +OfflineZipformerCtcModel::OfflineZipformerCtcModel( | ||
| 91 | + const OfflineModelConfig &config) | ||
| 92 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 93 | + | ||
| 94 | +#if __ANDROID_API__ >= 9 | ||
| 95 | +OfflineZipformerCtcModel::OfflineZipformerCtcModel( | ||
| 96 | + AAssetManager *mgr, const OfflineModelConfig &config) | ||
| 97 | + : impl_(std::make_unique<Impl>(mgr, config)) {} | ||
| 98 | +#endif | ||
| 99 | + | ||
| 100 | +OfflineZipformerCtcModel::~OfflineZipformerCtcModel() = default; | ||
| 101 | + | ||
| 102 | +std::vector<Ort::Value> OfflineZipformerCtcModel::Forward( | ||
| 103 | + Ort::Value features, Ort::Value features_length) { | ||
| 104 | + return impl_->Forward(std::move(features), std::move(features_length)); | ||
| 105 | +} | ||
| 106 | + | ||
| 107 | +int32_t OfflineZipformerCtcModel::VocabSize() const { | ||
| 108 | + return impl_->VocabSize(); | ||
| 109 | +} | ||
| 110 | + | ||
| 111 | +OrtAllocator *OfflineZipformerCtcModel::Allocator() const { | ||
| 112 | + return impl_->Allocator(); | ||
| 113 | +} | ||
| 114 | + | ||
| 115 | +int32_t OfflineZipformerCtcModel::SubsamplingFactor() const { | ||
| 116 | + return impl_->SubsamplingFactor(); | ||
| 117 | +} | ||
| 118 | + | ||
| 119 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-zipformer-ctc-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ | ||
| 6 | +#include <memory> | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#if __ANDROID_API__ >= 9 | ||
| 12 | +#include "android/asset_manager.h" | ||
| 13 | +#include "android/asset_manager_jni.h" | ||
| 14 | +#endif | ||
| 15 | + | ||
| 16 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 17 | +#include "sherpa-onnx/csrc/offline-ctc-model.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 19 | + | ||
| 20 | +namespace sherpa_onnx { | ||
| 21 | + | ||
| 22 | +/** This class implements the zipformer CTC model of the librispeech recipe | ||
| 23 | + * from icefall. | ||
| 24 | + * | ||
| 25 | + * See | ||
| 26 | + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py | ||
| 27 | + */ | ||
| 28 | +class OfflineZipformerCtcModel : public OfflineCtcModel { | ||
| 29 | + public: | ||
| 30 | + explicit OfflineZipformerCtcModel(const OfflineModelConfig &config); | ||
| 31 | + | ||
| 32 | +#if __ANDROID_API__ >= 9 | ||
| 33 | + OfflineZipformerCtcModel(AAssetManager *mgr, | ||
| 34 | + const OfflineModelConfig &config); | ||
| 35 | +#endif | ||
| 36 | + | ||
| 37 | + ~OfflineZipformerCtcModel() override; | ||
| 38 | + | ||
| 39 | + /** Run the forward method of the model. | ||
| 40 | + * | ||
| 41 | + * @param features A tensor of shape (N, T, C). | ||
| 42 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 43 | + * valid frames in `features` before padding. | ||
| 44 | + * Its dtype is int64_t. | ||
| 45 | + * | ||
| 46 | + * @return Return a vector containing: | ||
| 47 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). | ||
| 48 | + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t | ||
| 49 | + */ | ||
| 50 | + std::vector<Ort::Value> Forward(Ort::Value features, | ||
| 51 | + Ort::Value features_length) override; | ||
| 52 | + | ||
| 53 | + /** Return the vocabulary size of the model | ||
| 54 | + */ | ||
| 55 | + int32_t VocabSize() const override; | ||
| 56 | + | ||
| 57 | + /** Return an allocator for allocating memory | ||
| 58 | + */ | ||
| 59 | + OrtAllocator *Allocator() const override; | ||
| 60 | + | ||
| 61 | + int32_t SubsamplingFactor() const override; | ||
| 62 | + | ||
| 63 | + private: | ||
| 64 | + class Impl; | ||
| 65 | + std::unique_ptr<Impl> impl_; | ||
| 66 | +}; | ||
| 67 | + | ||
| 68 | +} // namespace sherpa_onnx | ||
| 69 | + | ||
| 70 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ |
| @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx | @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 5 | display.cc | 5 | display.cc |
| 6 | endpoint.cc | 6 | endpoint.cc |
| 7 | features.cc | 7 | features.cc |
| 8 | + offline-ctc-fst-decoder-config.cc | ||
| 8 | offline-lm-config.cc | 9 | offline-lm-config.cc |
| 9 | offline-model-config.cc | 10 | offline-model-config.cc |
| 10 | offline-nemo-enc-dec-ctc-model-config.cc | 11 | offline-nemo-enc-dec-ctc-model-config.cc |
| @@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx | @@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 14 | offline-tdnn-model-config.cc | 15 | offline-tdnn-model-config.cc |
| 15 | offline-transducer-model-config.cc | 16 | offline-transducer-model-config.cc |
| 16 | offline-whisper-model-config.cc | 17 | offline-whisper-model-config.cc |
| 18 | + offline-zipformer-ctc-model-config.cc | ||
| 17 | online-lm-config.cc | 19 | online-lm-config.cc |
| 18 | online-model-config.cc | 20 | online-model-config.cc |
| 19 | online-paraformer-model-config.cc | 21 | online-paraformer-model-config.cc |
| 1 | +// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindOfflineCtcFstDecoderConfig(py::module *m) { | ||
| 14 | + using PyClass = OfflineCtcFstDecoderConfig; | ||
| 15 | + py::class_<PyClass>(*m, "OfflineCtcFstDecoderConfig") | ||
| 16 | + .def(py::init<const std::string &, int32_t>(), py::arg("graph") = "", | ||
| 17 | + py::arg("max_active") = 3000) | ||
| 18 | + .def_readwrite("graph", &PyClass::graph) | ||
| 19 | + .def_readwrite("max_active", &PyClass::max_active) | ||
| 20 | + .def("__str__", &PyClass::ToString); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineCtcFstDecoderConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ |
| @@ -13,6 +13,7 @@ | @@ -13,6 +13,7 @@ | ||
| 13 | #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" |
| 14 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" | 14 | #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" |
| 15 | #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" | 15 | #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" |
| 16 | +#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" | ||
| 16 | 17 | ||
| 17 | namespace sherpa_onnx { | 18 | namespace sherpa_onnx { |
| 18 | 19 | ||
| @@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 22 | PybindOfflineNemoEncDecCtcModelConfig(m); | 23 | PybindOfflineNemoEncDecCtcModelConfig(m); |
| 23 | PybindOfflineWhisperModelConfig(m); | 24 | PybindOfflineWhisperModelConfig(m); |
| 24 | PybindOfflineTdnnModelConfig(m); | 25 | PybindOfflineTdnnModelConfig(m); |
| 26 | + PybindOfflineZipformerCtcModelConfig(m); | ||
| 25 | 27 | ||
| 26 | using PyClass = OfflineModelConfig; | 28 | using PyClass = OfflineModelConfig; |
| 27 | py::class_<PyClass>(*m, "OfflineModelConfig") | 29 | py::class_<PyClass>(*m, "OfflineModelConfig") |
| @@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) { | @@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) { | ||
| 29 | const OfflineParaformerModelConfig &, | 31 | const OfflineParaformerModelConfig &, |
| 30 | const OfflineNemoEncDecCtcModelConfig &, | 32 | const OfflineNemoEncDecCtcModelConfig &, |
| 31 | const OfflineWhisperModelConfig &, | 33 | const OfflineWhisperModelConfig &, |
| 32 | - const OfflineTdnnModelConfig &, const std::string &, | 34 | + const OfflineTdnnModelConfig &, |
| 35 | + const OfflineZipformerCtcModelConfig &, const std::string &, | ||
| 33 | int32_t, bool, const std::string &, const std::string &>(), | 36 | int32_t, bool, const std::string &, const std::string &>(), |
| 34 | py::arg("transducer") = OfflineTransducerModelConfig(), | 37 | py::arg("transducer") = OfflineTransducerModelConfig(), |
| 35 | py::arg("paraformer") = OfflineParaformerModelConfig(), | 38 | py::arg("paraformer") = OfflineParaformerModelConfig(), |
| 36 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), | 39 | py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), |
| 37 | py::arg("whisper") = OfflineWhisperModelConfig(), | 40 | py::arg("whisper") = OfflineWhisperModelConfig(), |
| 38 | - py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"), | ||
| 39 | - py::arg("num_threads"), py::arg("debug") = false, | 41 | + py::arg("tdnn") = OfflineTdnnModelConfig(), |
| 42 | + py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), | ||
| 43 | + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, | ||
| 40 | py::arg("provider") = "cpu", py::arg("model_type") = "") | 44 | py::arg("provider") = "cpu", py::arg("model_type") = "") |
| 41 | .def_readwrite("transducer", &PyClass::transducer) | 45 | .def_readwrite("transducer", &PyClass::transducer) |
| 42 | .def_readwrite("paraformer", &PyClass::paraformer) | 46 | .def_readwrite("paraformer", &PyClass::paraformer) |
| 43 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) | 47 | .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) |
| 44 | .def_readwrite("whisper", &PyClass::whisper) | 48 | .def_readwrite("whisper", &PyClass::whisper) |
| 45 | .def_readwrite("tdnn", &PyClass::tdnn) | 49 | .def_readwrite("tdnn", &PyClass::tdnn) |
| 50 | + .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) | ||
| 46 | .def_readwrite("tokens", &PyClass::tokens) | 51 | .def_readwrite("tokens", &PyClass::tokens) |
| 47 | .def_readwrite("num_threads", &PyClass::num_threads) | 52 | .def_readwrite("num_threads", &PyClass::num_threads) |
| 48 | .def_readwrite("debug", &PyClass::debug) | 53 | .def_readwrite("debug", &PyClass::debug) |
| @@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") | 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") |
| 17 | .def(py::init<const OfflineFeatureExtractorConfig &, | 17 | .def(py::init<const OfflineFeatureExtractorConfig &, |
| 18 | const OfflineModelConfig &, const OfflineLMConfig &, | 18 | const OfflineModelConfig &, const OfflineLMConfig &, |
| 19 | - const std::string &, int32_t, const std::string &, float>(), | 19 | + const OfflineCtcFstDecoderConfig &, const std::string &, |
| 20 | + int32_t, const std::string &, float>(), | ||
| 20 | py::arg("feat_config"), py::arg("model_config"), | 21 | py::arg("feat_config"), py::arg("model_config"), |
| 21 | py::arg("lm_config") = OfflineLMConfig(), | 22 | py::arg("lm_config") = OfflineLMConfig(), |
| 23 | + py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), | ||
| 22 | py::arg("decoding_method") = "greedy_search", | 24 | py::arg("decoding_method") = "greedy_search", |
| 23 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", | 25 | py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", |
| 24 | py::arg("hotwords_score") = 1.5) | 26 | py::arg("hotwords_score") = 1.5) |
| 25 | .def_readwrite("feat_config", &PyClass::feat_config) | 27 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 26 | .def_readwrite("model_config", &PyClass::model_config) | 28 | .def_readwrite("model_config", &PyClass::model_config) |
| 27 | .def_readwrite("lm_config", &PyClass::lm_config) | 29 | .def_readwrite("lm_config", &PyClass::lm_config) |
| 30 | + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config) | ||
| 28 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 31 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 29 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) | 32 | .def_readwrite("max_active_paths", &PyClass::max_active_paths) |
| 30 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) | 33 | .def_readwrite("hotwords_file", &PyClass::hotwords_file) |
| 1 | +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindOfflineZipformerCtcModelConfig(py::module *m) { | ||
| 14 | + using PyClass = OfflineZipformerCtcModelConfig; | ||
| 15 | + py::class_<PyClass>(*m, "OfflineZipformerCtcModelConfig") | ||
| 16 | + .def(py::init<>()) | ||
| 17 | + .def(py::init<const std::string &>(), py::arg("model")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def("__str__", &PyClass::ToString); | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineZipformerCtcModelConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | #include "sherpa-onnx/python/csrc/display.h" | 8 | #include "sherpa-onnx/python/csrc/display.h" |
| 9 | #include "sherpa-onnx/python/csrc/endpoint.h" | 9 | #include "sherpa-onnx/python/csrc/endpoint.h" |
| 10 | #include "sherpa-onnx/python/csrc/features.h" | 10 | #include "sherpa-onnx/python/csrc/features.h" |
| 11 | +#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" | ||
| 11 | #include "sherpa-onnx/python/csrc/offline-lm-config.h" | 12 | #include "sherpa-onnx/python/csrc/offline-lm-config.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-model-config.h" | 13 | #include "sherpa-onnx/python/csrc/offline-model-config.h" |
| 13 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" | 14 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" |
| @@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 37 | PybindOfflineStream(&m); | 38 | PybindOfflineStream(&m); |
| 38 | PybindOfflineLMConfig(&m); | 39 | PybindOfflineLMConfig(&m); |
| 39 | PybindOfflineModelConfig(&m); | 40 | PybindOfflineModelConfig(&m); |
| 41 | + PybindOfflineCtcFstDecoderConfig(&m); | ||
| 40 | PybindOfflineRecognizer(&m); | 42 | PybindOfflineRecognizer(&m); |
| 41 | 43 | ||
| 42 | PybindVadModelConfig(&m); | 44 | PybindVadModelConfig(&m); |
| @@ -4,12 +4,14 @@ from pathlib import Path | @@ -4,12 +4,14 @@ from pathlib import Path | ||
| 4 | from typing import List, Optional | 4 | from typing import List, Optional |
| 5 | 5 | ||
| 6 | from _sherpa_onnx import ( | 6 | from _sherpa_onnx import ( |
| 7 | + OfflineCtcFstDecoderConfig, | ||
| 7 | OfflineFeatureExtractorConfig, | 8 | OfflineFeatureExtractorConfig, |
| 8 | OfflineModelConfig, | 9 | OfflineModelConfig, |
| 9 | OfflineNemoEncDecCtcModelConfig, | 10 | OfflineNemoEncDecCtcModelConfig, |
| 10 | OfflineParaformerModelConfig, | 11 | OfflineParaformerModelConfig, |
| 11 | OfflineTdnnModelConfig, | 12 | OfflineTdnnModelConfig, |
| 12 | OfflineWhisperModelConfig, | 13 | OfflineWhisperModelConfig, |
| 14 | + OfflineZipformerCtcModelConfig, | ||
| 13 | ) | 15 | ) |
| 14 | from _sherpa_onnx import OfflineRecognizer as _Recognizer | 16 | from _sherpa_onnx import OfflineRecognizer as _Recognizer |
| 15 | from _sherpa_onnx import ( | 17 | from _sherpa_onnx import ( |
-
请 注册 或 登录 后发表评论