Fangjun Kuang
Committed by GitHub

Add CTC HLG decoding using OpenFst (#349)

正在显示 39 个修改的文件 包含 964 行增加56 行删除
... ... @@ -89,3 +89,48 @@ time $EXE \
$repo/test_wavs/8k.wav
rm -rf $repo
log "------------------------------------------------------------"
log "Run Librispeech zipformer CTC H/HL/HLG decoding (English) "
log "------------------------------------------------------------"
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
git lfs pull --include "*.fst"
ls -lh
popd
graphs=(
$repo/H.fst
$repo/HL.fst
$repo/HLG.fst
)
for graph in ${graphs[@]}; do
log "test float32 models with $graph"
time $EXE \
--model-type=zipformer2_ctc \
--ctc.graph=$graph \
--zipformer-ctc-model=$repo/model.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
log "test int8 models with $graph"
time $EXE \
--model-type=zipformer2_ctc \
--ctc.graph=$graph \
--zipformer-ctc-model=$repo/model.int8.onnx \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/2.wav
done
rm -rf $repo
... ...
... ... @@ -18,7 +18,7 @@ permissions:
jobs:
python_online_websocket_server:
runs-on: ${{ matrix.os }}
name: ${{ matrix.os }} ${{ matrix.python-version }}
name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
strategy:
fail-fast: false
matrix:
... ...
... ... @@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
endif()
include(kaldi-native-fbank)
include(kaldi-decoder)
include(onnxruntime)
if(SHERPA_ONNX_ENABLE_PORTAUDIO)
... ...
function(download_eigen)
include(FetchContent)
set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz")
set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz")
set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72")
# If you don't have access to the Internet,
# please pre-download eigen
set(possible_file_locations
$ENV{HOME}/Downloads/eigen-3.4.0.tar.gz
${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz
${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz
/tmp/eigen-3.4.0.tar.gz
/star-fj/fangjun/download/github/eigen-3.4.0.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(eigen_URL "${f}")
file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL)
message(STATUS "Found local downloaded eigen: ${eigen_URL}")
set(eigen_URL2)
break()
endif()
endforeach()
set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE)
FetchContent_Declare(eigen
URL ${eigen_URL}
URL_HASH ${eigen_HASH}
)
FetchContent_GetProperties(eigen)
if(NOT eigen_POPULATED)
message(STATUS "Downloading eigen from ${eigen_URL}")
FetchContent_Populate(eigen)
endif()
message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}")
message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}")
add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL)
endfunction()
download_eigen()
... ...
function(download_kaldi_decoder)
include(FetchContent)
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.3.tar.gz")
set(kaldi_decoder_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.3.tar.gz")
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
# If you don't have access to the Internet,
# please pre-download kaldi-decoder
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-decoder-0.2.3.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-decoder-0.2.3.tar.gz
${PROJECT_BINARY_DIR}/kaldi-decoder-0.2.3.tar.gz
/tmp/kaldi-decoder-0.2.3.tar.gz
/star-fj/fangjun/download/github/kaldi-decoder-0.2.3.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(kaldi_decoder_URL "${f}")
file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL)
message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}")
set(kaldi_decoder_URL2 )
break()
endif()
endforeach()
FetchContent_Declare(kaldi_decoder
URL
${kaldi_decoder_URL}
${kaldi_decoder_URL2}
URL_HASH ${kaldi_decoder_HASH}
)
FetchContent_GetProperties(kaldi_decoder)
if(NOT kaldi_decoder_POPULATED)
message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}")
FetchContent_Populate(kaldi_decoder)
endif()
message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}")
message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}")
include_directories(${kaldi_decoder_SOURCE_DIR})
add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(kaldi-decoder-core
INTERFACE
${kaldi-decoder_SOURCE_DIR}/
)
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
install(TARGETS
kaldi-decoder-core
kaldifst_core
fst
DESTINATION ..)
else()
install(TARGETS
kaldi-decoder-core
kaldifst_core
fst
DESTINATION lib)
endif()
if(WIN32 AND BUILD_SHARED_LIBS)
install(TARGETS
kaldi-decoder-core
kaldifst_core
fst
DESTINATION bin)
endif()
endfunction()
download_kaldi_decoder()
... ...
function(download_kaldifst)
include(FetchContent)
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz")
set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2")
# If you don't have access to the Internet,
# please pre-download kaldifst
set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz
/tmp/kaldifst-1.7.6.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz
)
foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(kaldifst_URL "${f}")
file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL)
message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}")
set(kaldifst_URL2)
break()
endif()
endforeach()
set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
FetchContent_Declare(kaldifst
URL ${kaldifst_URL}
URL_HASH ${kaldifst_HASH}
)
FetchContent_GetProperties(kaldifst)
if(NOT kaldifst_POPULATED)
message(STATUS "Downloading kaldifst from ${kaldifst_URL}")
FetchContent_Populate(kaldifst)
endif()
message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}")
message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}")
list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake)
add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(kaldifst_core
PUBLIC
${kaldifst_SOURCE_DIR}/
)
target_include_directories(fst
PUBLIC
${openfst_SOURCE_DIR}/src/include
)
set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-onnx-kaldifst-core")
set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst")
endfunction()
download_kaldifst()
... ...
... ... @@ -13,4 +13,4 @@ Cflags: -I"${includedir}"
# Note: -lcargs is required only for the following file
# https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c
# We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c
Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
... ...
... ... @@ -9,6 +9,9 @@
sherpa-onnx-portaudio_static.lib;
sherpa-onnx-c-api.lib;
sherpa-onnx-core.lib;
kaldi-decoder-core.lib;
sherpa-onnx-kaldifst-core.lib;
sherpa-onnx-fst.lib;
kaldi-native-fbank-core.lib;
absl_base.lib;
absl_city.lib;
... ...
... ... @@ -9,6 +9,9 @@
sherpa-onnx-portaudio_static.lib;
sherpa-onnx-c-api.lib;
sherpa-onnx-core.lib;
kaldi-decoder-core.lib;
sherpa-onnx-kaldifst-core.lib;
sherpa-onnx-fst.lib;
kaldi-native-fbank-core.lib;
absl_base.lib;
absl_city.lib;
... ...
... ... @@ -19,6 +19,8 @@ set(sources
features.cc
file-utils.cc
hypothesis.cc
offline-ctc-fst-decoder-config.cc
offline-ctc-fst-decoder.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-lm-config.cc
... ... @@ -42,6 +44,8 @@ set(sources
offline-whisper-greedy-search-decoder.cc
offline-whisper-model-config.cc
offline-whisper-model.cc
offline-zipformer-ctc-model-config.cc
offline-zipformer-ctc-model.cc
online-conformer-transducer-model.cc
online-lm-config.cc
online-lm.cc
... ... @@ -97,6 +101,8 @@ endif()
target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core)
target_link_libraries(sherpa-onnx-core kaldi-decoder-core)
if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm)
target_link_libraries(sherpa-onnx-core onnxruntime)
else()
... ...
// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include <sstream>
#include <string>
namespace sherpa_onnx {
std::string OfflineCtcFstDecoderConfig::ToString() const {
std::ostringstream os;
os << "OfflineCtcFstDecoderConfig(";
os << "graph=\"" << graph << "\", ";
os << "max_active=" << max_active << ")";
return os.str();
}
void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
std::string prefix = "ctc";
ParseOptions p(prefix, po);
p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
p.Register("max-active", &max_active,
"Decoder max active states. Larger->slower; more accurate");
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineCtcFstDecoderConfig {
// Path to H.fst, HL.fst or HLG.fst
std::string graph;
int32_t max_active = 3000;
OfflineCtcFstDecoderConfig() = default;
OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
: graph(graph), max_active(max_active) {}
std::string ToString() const;
void Register(ParseOptions *po);
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
#include <string>
#include <utility>
#include "fst/fstlib.h"
#include "kaldi-decoder/csrc/decodable-ctc.h"
#include "kaldi-decoder/csrc/eigen.h"
#include "kaldi-decoder/csrc/faster-decoder.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
// This function is copied from kaldi.
//
// @param filename Path to a StdVectorFst or StdConstFst graph
// @return The caller should free the returned pointer using `delete` to
// avoid memory leak.
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
// read decoding network FST
std::ifstream is(filename, std::ios::binary);
if (!is.good()) {
SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str());
}
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
SHERPA_ONNX_LOGE("Reading FST: error reading FST header.");
}
if (hdr.ArcType() != fst::StdArc::Type()) {
SHERPA_ONNX_LOGE("FST with arc type %s not supported",
hdr.ArcType().c_str());
}
fst::FstReadOptions ropts("<unspecified>", &hdr);
fst::Fst<fst::StdArc> *decode_fst = nullptr;
if (hdr.FstType() == "vector") {
decode_fst = fst::VectorFst<fst::StdArc>::Read(is, ropts);
} else if (hdr.FstType() == "const") {
decode_fst = fst::ConstFst<fst::StdArc>::Read(is, ropts);
} else {
SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s",
hdr.FstType().c_str());
}
if (decode_fst == nullptr) { // fst code will warn.
SHERPA_ONNX_LOGE("Error reading FST (after reading header).");
return nullptr;
} else {
return decode_fst;
}
}
/**
* @param decoder
* @param p Pointer to a 2-d array of shape (num_frames, vocab_size)
* @param num_frames Number of rows in the 2-d array.
* @param vocab_size Number of columns in the 2-d array.
* @return Return the decoded result.
*/
static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
const float *p, int32_t num_frames,
int32_t vocab_size) {
OfflineCtcDecoderResult r;
kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size);
decoder->Decode(&decodable);
if (!decoder->ReachedFinal()) {
SHERPA_ONNX_LOGE("Not reached final!");
return r;
}
fst::VectorFst<fst::LatticeArc> decoded; // linear FST.
decoder->GetBestPath(&decoded);
if (decoded.NumStates() == 0) {
SHERPA_ONNX_LOGE("Empty best path!");
return r;
}
auto cur_state = decoded.Start();
int32_t blank_id = 0;
for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) {
fst::ArcIterator<fst::Fst<fst::LatticeArc>> iter(decoded, cur_state);
const auto &arc = iter.Value();
cur_state = arc.nextstate;
if (arc.ilabel == prev) {
continue;
}
// 0 is epsilon here
if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) {
prev = arc.ilabel;
continue;
}
// -1 here since the input labels are incremented during graph
// construction
r.tokens.push_back(arc.ilabel - 1);
r.timestamps.push_back(t);
prev = arc.ilabel;
}
return r;
}
OfflineCtcFstDecoder::OfflineCtcFstDecoder(
const OfflineCtcFstDecoderConfig &config)
: config_(config), fst_(ReadGraph(config_.graph)) {}
std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
assert(static_cast<int32_t>(shape.size()) == 3);
int32_t batch_size = shape[0];
int32_t T = shape[1];
int32_t vocab_size = shape[2];
std::vector<int64_t> length_shape =
log_probs_length.GetTensorTypeAndShapeInfo().GetShape();
assert(static_cast<int32_t>(length_shape.size()) == 1);
assert(shape[0] == length_shape[0]);
kaldi_decoder::FasterDecoderOptions opts;
opts.max_active = config_.max_active;
kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts);
const float *start = log_probs.GetTensorData<float>();
std::vector<OfflineCtcDecoderResult> ans;
ans.reserve(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
const float *p = start + i * T * vocab_size;
int32_t num_frames = log_probs_length.GetTensorData<int64_t>()[i];
auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size);
ans.push_back(std::move(r));
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-ctc-fst-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
#include <memory>
#include <vector>
#include "fst/fst.h"
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
class OfflineCtcFstDecoder : public OfflineCtcDecoder {
public:
explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config);
std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
private:
OfflineCtcFstDecoderConfig config_;
std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
... ...
... ... @@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace {
... ... @@ -19,6 +20,7 @@ namespace {
enum class ModelType {
kEncDecCTCModelBPE,
kTdnn,
kZipformerCtc,
kUnkown,
};
... ... @@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kEncDecCTCModelBPE;
} else if (model_type.get() == std::string("tdnn")) {
return ModelType::kTdnn;
} else if (model_type.get() == std::string("zipformer2_ctc")) {
return ModelType::kZipformerCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
... ... @@ -74,6 +78,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -92,6 +98,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(config);
break;
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ... @@ -111,6 +120,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
filename = config.nemo_ctc.model;
} else if (!config.tdnn.model.empty()) {
filename = config.tdnn.model;
} else if (!config.zipformer_ctc.model.empty()) {
filename = config.zipformer_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
... ... @@ -129,6 +140,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kTdnn:
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
break;
case ModelType::kZipformerCtc:
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
... ...
... ... @@ -6,7 +6,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ... @@ -32,17 +32,17 @@ class OfflineCtcModel {
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a pair containing:
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
virtual std::pair<Ort::Value, Ort::Value> Forward(
Ort::Value features, Ort::Value features_length) = 0;
virtual std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) = 0;
/** Return the vocabulary size of the model
*/
... ...
... ... @@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
nemo_ctc.Register(po);
whisper.Register(po);
tdnn.Register(po);
zipformer_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
"tdnn."
"tdnn, zipformer2_ctc"
"All other values lead to loading the model twice.");
}
... ... @@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const {
return tdnn.Validate();
}
if (!zipformer_ctc.model.empty()) {
return zipformer_ctc.Validate();
}
return transducer.Validate();
}
... ... @@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const {
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "whisper=" << whisper.ToString() << ", ";
os << "tdnn=" << tdnn.ToString() << ", ";
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
namespace sherpa_onnx {
... ... @@ -20,6 +21,7 @@ struct OfflineModelConfig {
OfflineNemoEncDecCtcModelConfig nemo_ctc;
OfflineWhisperModelConfig whisper;
OfflineTdnnModelConfig tdnn;
OfflineZipformerCtcModelConfig zipformer_ctc;
std::string tokens;
int32_t num_threads = 2;
... ... @@ -43,6 +45,7 @@ struct OfflineModelConfig {
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const OfflineWhisperModelConfig &whisper,
const OfflineTdnnModelConfig &tdnn,
const OfflineZipformerCtcModelConfig &zipformer_ctc,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
... ... @@ -50,6 +53,7 @@ struct OfflineModelConfig {
nemo_ctc(nemo_ctc),
whisper(whisper),
tdnn(tdnn),
zipformer_ctc(zipformer_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug),
... ...
... ... @@ -34,8 +34,8 @@ class OfflineNemoEncDecCtcModel::Impl {
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<int64_t> shape =
features_length.GetTensorTypeAndShapeInfo().GetShape();
... ... @@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl {
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return {std::move(out[0]), std::move(out_features_length)};
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(out[0]));
ans.push_back(std::move(out_features_length));
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
... ... @@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
std::vector<Ort::Value> OfflineNemoEncDecCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
... ...
... ... @@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a pair containing:
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::pair<Ort::Value, Ort::Value> Forward(
Ort::Value features, Ort::Value features_length) override;
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
... ...
... ... @@ -16,6 +16,7 @@
#endif
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
... ... @@ -25,9 +26,12 @@
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
const SymbolTable &sym_table) {
const SymbolTable &sym_table,
int32_t frame_shift_ms,
int32_t subsampling_factor) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
r.timestamps.reserve(src.timestamps.size());
std::string text;
... ... @@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
}
r.text = std::move(text);
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
for (auto t : src.timestamps) {
float time = frame_shift_s * t;
r.timestamps.push_back(time);
}
return r;
}
... ... @@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
if (config_.decoding_method == "greedy_search") {
if (!config_.ctc_fst_decoder_config.graph.empty()) {
// TODO(fangjun): Support android to read the graph from
// asset_manager
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
SHERPA_ONNX_LOGE(
... ... @@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
-23.025850929940457f);
auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
int32_t frame_shift_ms = 10;
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_);
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
ss[i]->SetResult(r);
}
}
... ...
... ... @@ -25,9 +25,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "tdnn") {
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
... ... @@ -50,6 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
... ... @@ -93,6 +94,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
"\n"
"(5) Zipformer CTC models from icefall"
"\n "
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"zipformer/export-onnx-ctc.py"
"\n"
"\n");
exit(-1);
}
... ... @@ -107,11 +113,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
}
if (model_type == "EncDecCTCModelBPE") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
if (model_type == "tdnn") {
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
... ... @@ -126,7 +129,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n",
" - Tdnn models\n"
" - Zipformer CTC models\n",
model_type.c_str());
exit(-1);
... ... @@ -141,9 +145,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
} else if (model_type == "paraformer") {
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
} else if (model_type == "nemo_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "tdnn") {
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
} else if (model_type == "whisper") {
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
... ... @@ -166,6 +169,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.nemo_ctc.model;
} else if (!config.model_config.tdnn.model.empty()) {
model_filename = config.model_config.tdnn.model;
} else if (!config.model_config.zipformer_ctc.model.empty()) {
model_filename = config.model_config.zipformer_ctc.model;
} else if (!config.model_config.whisper.encoder.empty()) {
model_filename = config.model_config.whisper.encoder;
} else {
... ... @@ -209,6 +214,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
"\n "
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
"\n"
"(5) Zipformer CTC models from icefall"
"\n "
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"zipformer/export-onnx-ctc.py"
"\n"
"\n");
exit(-1);
}
... ... @@ -223,11 +233,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
}
if (model_type == "EncDecCTCModelBPE") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
if (model_type == "tdnn") {
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
model_type == "zipformer2_ctc") {
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
}
... ... @@ -242,7 +249,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n"
" - Whisper models\n"
" - Tdnn models\n",
" - Tdnn models\n"
" - Zipformer CTC models\n",
model_type.c_str());
exit(-1);
... ...
... ... @@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
lm_config.Register(po);
ctc_fst_decoder_config.Register(po);
po->Register(
"decoding-method", &decoding_method,
... ... @@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
... ...
... ... @@ -14,6 +14,7 @@
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-stream.h"
... ... @@ -28,6 +29,7 @@ struct OfflineRecognizerConfig {
OfflineFeatureExtractorConfig feat_config;
OfflineModelConfig model_config;
OfflineLMConfig lm_config;
OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4;
... ... @@ -39,16 +41,16 @@ struct OfflineRecognizerConfig {
// TODO(fangjun): Implement modified_beam_search
OfflineRecognizerConfig() = default;
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
const OfflineModelConfig &model_config,
const OfflineLMConfig &lm_config,
const std::string &decoding_method,
int32_t max_active_paths,
const std::string &hotwords_file,
float hotwords_score)
OfflineRecognizerConfig(
const OfflineFeatureExtractorConfig &feat_config,
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
ctc_fst_decoder_config(ctc_fst_decoder_config),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
... ...
... ... @@ -4,6 +4,8 @@
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
#include <utility>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
... ... @@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl {
}
#endif
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
std::vector<Ort::Value> Forward(Ort::Value features) {
auto nnet_out =
sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
... ... @@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl {
memory_info, out_length_vec.data(), out_length_vec.size(),
out_length_shape.data(), out_length_shape.size());
return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)};
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(nnet_out[0]));
ans.push_back(Clone(Allocator(), &nnet_out_length));
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
... ... @@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr,
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(
std::vector<Ort::Value> OfflineTdnnCtcModel::Forward(
Ort::Value features, Ort::Value /*features_length*/) {
return impl_->Forward(std::move(features));
}
... ...
... ... @@ -5,7 +5,6 @@
#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
... ... @@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel {
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
... ... @@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel {
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::pair<Ort::Value, Ort::Value> Forward(
Ort::Value features, Ort::Value /*features_length*/) override;
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value /*features_length*/) override;
/** Return the vocabulary size of the model
*/
... ...
// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) {
po->Register("zipformer-ctc-model", &model, "Path to zipformer CTC model");
}
bool OfflineZipformerCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist",
model.c_str());
return false;
}
return true;
}
std::string OfflineZipformerCtcModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineZipformerCtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
// for
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
struct OfflineZipformerCtcModelConfig {
std::string model;
OfflineZipformerCtcModelConfig() = default;
explicit OfflineZipformerCtcModelConfig(const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-zipformer-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineZipformerCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.zipformer_ctc.model);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.zipformer_ctc.model);
Init(buf.data(), buf.size());
}
#endif
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return 4; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
// get vocab size from the output[0].shape, which is (N, T, vocab_size)
vocab_size_ =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2];
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t vocab_size_ = 0;
};
OfflineZipformerCtcModel::OfflineZipformerCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineZipformerCtcModel::OfflineZipformerCtcModel(
AAssetManager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineZipformerCtcModel::~OfflineZipformerCtcModel() = default;
std::vector<Ort::Value> OfflineZipformerCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineZipformerCtcModel::VocabSize() const {
return impl_->VocabSize();
}
OrtAllocator *OfflineZipformerCtcModel::Allocator() const {
return impl_->Allocator();
}
int32_t OfflineZipformerCtcModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-zipformer-ctc-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
/** This class implements the zipformer CTC model of the librispeech recipe
* from icefall.
*
* See
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
*/
class OfflineZipformerCtcModel : public OfflineCtcModel {
public:
explicit OfflineZipformerCtcModel(const OfflineModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineZipformerCtcModel(AAssetManager *mgr,
const OfflineModelConfig &config);
#endif
~OfflineZipformerCtcModel() override;
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a vector containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::vector<Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
int32_t SubsamplingFactor() const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
... ...
... ... @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc
offline-ctc-fst-decoder-config.cc
offline-lm-config.cc
offline-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc
... ... @@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
offline-tdnn-model-config.cc
offline-transducer-model-config.cc
offline-whisper-model-config.cc
offline-zipformer-ctc-model-config.cc
online-lm-config.cc
online-model-config.cc
online-paraformer-model-config.cc
... ...
// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
#include <string>
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
namespace sherpa_onnx {
void PybindOfflineCtcFstDecoderConfig(py::module *m) {
using PyClass = OfflineCtcFstDecoderConfig;
py::class_<PyClass>(*m, "OfflineCtcFstDecoderConfig")
.def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
py::arg("max_active") = 3000)
.def_readwrite("graph", &PyClass::graph)
.def_readwrite("max_active", &PyClass::max_active)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineCtcFstDecoderConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
... ...
... ... @@ -13,6 +13,7 @@
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
namespace sherpa_onnx {
... ... @@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) {
PybindOfflineNemoEncDecCtcModelConfig(m);
PybindOfflineWhisperModelConfig(m);
PybindOfflineTdnnModelConfig(m);
PybindOfflineZipformerCtcModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
... ... @@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) {
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const OfflineWhisperModelConfig &,
const OfflineTdnnModelConfig &, const std::string &,
const OfflineTdnnModelConfig &,
const OfflineZipformerCtcModelConfig &, const std::string &,
int32_t, bool, const std::string &, const std::string &>(),
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("whisper") = OfflineWhisperModelConfig(),
py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false,
py::arg("tdnn") = OfflineTdnnModelConfig(),
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
py::arg("provider") = "cpu", py::arg("model_type") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("whisper", &PyClass::whisper)
.def_readwrite("tdnn", &PyClass::tdnn)
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
... ... @@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &,
const std::string &, int32_t, const std::string &, float>(),
const OfflineCtcFstDecoderConfig &, const std::string &,
int32_t, const std::string &, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
... ...
// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
#include <string>
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOfflineZipformerCtcModelConfig(py::module *m) {
using PyClass = OfflineZipformerCtcModelConfig;
py::class_<PyClass>(*m, "OfflineZipformerCtcModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineZipformerCtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
... ...
... ... @@ -8,6 +8,7 @@
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
#include "sherpa-onnx/python/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
... ... @@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineStream(&m);
PybindOfflineLMConfig(&m);
PybindOfflineModelConfig(&m);
PybindOfflineCtcFstDecoderConfig(&m);
PybindOfflineRecognizer(&m);
PybindVadModelConfig(&m);
... ...
... ... @@ -4,12 +4,14 @@ from pathlib import Path
from typing import List, Optional
from _sherpa_onnx import (
OfflineCtcFstDecoderConfig,
OfflineFeatureExtractorConfig,
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
OfflineTdnnModelConfig,
OfflineWhisperModelConfig,
OfflineZipformerCtcModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
... ...