Fangjun Kuang
Committed by GitHub

Begin to support CTC models (#119)

Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
正在显示 40 个修改的文件 包含 1237 行增加53 行删除
#!/usr/bin/env bash
set -e
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
echo "EXE is $EXE"
echo "PATH: $PATH"
which $EXE
log "------------------------------------------------------------"
log "Run Citrinet (stt_en_citrinet_512, English)"
log "------------------------------------------------------------"
repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
log "Start testing ${repo_url}"
repo=$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
pushd $repo
git lfs pull --include "*.onnx"
ls -lh *.onnx
popd
time $EXE \
--tokens=$repo/tokens.txt \
--nemo-ctc-model=$repo/model.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
time $EXE \
--tokens=$repo/tokens.txt \
--nemo-ctc-model=$repo/model.int8.onnx \
--num-threads=2 \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
rm -rf $repo
... ...
... ... @@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $repo
log "Test non-streaming paraformer models"
pushd $dir
... ... @@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \
$repo/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $repo
log "Test non-streaming NeMo CTC models"
pushd $dir
repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
log "Start testing ${repo_url}"
repo=$dir/$(basename $repo_url)
log "Download pretrained model and test-data from $repo_url"
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
cd $repo
git lfs pull --include "*.onnx"
popd
ls -lh $repo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=$repo/tokens.txt \
--nemo-ctc=$repo/model.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
python3 ./python-api-examples/offline-decode-files.py \
--tokens=$repo/tokens.txt \
--nemo-ctc=$repo/model.int8.onnx \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
rm -rf $repo
... ...
... ... @@ -8,6 +8,7 @@ on:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -20,6 +21,7 @@ on:
- '.github/workflows/linux.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -68,6 +70,14 @@ jobs:
file build/bin/sherpa-onnx
readelf -d build/bin/sherpa-onnx
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer
shell: bash
run: |
... ...
... ... @@ -8,6 +8,7 @@ on:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -18,6 +19,7 @@ on:
- '.github/workflows/macos.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -67,6 +69,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test offline CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer
shell: bash
run: |
... ...
... ... @@ -8,6 +8,7 @@ on:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -18,6 +19,7 @@ on:
- '.github/workflows/windows-x64.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -73,6 +75,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test offline CTC for windows x64
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer for Windows x64
shell: bash
run: |
... ...
... ... @@ -8,6 +8,7 @@ on:
- '.github/workflows/windows-x86.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -18,6 +19,7 @@ on:
- '.github/workflows/windows-x86.yaml'
- '.github/scripts/test-online-transducer.sh'
- '.github/scripts/test-offline-transducer.sh'
- '.github/scripts/test-offline-ctc.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
... ... @@ -31,6 +33,7 @@ permissions:
jobs:
windows_x86:
if: false # disable windows x86 CI for now
runs-on: ${{ matrix.os }}
name: ${{ matrix.vs-version }}
strategy:
... ... @@ -73,6 +76,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test offline CTC for windows x86
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline.exe
.github/scripts/test-offline-ctc.sh
- name: Test offline transducer for Windows x86
shell: bash
run: |
... ...
... ... @@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh
run-sherpa-onnx-*.sh
sherpa-onnx-zipformer-en-2023-03-30
sherpa-onnx-zipformer-en-2023-04-01
run-offline-decode-files.sh
sherpa-onnx-nemo-ctc-en-citrinet-512
run-offline-decode-files-nemo-ctc.sh
... ...
... ... @@ -6,7 +6,7 @@
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
paraformer Usage:
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
... ... @@ -18,7 +18,7 @@ paraformer Usage:
/path/to/0.wav \
/path/to/1.wav
transducer Usage:
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
... ... @@ -32,6 +32,8 @@ transducer Usage:
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
... ... @@ -83,7 +85,14 @@ def get_args():
"--paraformer",
default="",
type=str,
help="Path to the paraformer model",
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
... ... @@ -171,11 +180,14 @@ def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
if len(args.encoder) > 0:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert_file_exists(args.encoder)
assert_file_exists(args.decoder)
assert_file_exists(args.joiner)
assert len(args.paraformer) == 0, args.paraformer
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
encoder=args.encoder,
decoder=args.decoder,
... ... @@ -187,8 +199,10 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
... ... @@ -198,6 +212,19 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.nemo_ctc:
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=args.nemo_ctc,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
print("Please specify at least one model")
return
print("Started!")
start_time = time.time()
... ... @@ -225,12 +252,14 @@ def main():
print("-" * 10)
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
if __name__ == "__main__":
... ...
... ... @@ -172,12 +172,14 @@ def main():
print("-" * 10)
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {duration:.3f} s")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
if __name__ == "__main__":
... ...
... ... @@ -16,7 +16,11 @@ set(sources
features.cc
file-utils.cc
hypothesis.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc
offline-nemo-enc-dec-ctc-model.cc
offline-paraformer-greedy-search-decoder.cc
offline-paraformer-model-config.cc
offline-paraformer-model.cc
... ...
... ... @@ -11,6 +11,8 @@
#include "android/log.h"
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
static_cast<int>(__LINE__)); \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
... ... @@ -18,6 +20,8 @@
#else
#define SHERPA_ONNX_LOGE(...) \
do { \
fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
static_cast<int>(__LINE__)); \
fprintf(stderr, ##__VA_ARGS__); \
fprintf(stderr, "\n"); \
} while (0)
... ...
// sherpa-onnx/csrc/offline-ctc-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OfflineCtcDecoderResult {
/// The decoded token IDs
std::vector<int64_t> tokens;
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
/// Note: The index is after subsampling
std::vector<int32_t> timestamps;
};
class OfflineCtcDecoder {
public:
virtual ~OfflineCtcDecoder() = default;
/** Run CTC decoding given the output from the encoder model.
*
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
* lob_probs.
* @param log_probs_length A 1-D tensor of shape (N,) containing number
* of valid frames in log_probs before padding.
*
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode(
Ort::Value log_probs, Ort::Value log_probs_length) {
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = static_cast<int32_t>(shape[0]);
int32_t num_frames = static_cast<int32_t>(shape[1]);
int32_t vocab_size = static_cast<int32_t>(shape[2]);
const int64_t *p_log_probs_length = log_probs_length.GetTensorData<int64_t>();
std::vector<OfflineCtcDecoderResult> ans;
ans.reserve(batch_size);
for (int32_t b = 0; b != batch_size; ++b) {
const float *p_log_probs =
log_probs.GetTensorData<float>() + b * num_frames * vocab_size;
OfflineCtcDecoderResult r;
int64_t prev_id = -1;
for (int32_t t = 0; t != static_cast<int32_t>(p_log_probs_length[b]); ++t) {
auto y = static_cast<int64_t>(std::distance(
static_cast<const float *>(p_log_probs),
std::max_element(
static_cast<const float *>(p_log_probs),
static_cast<const float *>(p_log_probs) + vocab_size)));
p_log_probs += vocab_size;
if (y != blank_id_ && y != prev_id) {
r.tokens.push_back(y);
r.timestamps.push_back(t);
prev_id = y;
}
} // for (int32_t t = 0; ...)
ans.push_back(std::move(r));
}
return ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
namespace sherpa_onnx {
class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder {
public:
explicit OfflineCtcGreedySearchDecoder(int32_t blank_id)
: blank_id_(blank_id) {}
std::vector<OfflineCtcDecoderResult> Decode(
Ort::Value log_probs, Ort::Value log_probs_length) override;
private:
int32_t blank_id_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
... ...
// sherpa-onnx/csrc/offline-ctc-model.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace {
enum class ModelType {
kEncDecCTCModelBPE,
kUnkown,
};
}
namespace sherpa_onnx {
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"If you are using models from NeMo, please refer to\n"
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
}
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
return ModelType::kEncDecCTCModelBPE;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
{
auto buffer = ReadFile(config.nemo_ctc.model);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kEncDecCTCModelBPE:
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
break;
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}
return nullptr;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-ctc-model.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
class OfflineCtcModel {
public:
virtual ~OfflineCtcModel() = default;
static std::unique_ptr<OfflineCtcModel> Create(
const OfflineModelConfig &config);
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a pair containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
virtual std::pair<Ort::Value, Ort::Value> Forward(
Ort::Value features, Ort::Value features_length) = 0;
/** Return the vocabulary size of the model
*/
virtual int32_t VocabSize() const = 0;
/** SubsamplingFactor of the model
*
* For Citrinet, the subsampling factor is usually 4.
* For Conformer CTC, the subsampling factor is usually 8.
*/
virtual int32_t SubsamplingFactor() const = 0;
/** Return an allocator for allocating memory
*/
virtual OrtAllocator *Allocator() const = 0;
/** For some models, e.g., those from NeMo, they require some preprocessing
* for the features.
*/
virtual std::string FeatureNormalizationMethod() const { return {}; }
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
... ...
... ... @@ -13,6 +13,7 @@ namespace sherpa_onnx {
void OfflineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
paraformer.Register(po);
nemo_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
... ... @@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const {
return paraformer.Validate();
}
if (!nemo_ctc.model.empty()) {
return nemo_ctc.Validate();
}
return transducer.Validate();
}
... ... @@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const {
os << "OfflineModelConfig(";
os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ")";
... ...
... ... @@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
... ... @@ -14,6 +15,7 @@ namespace sherpa_onnx {
struct OfflineModelConfig {
OfflineTransducerModelConfig transducer;
OfflineParaformerModelConfig paraformer;
OfflineNemoEncDecCtcModelConfig nemo_ctc;
std::string tokens;
int32_t num_threads = 2;
... ... @@ -22,9 +24,11 @@ struct OfflineModelConfig {
OfflineModelConfig() = default;
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
const OfflineParaformerModelConfig &paraformer,
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads, bool debug)
: transducer(transducer),
paraformer(paraformer),
nemo_ctc(nemo_ctc),
tokens(tokens),
num_threads(num_threads),
debug(debug) {}
... ...
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
po->Register("nemo-ctc-model", &model,
"Path to model.onnx of Nemo EncDecCtcModel.");
}
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
return false;
}
return true;
}
std::string OfflineNemoEncDecCtcModelConfig::ToString() const {
std::ostringstream os;
os << "OfflineNemoEncDecCtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineNemoEncDecCtcModelConfig {
std::string model;
OfflineNemoEncDecCtcModelConfig() = default;
explicit OfflineNemoEncDecCtcModelConfig(const std::string &model)
: model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineNemoEncDecCtcModel::Impl {
public:
explicit Impl(const OfflineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{},
allocator_{} {
sess_opts_.SetIntraOpNumThreads(config_.num_threads);
sess_opts_.SetInterOpNumThreads(config_.num_threads);
Init();
}
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
Ort::Value features_length) {
std::vector<int64_t> shape =
features_length.GetTensorTypeAndShapeInfo().GetShape();
Ort::Value out_features_length = Ort::Value::CreateTensor<int64_t>(
allocator_, shape.data(), shape.size());
const int64_t *src = features_length.GetTensorData<int64_t>();
int64_t *dst = out_features_length.GetTensorMutableData<int64_t>();
for (int64_t i = 0; i != shape[0]; ++i) {
dst[i] = src[i] / subsampling_factor_;
}
// (B, T, C) -> (B, C, T)
features = Transpose12(allocator_, &features);
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return {std::move(out[0]), std::move(out_features_length)};
}
int32_t VocabSize() const { return vocab_size_; }
int32_t SubsamplingFactor() const { return subsampling_factor_; }
OrtAllocator *Allocator() const { return allocator_; }
std::string FeatureNormalizationMethod() const { return normalize_type_; }
private:
void Init() {
auto buf = ReadFile(config_.nemo_ctc.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
}
private:
OfflineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t vocab_size_ = 0;
int32_t subsampling_factor_ = 0;
std::string normalize_type_;
};
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
Ort::Value features, Ort::Value features_length) {
return impl_->Forward(std::move(features), std::move(features_length));
}
int32_t OfflineNemoEncDecCtcModel::VocabSize() const {
return impl_->VocabSize();
}
int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const {
return impl_->SubsamplingFactor();
}
OrtAllocator *OfflineNemoEncDecCtcModel::Allocator() const {
return impl_->Allocator();
}
std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const {
return impl_->FeatureNormalizationMethod();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
/** This class implements the EncDecCTCModelBPE model from NeMo.
*
* See
* https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py
* https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py
*/
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
public:
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
~OfflineNemoEncDecCtcModel() override;
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.
* @param features_length A 1-D tensor of shape (N,) containing number of
* valid frames in `features` before padding.
* Its dtype is int64_t.
*
* @return Return a pair containing:
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
*/
std::pair<Ort::Value, Ort::Value> Forward(
Ort::Value features, Ort::Value features_length) override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** SubsamplingFactor of the model
*
* For Citrinet, the subsampling factor is usually 4.
* For Conformer CTC, the subsampling factor is usually 8.
*/
int32_t SubsamplingFactor() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string FeatureNormalizationMethod() const override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
... ...
// sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-ctc-model.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
const SymbolTable &sym_table) {
OfflineRecognitionResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
for (int32_t i = 0; i != src.tokens.size(); ++i) {
auto sym = sym_table[src.tokens[i]];
text.append(sym);
r.tokens.push_back(std::move(sym));
}
r.text = std::move(text);
return r;
}
class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
: config_(config),
symbol_table_(config_.model_config.tokens),
model_(OfflineCtcModel::Create(config_.model_config)) {
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
if (config.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> and its ID.");
exit(-1);
}
int32_t blank_id = symbol_table_["<blk>"];
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
}
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<Ort::Value> features;
features.reserve(n);
std::vector<std::vector<float>> features_vec(n);
std::vector<int64_t> features_length_vec(n);
for (int32_t i = 0; i != n; ++i) {
std::vector<float> f = ss[i]->GetFrames();
int32_t num_frames = f.size() / feat_dim;
features_vec[i] = std::move(f);
features_length_vec[i] = num_frames;
std::array<int64_t, 2> shape = {num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(
memory_info, features_vec[i].data(), features_vec[i].size(),
shape.data(), shape.size());
features.push_back(std::move(x));
} // for (int32_t i = 0; i != n; ++i)
std::vector<const Ort::Value *> features_pointer(n);
for (int32_t i = 0; i != n; ++i) {
features_pointer[i] = &features[i];
}
std::array<int64_t, 1> features_length_shape = {n};
Ort::Value x_length = Ort::Value::CreateTensor(
memory_info, features_length_vec.data(), n,
features_length_shape.data(), features_length_shape.size());
Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
-23.025850929940457f);
auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
for (int32_t i = 0; i != n; ++i) {
auto r = Convert(results[i], symbol_table_);
ss[i]->SetResult(r);
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineCtcModel> model_;
std::unique_ptr<OfflineCtcDecoder> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
... ...
... ... @@ -8,6 +8,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
... ... @@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
model_filename = config.model_config.transducer.encoder_filename;
} else if (!config.model_config.paraformer.model.empty()) {
model_filename = config.model_config.paraformer.model;
} else if (!config.model_config.nemo_ctc.model.empty()) {
model_filename = config.model_config.nemo_ctc.model;
} else {
SHERPA_ONNX_LOGE("Please provide a model");
exit(-1);
... ... @@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
std::string model_type;
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
auto model_type_ptr =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type_ptr) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n\n"
"Please refer to the following URLs to add metadata"
"\n"
"(0) Transducer models from icefall"
"\n "
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
"pruned_transducer_stateless7/export-onnx.py#L303"
"\n"
"(1) Nemo CTC models\n "
"https://huggingface.co/csukuangfj/"
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
"\n"
"(2) Paraformer"
"\n "
"https://huggingface.co/csukuangfj/"
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
"\n");
exit(-1);
}
std::string model_type(model_type_ptr.get());
if (model_type == "conformer" || model_type == "zipformer") {
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
... ... @@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
}
if (model_type == "EncDecCTCModelBPE") {
return std::make_unique<OfflineRecognizerCtcImpl>(config);
}
SHERPA_ONNX_LOGE(
"\nUnsupported model_type: %s\n"
"We support only the following model types at present: \n"
" - transducer models from icefall\n"
" - Paraformer models from FunASR\n",
" - Non-streaming transducer models from icefall\n"
" - Non-streaming Paraformer models from FunASR\n"
" - EncDecCTCModelBPE models from NeMo\n",
model_type.c_str());
exit(-1);
... ...
... ... @@ -7,6 +7,7 @@
#include <assert.h>
#include <algorithm>
#include <cmath>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/macros.h"
... ... @@ -15,6 +16,41 @@
namespace sherpa_onnx {
/* Compute mean and inverse stddev over rows.
*
* @param p A pointer to a 2-d array of shape (num_rows, num_cols)
* @param num_rows Number of rows
* @param num_cols Number of columns
* @param mean On return, it contains p.mean(axis=0)
* @param inv_stddev On return, it contains 1/p.std(axis=0)
*/
static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
int32_t num_cols, std::vector<float> *mean,
std::vector<float> *inv_stddev) {
std::vector<float> sum(num_cols);
std::vector<float> sum_sq(num_cols);
for (int32_t i = 0; i != num_rows; ++i) {
for (int32_t c = 0; c != num_cols; ++c) {
auto t = p[c];
sum[c] += t;
sum_sq[c] += t * t;
}
p += num_cols;
}
mean->resize(num_cols);
inv_stddev->resize(num_cols);
for (int32_t i = 0; i != num_cols; ++i) {
auto t = sum[i] / num_rows;
(*mean)[i] = t;
float stddev = std::sqrt(sum_sq[i] / num_rows - t * t);
(*inv_stddev)[i] = 1.0f / (stddev + 1e-5f);
}
}
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("sample-rate", &sampling_rate,
"Sampling rate of the input waveform. "
... ... @@ -106,6 +142,8 @@ class OfflineStream::Impl {
p += feature_dim;
}
NemoNormalizeFeatures(features.data(), n, feature_dim);
return features;
}
... ... @@ -114,6 +152,38 @@ class OfflineStream::Impl {
const OfflineRecognitionResult &GetResult() const { return r_; }
private:
void NemoNormalizeFeatures(float *p, int32_t num_frames,
int32_t feature_dim) const {
if (config_.nemo_normalize_type.empty()) {
return;
}
if (config_.nemo_normalize_type != "per_feature") {
SHERPA_ONNX_LOGE(
"Only normalize_type=per_feature is implemented. Given: %s",
config_.nemo_normalize_type.c_str());
exit(-1);
}
NemoNormalizePerFeature(p, num_frames, feature_dim);
}
static void NemoNormalizePerFeature(float *p, int32_t num_frames,
int32_t feature_dim) {
std::vector<float> mean;
std::vector<float> inv_stddev;
ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev);
for (int32_t n = 0; n != num_frames; ++n) {
for (int32_t i = 0; i != feature_dim; ++i) {
p[i] = (p[i] - mean[i]) * inv_stddev[i];
}
p += feature_dim;
}
}
private:
OfflineFeatureExtractorConfig config_;
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
... ...
... ... @@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;
// Set internally by some models, e.g., paraformer
// Set internally by some models, e.g., paraformer sets it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
// the range [-1, 1].
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;
// For models from NeMo
// This option is not exposed and is set internally when loading models.
// Possible values:
// - per_feature
// - all_features (not implemented yet)
// - fixed_mean (not implemented)
// - fixed_std (not implemented)
// - or just leave it to empty
// See
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
// for details
std::string nemo_normalize_type;
std::string ToString() const;
void Register(ParseOptions *po);
... ...
... ... @@ -14,10 +14,12 @@
#include <sstream>
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
namespace {
enum class ModelType {
kLstm,
... ... @@ -25,6 +27,10 @@ enum class ModelType {
kUnkown,
};
}
namespace sherpa_onnx {
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
... ... @@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
fprintf(stderr, "%s\n", os.str().c_str());
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
fprintf(stderr, "No model_type in the metadata!\n");
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall "
"to export your transducer models");
return ModelType::kUnkown;
}
... ... @@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
} else if (model_type.get() == std::string("zipformer")) {
return ModelType::kZipformer;
} else {
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
... ... @@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
case ModelType::kZipformer:
return std::make_unique<OnlineZipformerTransducerModel>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}
... ... @@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
case ModelType::kZipformer:
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}
... ...
... ... @@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) {
}
}
TEST(Tranpose, Tranpose12) {
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 3> shape{3, 2, 5};
Ort::Value v =
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
float *p = v.GetTensorMutableData<float>();
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
auto ans = Transpose12(allocator, &v);
auto v2 = Transpose12(allocator, &ans);
Print3D(&v);
Print3D(&ans);
Print3D(&v2);
const float *q = v2.GetTensorData<float>();
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
++i) {
EXPECT_EQ(p[i], q[i]);
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -17,7 +17,7 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
assert(shape.size() == 3);
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
ans_shape.size());
T *dst = ans.GetTensorMutableData<T>();
... ... @@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
return ans;
}
template <typename T /*= float*/>
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v) {
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
assert(shape.size() == 3);
std::array<int64_t, 3> ans_shape{shape[0], shape[2], shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
ans_shape.size());
T *dst = ans.GetTensorMutableData<T>();
auto row_stride = shape[2];
for (int64_t b = 0; b != ans_shape[0]; ++b) {
const T *src = v->GetTensorData<T>() + b * shape[1] * shape[2];
for (int64_t i = 0; i != ans_shape[1]; ++i) {
for (int64_t k = 0; k != ans_shape[2]; ++k, ++dst) {
*dst = (src + k * row_stride)[i];
}
}
}
return ans;
}
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
const Ort::Value *v);
template Ort::Value Transpose12<float>(OrtAllocator *allocator,
const Ort::Value *v);
} // namespace sherpa_onnx
... ...
... ... @@ -10,13 +10,23 @@ namespace sherpa_onnx {
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
*
* @param allocator
* @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
*
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is type.
*/
template <typename T = float>
template <typename type = float>
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
*
* @param allocator
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
*
* @return Return a 3-D tensor of shape (B, C, T). Its datatype is type.
*/
template <typename type = float>
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
... ...
... ... @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
endpoint.cc
features.cc
offline-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc
offline-paraformer-model-config.cc
offline-recognizer.cc
offline-stream.cc
... ...
... ... @@ -7,26 +7,31 @@
#include <string>
#include <vector>
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
namespace sherpa_onnx {
void PybindOfflineModelConfig(py::module *m) {
PybindOfflineTransducerModelConfig(m);
PybindOfflineParaformerModelConfig(m);
PybindOfflineNemoEncDecCtcModelConfig(m);
using PyClass = OfflineModelConfig;
py::class_<PyClass>(*m, "OfflineModelConfig")
.def(py::init<OfflineTransducerModelConfig &,
OfflineParaformerModelConfig &,
.def(py::init<const OfflineTransducerModelConfig &,
const OfflineParaformerModelConfig &,
const OfflineNemoEncDecCtcModelConfig &,
const std::string &, int32_t, bool>(),
py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"),
py::arg("num_threads"), py::arg("debug") = false)
py::arg("transducer") = OfflineTransducerModelConfig(),
py::arg("paraformer") = OfflineParaformerModelConfig(),
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false)
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
.def_readwrite("tokens", &PyClass::tokens)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
... ...
// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
namespace sherpa_onnx {
void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) {
using PyClass = OfflineNemoEncDecCtcModelConfig;
py::class_<PyClass>(*m, "OfflineNemoEncDecCtcModelConfig")
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineNemoEncDecCtcModelConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
... ...
... ... @@ -4,7 +4,6 @@
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include <string>
#include <vector>
... ... @@ -15,8 +14,7 @@ namespace sherpa_onnx {
void PybindOfflineParaformerModelConfig(py::module *m) {
using PyClass = OfflineParaformerModelConfig;
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
.def(py::init<const std::string &>(),
py::arg("model"))
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -11,8 +11,6 @@
namespace sherpa_onnx {
static void PybindOfflineRecognizerConfig(py::module *m) {
using PyClass = OfflineRecognizerConfig;
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
... ...
... ... @@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
"timestamps", [](const PyClass &self) { return self.timestamps; });
}
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
using PyClass = OfflineFeatureExtractorConfig;
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
... ... @@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) {
.def("__str__", &PyClass::ToString);
}
void PybindOfflineStream(py::module *m) {
PybindOfflineFeatureExtractorConfig(m);
PybindOfflineRecognitionResult(m);
... ...
... ... @@ -7,15 +7,12 @@
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
#include "sherpa-onnx/python/csrc/offline-stream.h"
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
... ...
... ... @@ -4,12 +4,15 @@ from typing import List
from _sherpa_onnx import (
OfflineFeatureExtractorConfig,
OfflineRecognizer as _Recognizer,
OfflineModelConfig,
OfflineNemoEncDecCtcModelConfig,
OfflineParaformerModelConfig,
)
from _sherpa_onnx import OfflineRecognizer as _Recognizer
from _sherpa_onnx import (
OfflineRecognizerConfig,
OfflineStream,
OfflineModelConfig,
OfflineTransducerModelConfig,
OfflineParaformerModelConfig,
)
... ... @@ -75,7 +78,6 @@ class OfflineRecognizer(object):
decoder_filename=decoder,
joiner_filename=joiner,
),
paraformer=OfflineParaformerModelConfig(model=""),
tokens=tokens,
num_threads=num_threads,
debug=debug,
... ... @@ -119,7 +121,7 @@ class OfflineRecognizer(object):
symbol integer_id
paraformer:
Path to ``paraformer.onnx``.
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
... ... @@ -133,9 +135,6 @@ class OfflineRecognizer(object):
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
transducer=OfflineTransducerModelConfig(
encoder_filename="", decoder_filename="", joiner_filename=""
),
paraformer=OfflineParaformerModelConfig(model=paraformer),
tokens=tokens,
num_threads=num_threads,
... ... @@ -155,6 +154,64 @@ class OfflineRecognizer(object):
self.recognizer = _Recognizer(recognizer_config)
return self
@classmethod
def from_nemo_ctc(
cls,
model: str,
tokens: str,
num_threads: int,
sample_rate: int = 16000,
feature_dim: int = 80,
decoding_method: str = "greedy_search",
debug: bool = False,
):
"""
Please refer to
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
to download pre-trained models for different languages, e.g., Chinese,
English, etc.
Args:
tokens:
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
columns::
symbol integer_id
model:
Path to ``model.onnx``.
num_threads:
Number of threads for neural network computation.
sample_rate:
Sample rate of the training data used to train the model.
feature_dim:
Dimension of the feature used to train the model.
decoding_method:
Valid values are greedy_search, modified_beam_search.
debug:
True to show debug messages.
"""
self = cls.__new__(cls)
model_config = OfflineModelConfig(
nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model),
tokens=tokens,
num_threads=num_threads,
debug=debug,
)
feat_config = OfflineFeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
)
recognizer_config = OfflineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
decoding_method=decoding_method,
)
self.recognizer = _Recognizer(recognizer_config)
return self
def create_stream(self):
return self.recognizer.create_stream()
... ...
... ... @@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase):
print(s2.result.text)
print(s3.result.text)
def test_nemo_ctc_single_file(self):
for use_int8 in [True, False]:
if use_int8:
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
else:
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
if not Path(model).is_file():
print("skipping test_nemo_ctc_single_file()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=model,
tokens=tokens,
num_threads=1,
)
s = recognizer.create_stream()
samples, sample_rate = read_wave(wave0)
s.accept_waveform(sample_rate, samples)
recognizer.decode_stream(s)
print(s.result.text)
def test_nemo_ctc_multiple_files(self):
for use_int8 in [True, False]:
if use_int8:
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
else:
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
wave1 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav"
wave2 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav"
if not Path(model).is_file():
print("skipping test_nemo_ctc_multiple_files()")
return
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=model,
tokens=tokens,
num_threads=1,
)
s0 = recognizer.create_stream()
samples0, sample_rate0 = read_wave(wave0)
s0.accept_waveform(sample_rate0, samples0)
s1 = recognizer.create_stream()
samples1, sample_rate1 = read_wave(wave1)
s1.accept_waveform(sample_rate1, samples1)
s2 = recognizer.create_stream()
samples2, sample_rate2 = read_wave(wave2)
s2.accept_waveform(sample_rate2, samples2)
recognizer.decode_streams([s0, s1, s2])
print(s0.result.text)
print(s1.result.text)
print(s2.result.text)
if __name__ == "__main__":
unittest.main()
... ...