Fangjun Kuang
Committed by GitHub

Support paraformer. (#95)

@@ -25,36 +25,59 @@ log "Download pretrained model and test-data from $repo_url" @@ -25,36 +25,59 @@ log "Download pretrained model and test-data from $repo_url"
25 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url 25 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
26 pushd $repo 26 pushd $repo
27 git lfs pull --include "*.onnx" 27 git lfs pull --include "*.onnx"
28 -cd test_wavs  
29 popd 28 popd
30 29
31 -waves=(  
32 -$repo/test_wavs/0.wav  
33 -$repo/test_wavs/1.wav  
34 -$repo/test_wavs/2.wav  
35 -)  
36 -  
37 -for wave in ${waves[@]}; do  
38 - time $EXE \  
39 - $repo/tokens.txt \  
40 - $repo/encoder-epoch-99-avg-1.onnx \  
41 - $repo/decoder-epoch-99-avg-1.onnx \  
42 - $repo/joiner-epoch-99-avg-1.onnx \  
43 - $wave \  
44 - 2  
45 -done 30 +time $EXE \
  31 + --tokens=$repo/tokens.txt \
  32 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  33 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  34 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  35 + --num-threads=2 \
  36 + $repo/test_wavs/0.wav \
  37 + $repo/test_wavs/1.wav \
  38 + $repo/test_wavs/2.wav
46 39
47 40
48 if command -v sox &> /dev/null; then 41 if command -v sox &> /dev/null; then
49 echo "test 8kHz" 42 echo "test 8kHz"
50 sox $repo/test_wavs/0.wav -r 8000 8k.wav 43 sox $repo/test_wavs/0.wav -r 8000 8k.wav
  44 +
51 time $EXE \ 45 time $EXE \
52 - $repo/tokens.txt \  
53 - $repo/encoder-epoch-99-avg-1.onnx \  
54 - $repo/decoder-epoch-99-avg-1.onnx \  
55 - $repo/joiner-epoch-99-avg-1.onnx \  
56 - 8k.wav \  
57 - 2 46 + --tokens=$repo/tokens.txt \
  47 + --encoder=$repo/encoder-epoch-99-avg-1.onnx \
  48 + --decoder=$repo/decoder-epoch-99-avg-1.onnx \
  49 + --joiner=$repo/joiner-epoch-99-avg-1.onnx \
  50 + --num-threads=2 \
  51 + $repo/test_wavs/0.wav \
  52 + $repo/test_wavs/1.wav \
  53 + $repo/test_wavs/2.wav \
  54 + 8k.wav
58 fi 55 fi
59 56
60 rm -rf $repo 57 rm -rf $repo
  58 +
  59 +log "------------------------------------------------------------"
  60 +log "Run Paraformer (Chinese)"
  61 +log "------------------------------------------------------------"
  62 +
  63 +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
  64 +log "Start testing ${repo_url}"
  65 +repo=$(basename $repo_url)
  66 +log "Download pretrained model and test-data from $repo_url"
  67 +
  68 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
  69 +pushd $repo
  70 +git lfs pull --include "*.onnx"
  71 +popd
  72 +
  73 +time $EXE \
  74 + --tokens=$repo/tokens.txt \
  75 + --paraformer=$repo/model.onnx \
  76 + --num-threads=2 \
  77 + --decoding-method=greedy_search \
  78 + $repo/test_wavs/0.wav \
  79 + $repo/test_wavs/1.wav \
  80 + $repo/test_wavs/2.wav \
  81 + $repo/test_wavs/8k.wav
  82 +
  83 +rm -rf $repo
@@ -71,7 +71,15 @@ jobs: @@ -71,7 +71,15 @@ jobs:
71 71
72 ls -lh ./bin/Release/sherpa-onnx.exe 72 ls -lh ./bin/Release/sherpa-onnx.exe
73 73
74 - - name: Test sherpa-onnx for Windows x64 74 + - name: Test offline transducer for Windows x64
  75 + shell: bash
  76 + run: |
  77 + export PATH=$PWD/build/bin/Release:$PATH
  78 + export EXE=sherpa-onnx-offline.exe
  79 +
  80 + .github/scripts/test-offline-transducer.sh
  81 +
  82 + - name: Test online transducer for Windows x64
75 shell: bash 83 shell: bash
76 run: | 84 run: |
77 export PATH=$PWD/build/bin/Release:$PATH 85 export PATH=$PWD/build/bin/Release:$PATH
@@ -71,7 +71,15 @@ jobs: @@ -71,7 +71,15 @@ jobs:
71 71
72 ls -lh ./bin/Release/sherpa-onnx.exe 72 ls -lh ./bin/Release/sherpa-onnx.exe
73 73
74 - - name: Test sherpa-onnx for Windows x86 74 + - name: Test offline transducer for Windows x86
  75 + shell: bash
  76 + run: |
  77 + export PATH=$PWD/build/bin/Release:$PATH
  78 + export EXE=sherpa-onnx-offline.exe
  79 +
  80 + .github/scripts/test-offline-transducer.sh
  81 +
  82 + - name: Test online transducer for Windows x86
75 shell: bash 83 shell: bash
76 run: | 84 run: |
77 export PATH=$PWD/build/bin/Release:$PATH 85 export PATH=$PWD/build/bin/Release:$PATH
@@ -41,3 +41,7 @@ android/SherpaOnnx/app/src/main/assets/ @@ -41,3 +41,7 @@ android/SherpaOnnx/app/src/main/assets/
41 *.ncnn.* 41 *.ncnn.*
42 run-sherpa-onnx-offline.sh 42 run-sherpa-onnx-offline.sh
43 sherpa-onnx-conformer-en-2023-03-18 43 sherpa-onnx-conformer-en-2023-03-18
  44 +paraformer-onnxruntime-python-example
  45 +run-sherpa-onnx-offline-paraformer.sh
  46 +run-sherpa-onnx-offline-transducer.sh
  47 +sherpa-onnx-paraformer-zh-2023-03-28
@@ -6,6 +6,10 @@ set(sources @@ -6,6 +6,10 @@ set(sources
6 features.cc 6 features.cc
7 file-utils.cc 7 file-utils.cc
8 hypothesis.cc 8 hypothesis.cc
  9 + offline-model-config.cc
  10 + offline-paraformer-greedy-search-decoder.cc
  11 + offline-paraformer-model-config.cc
  12 + offline-paraformer-model.cc
9 offline-recognizer-impl.cc 13 offline-recognizer-impl.cc
10 offline-recognizer.cc 14 offline-recognizer.cc
11 offline-stream.cc 15 offline-stream.cc
@@ -57,6 +57,23 @@ @@ -57,6 +57,23 @@
57 } \ 57 } \
58 } while (0) 58 } while (0)
59 59
  60 +// read a vector of floats
  61 +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
  62 + do { \
  63 + auto value = \
  64 + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
  65 + if (!value) { \
  66 + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
  67 + exit(-1); \
  68 + } \
  69 + \
  70 + bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
  71 + if (!ret) { \
  72 + SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \
  73 + exit(-1); \
  74 + } \
  75 + } while (0)
  76 +
60 // Read a string 77 // Read a string
61 #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ 78 #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
62 do { \ 79 do { \
  1 +// sherpa-onnx/csrc/offline-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#include "sherpa-onnx/csrc/offline-model-config.h"
  5 +
  6 +#include <string>
  7 +
  8 +#include "sherpa-onnx/csrc/file-utils.h"
  9 +#include "sherpa-onnx/csrc/macros.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +void OfflineModelConfig::Register(ParseOptions *po) {
  14 + transducer.Register(po);
  15 + paraformer.Register(po);
  16 +
  17 + po->Register("tokens", &tokens, "Path to tokens.txt");
  18 +
  19 + po->Register("num-threads", &num_threads,
  20 + "Number of threads to run the neural network");
  21 +
  22 + po->Register("debug", &debug,
  23 + "true to print model information while loading it.");
  24 +}
  25 +
  26 +bool OfflineModelConfig::Validate() const {
  27 + if (num_threads < 1) {
  28 + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
  29 + return false;
  30 + }
  31 +
  32 + if (!FileExists(tokens)) {
  33 + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
  34 + return false;
  35 + }
  36 +
  37 + if (!paraformer.model.empty()) {
  38 + return paraformer.Validate();
  39 + }
  40 +
  41 + return transducer.Validate();
  42 +}
  43 +
  44 +std::string OfflineModelConfig::ToString() const {
  45 + std::ostringstream os;
  46 +
  47 + os << "OfflineModelConfig(";
  48 + os << "transducer=" << transducer.ToString() << ", ";
  49 + os << "paraformer=" << paraformer.ToString() << ", ";
  50 + os << "tokens=\"" << tokens << "\", ";
  51 + os << "num_threads=" << num_threads << ", ";
  52 + os << "debug=" << (debug ? "True" : "False") << ")";
  53 +
  54 + return os.str();
  55 +}
  56 +
  57 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
  10 +#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineModelConfig {
  15 + OfflineTransducerModelConfig transducer;
  16 + OfflineParaformerModelConfig paraformer;
  17 +
  18 + std::string tokens;
  19 + int32_t num_threads = 2;
  20 + bool debug = false;
  21 +
  22 + OfflineModelConfig() = default;
  23 + OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
  24 + const OfflineParaformerModelConfig &paraformer,
  25 + const std::string &tokens, int32_t num_threads, bool debug)
  26 + : transducer(transducer),
  27 + paraformer(paraformer),
  28 + tokens(tokens),
  29 + num_threads(num_threads),
  30 + debug(debug) {}
  31 +
  32 + void Register(ParseOptions *po);
  33 + bool Validate() const;
  34 +
  35 + std::string ToString() const;
  36 +};
  37 +
  38 +} // namespace sherpa_onnx
  39 +
  40 +#endif // SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-paraformer-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "onnxruntime_cxx_api.h" // NOLINT
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +struct OfflineParaformerDecoderResult {
  15 + /// The decoded token IDs
  16 + std::vector<int64_t> tokens;
  17 +};
  18 +
  19 +class OfflineParaformerDecoder {
  20 + public:
  21 + virtual ~OfflineParaformerDecoder() = default;
  22 +
  23 + /** Run beam search given the output from the paraformer model.
  24 + *
  25 + * @param log_probs A 3-D tensor of shape (N, T, vocab_size)
  26 + * @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t.
  27 + * log_probs[i].argmax(axis=-1) equals to token_num[i]
  28 + *
  29 + * @return Return a vector of size `N` containing the decoded results.
  30 + */
  31 + virtual std::vector<OfflineParaformerDecoderResult> Decode(
  32 + Ort::Value log_probs, Ort::Value token_num) = 0;
  33 +};
  34 +
  35 +} // namespace sherpa_onnx
  36 +
  37 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
  6 +
  7 +#include <vector>
  8 +
  9 +namespace sherpa_onnx {
  10 +
  11 +std::vector<OfflineParaformerDecoderResult>
  12 +OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/,
  13 + Ort::Value token_num) {
  14 + std::vector<int64_t> shape = token_num.GetTensorTypeAndShapeInfo().GetShape();
  15 + int32_t batch_size = shape[0];
  16 + int32_t num_tokens = shape[1];
  17 +
  18 + std::vector<OfflineParaformerDecoderResult> results(batch_size);
  19 +
  20 + const int64_t *p = token_num.GetTensorData<int64_t>();
  21 + for (int32_t i = 0; i != batch_size; ++i) {
  22 + for (int32_t k = 0; k != num_tokens; ++k) {
  23 + if (p[k] == eos_id_) break;
  24 +
  25 + results[i].tokens.push_back(p[k]);
  26 + }
  27 +
  28 + p += num_tokens;
  29 + }
  30 +
  31 + return results;
  32 +}
  33 +
  34 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_
  7 +
  8 +#include <vector>
  9 +
  10 +#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
  11 +
  12 +namespace sherpa_onnx {
  13 +
  14 +class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
  15 + public:
  16 + explicit OfflineParaformerGreedySearchDecoder(int32_t eos_id)
  17 + : eos_id_(eos_id) {}
  18 +
  19 + std::vector<OfflineParaformerDecoderResult> Decode(
  20 + Ort::Value /*log_probs*/, Ort::Value token_num) override;
  21 +
  22 + private:
  23 + int32_t eos_id_;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_
  1 +// sherpa-onnx/csrc/offline-paraformer-model-config.cc
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
  6 +
  7 +#include "sherpa-onnx/csrc/file-utils.h"
  8 +#include "sherpa-onnx/csrc/macros.h"
  9 +
  10 +namespace sherpa_onnx {
  11 +
  12 +void OfflineParaformerModelConfig::Register(ParseOptions *po) {
  13 + po->Register("paraformer", &model, "Path to model.onnx of paraformer.");
  14 +}
  15 +
  16 +bool OfflineParaformerModelConfig::Validate() const {
  17 + if (!FileExists(model)) {
  18 + SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
  19 + return false;
  20 + }
  21 +
  22 + return true;
  23 +}
  24 +
  25 +std::string OfflineParaformerModelConfig::ToString() const {
  26 + std::ostringstream os;
  27 +
  28 + os << "OfflineParaformerModelConfig(";
  29 + os << "model=\"" << model << "\")";
  30 +
  31 + return os.str();
  32 +}
  33 +
  34 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-paraformer-model-config.h
  2 +//
  3 +// Copyright (c) 2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
  6 +
  7 +#include <string>
  8 +
  9 +#include "sherpa-onnx/csrc/parse-options.h"
  10 +
  11 +namespace sherpa_onnx {
  12 +
  13 +struct OfflineParaformerModelConfig {
  14 + std::string model;
  15 +
  16 + OfflineParaformerModelConfig() = default;
  17 + explicit OfflineParaformerModelConfig(const std::string &model)
  18 + : model(model) {}
  19 +
  20 + void Register(ParseOptions *po);
  21 + bool Validate() const;
  22 +
  23 + std::string ToString() const;
  24 +};
  25 +
  26 +} // namespace sherpa_onnx
  27 +
  28 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
  1 +// sherpa-onnx/csrc/offline-paraformer-model.cc
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#include "sherpa-onnx/csrc/offline-paraformer-model.h"
  6 +
  7 +#include <algorithm>
  8 +#include <string>
  9 +
  10 +#include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/onnx-utils.h"
  12 +#include "sherpa-onnx/csrc/text-utils.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineParaformerModel::Impl {
  17 + public:
  18 + explicit Impl(const OfflineModelConfig &config)
  19 + : config_(config),
  20 + env_(ORT_LOGGING_LEVEL_ERROR),
  21 + sess_opts_{},
  22 + allocator_{} {
  23 + sess_opts_.SetIntraOpNumThreads(config_.num_threads);
  24 + sess_opts_.SetInterOpNumThreads(config_.num_threads);
  25 +
  26 + Init();
  27 + }
  28 +
  29 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
  30 + Ort::Value features_length) {
  31 + std::array<Ort::Value, 2> inputs = {std::move(features),
  32 + std::move(features_length)};
  33 +
  34 + auto out =
  35 + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
  36 + output_names_ptr_.data(), output_names_ptr_.size());
  37 +
  38 + return {std::move(out[0]), std::move(out[1])};
  39 + }
  40 +
  41 + int32_t VocabSize() const { return vocab_size_; }
  42 +
  43 + int32_t LfrWindowSize() const { return lfr_window_size_; }
  44 +
  45 + int32_t LfrWindowShift() const { return lfr_window_shift_; }
  46 +
  47 + const std::vector<float> &NegativeMean() const { return neg_mean_; }
  48 +
  49 + const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
  50 +
  51 + OrtAllocator *Allocator() const { return allocator_; }
  52 +
  53 + private:
  54 + void Init() {
  55 + auto buf = ReadFile(config_.paraformer.model);
  56 +
  57 + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
  58 + sess_opts_);
  59 +
  60 + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
  61 +
  62 + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
  63 +
  64 + // get meta data
  65 + Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
  66 + if (config_.debug) {
  67 + std::ostringstream os;
  68 + PrintModelMetadata(os, meta_data);
  69 + SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
  70 + }
  71 +
  72 + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
  73 + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
  74 + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
  75 + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
  76 +
  77 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
  78 + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
  79 + }
  80 +
  81 + private:
  82 + OfflineModelConfig config_;
  83 + Ort::Env env_;
  84 + Ort::SessionOptions sess_opts_;
  85 + Ort::AllocatorWithDefaultOptions allocator_;
  86 +
  87 + std::unique_ptr<Ort::Session> sess_;
  88 +
  89 + std::vector<std::string> input_names_;
  90 + std::vector<const char *> input_names_ptr_;
  91 +
  92 + std::vector<std::string> output_names_;
  93 + std::vector<const char *> output_names_ptr_;
  94 +
  95 + std::vector<float> neg_mean_;
  96 + std::vector<float> inv_stddev_;
  97 +
  98 + int32_t vocab_size_ = 0; // initialized in Init
  99 + int32_t lfr_window_size_ = 0;
  100 + int32_t lfr_window_shift_ = 0;
  101 +};
  102 +
  103 +OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config)
  104 + : impl_(std::make_unique<Impl>(config)) {}
  105 +
  106 +OfflineParaformerModel::~OfflineParaformerModel() = default;
  107 +
  108 +std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
  109 + Ort::Value features, Ort::Value features_length) {
  110 + return impl_->Forward(std::move(features), std::move(features_length));
  111 +}
  112 +
  113 +int32_t OfflineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
  114 +
  115 +int32_t OfflineParaformerModel::LfrWindowSize() const {
  116 + return impl_->LfrWindowSize();
  117 +}
  118 +int32_t OfflineParaformerModel::LfrWindowShift() const {
  119 + return impl_->LfrWindowShift();
  120 +}
  121 +const std::vector<float> &OfflineParaformerModel::NegativeMean() const {
  122 + return impl_->NegativeMean();
  123 +}
  124 +const std::vector<float> &OfflineParaformerModel::InverseStdDev() const {
  125 + return impl_->InverseStdDev();
  126 +}
  127 +
  128 +OrtAllocator *OfflineParaformerModel::Allocator() const {
  129 + return impl_->Allocator();
  130 +}
  131 +
  132 +} // namespace sherpa_onnx
  1 +// sherpa-onnx/csrc/offline-paraformer-model.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
  5 +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
  6 +
  7 +#include <memory>
  8 +#include <utility>
  9 +#include <vector>
  10 +
  11 +#include "onnxruntime_cxx_api.h" // NOLINT
  12 +#include "sherpa-onnx/csrc/offline-model-config.h"
  13 +
  14 +namespace sherpa_onnx {
  15 +
  16 +class OfflineParaformerModel {
  17 + public:
  18 + explicit OfflineParaformerModel(const OfflineModelConfig &config);
  19 + ~OfflineParaformerModel();
  20 +
  21 + /** Run the forward method of the model.
  22 + *
  23 + * @param features A tensor of shape (N, T, C). It is changed in-place.
  24 + * @param features_length A 1-D tensor of shape (N,) containing number of
  25 + * valid frames in `features` before padding.
  26 + * Its dtype is int32_t.
  27 + *
  28 + * @return Return a pair containing:
  29 + * - log_probs: A 3-D tensor of shape (N, T', vocab_size)
  30 + * - token_num: A 1-D tensor of shape (N, T') containing number
  31 + * of valid tokens in each utterance. Its dtype is int64_t.
  32 + */
  33 + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
  34 + Ort::Value features_length);
  35 +
  36 + /** Return the vocabulary size of the model
  37 + */
  38 + int32_t VocabSize() const;
  39 +
  40 + /** It is lfr_m in config.yaml
  41 + */
  42 + int32_t LfrWindowSize() const;
  43 +
  44 + /** It is lfr_n in config.yaml
  45 + */
  46 + int32_t LfrWindowShift() const;
  47 +
  48 + /** Return negative mean for CMVN
  49 + */
  50 + const std::vector<float> &NegativeMean() const;
  51 +
  52 + /** Return inverse stddev for CMVN
  53 + */
  54 + const std::vector<float> &InverseStdDev() const;
  55 +
  56 + /** Return an allocator for allocating memory
  57 + */
  58 + OrtAllocator *Allocator() const;
  59 +
  60 + private:
  61 + class Impl;
  62 + std::unique_ptr<Impl> impl_;
  63 +};
  64 +
  65 +} // namespace sherpa_onnx
  66 +
  67 +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
@@ -8,6 +8,7 @@ @@ -8,6 +8,7 @@
8 8
9 #include "onnxruntime_cxx_api.h" // NOLINT 9 #include "onnxruntime_cxx_api.h" // NOLINT
10 #include "sherpa-onnx/csrc/macros.h" 10 #include "sherpa-onnx/csrc/macros.h"
  11 +#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
11 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" 12 #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
12 #include "sherpa-onnx/csrc/onnx-utils.h" 13 #include "sherpa-onnx/csrc/onnx-utils.h"
13 #include "sherpa-onnx/csrc/text-utils.h" 14 #include "sherpa-onnx/csrc/text-utils.h"
@@ -16,10 +17,20 @@ namespace sherpa_onnx { @@ -16,10 +17,20 @@ namespace sherpa_onnx {
16 17
17 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( 18 std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
18 const OfflineRecognizerConfig &config) { 19 const OfflineRecognizerConfig &config) {
19 - Ort::Env env; 20 + Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
20 21
21 Ort::SessionOptions sess_opts; 22 Ort::SessionOptions sess_opts;
22 - auto buf = ReadFile(config.model_config.encoder_filename); 23 + std::string model_filename;
  24 + if (!config.model_config.transducer.encoder_filename.empty()) {
  25 + model_filename = config.model_config.transducer.encoder_filename;
  26 + } else if (!config.model_config.paraformer.model.empty()) {
  27 + model_filename = config.model_config.paraformer.model;
  28 + } else {
  29 + SHERPA_ONNX_LOGE("Please provide a model");
  30 + exit(-1);
  31 + }
  32 +
  33 + auto buf = ReadFile(model_filename);
23 34
24 auto encoder_sess = 35 auto encoder_sess =
25 std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts); 36 std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts);
@@ -35,7 +46,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( @@ -35,7 +46,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
35 return std::make_unique<OfflineRecognizerTransducerImpl>(config); 46 return std::make_unique<OfflineRecognizerTransducerImpl>(config);
36 } 47 }
37 48
38 - SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str()); 49 + if (model_type == "paraformer") {
  50 + return std::make_unique<OfflineRecognizerParaformerImpl>(config);
  51 + }
  52 +
  53 + SHERPA_ONNX_LOGE(
  54 + "\nUnsupported model_type: %s\n"
  55 + "We support only the following model types at present: \n"
  56 + " - transducer models from icefall\n"
  57 + " - Paraformer models from FunASR\n",
  58 + model_type.c_str());
39 59
40 exit(-1); 60 exit(-1);
41 } 61 }
  1 +// sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
  2 +//
  3 +// Copyright (c) 2022-2023 Xiaomi Corporation
  4 +
  5 +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_
  6 +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_
  7 +
  8 +#include <algorithm>
  9 +#include <memory>
  10 +#include <string>
  11 +#include <utility>
  12 +#include <vector>
  13 +
  14 +#include "sherpa-onnx/csrc/offline-model-config.h"
  15 +#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
  16 +#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
  17 +#include "sherpa-onnx/csrc/offline-paraformer-model.h"
  18 +#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
  19 +#include "sherpa-onnx/csrc/offline-recognizer.h"
  20 +#include "sherpa-onnx/csrc/pad-sequence.h"
  21 +#include "sherpa-onnx/csrc/symbol-table.h"
  22 +
  23 +namespace sherpa_onnx {
  24 +
  25 +static OfflineRecognitionResult Convert(
  26 + const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
  27 + OfflineRecognitionResult r;
  28 + r.tokens.reserve(src.tokens.size());
  29 +
  30 + std::string text;
  31 + for (auto i : src.tokens) {
  32 + auto sym = sym_table[i];
  33 + text.append(sym);
  34 +
  35 + r.tokens.push_back(std::move(sym));
  36 + }
  37 + r.text = std::move(text);
  38 +
  39 + return r;
  40 +}
  41 +
  42 +class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
  43 + public:
  44 + explicit OfflineRecognizerParaformerImpl(
  45 + const OfflineRecognizerConfig &config)
  46 + : config_(config),
  47 + symbol_table_(config_.model_config.tokens),
  48 + model_(std::make_unique<OfflineParaformerModel>(config.model_config)) {
  49 + if (config.decoding_method == "greedy_search") {
  50 + int32_t eos_id = symbol_table_["</s>"];
  51 + decoder_ = std::make_unique<OfflineParaformerGreedySearchDecoder>(eos_id);
  52 + } else {
  53 + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
  54 + config.decoding_method.c_str());
  55 + exit(-1);
  56 + }
  57 +
  58 + // Paraformer models assume input samples are in the range
  59 + // [-32768, 32767], so we set normalize_samples to false
  60 + config_.feat_config.normalize_samples = false;
  61 + }
  62 +
  63 + std::unique_ptr<OfflineStream> CreateStream() const override {
  64 + return std::make_unique<OfflineStream>(config_.feat_config);
  65 + }
  66 +
  67 + void DecodeStreams(OfflineStream **ss, int32_t n) const override {
  68 + // 1. Apply LFR
  69 + // 2. Apply CMVN
  70 + //
  71 + // Please refer to
  72 + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf
  73 + // for what LFR means
  74 + //
  75 + // "Lower Frame Rate Neural Network Acoustic Models"
  76 + auto memory_info =
  77 + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
  78 +
  79 + std::vector<Ort::Value> features;
  80 + features.reserve(n);
  81 +
  82 + int32_t feat_dim =
  83 + config_.feat_config.feature_dim * model_->LfrWindowSize();
  84 +
  85 + std::vector<std::vector<float>> features_vec(n);
  86 + std::vector<int32_t> features_length_vec(n);
  87 + for (int32_t i = 0; i != n; ++i) {
  88 + std::vector<float> f = ss[i]->GetFrames();
  89 +
  90 + f = ApplyLFR(f);
  91 + ApplyCMVN(&f);
  92 +
  93 + int32_t num_frames = f.size() / feat_dim;
  94 + features_vec[i] = std::move(f);
  95 +
  96 + features_length_vec[i] = num_frames;
  97 +
  98 + std::array<int64_t, 2> shape = {num_frames, feat_dim};
  99 +
  100 + Ort::Value x = Ort::Value::CreateTensor(
  101 + memory_info, features_vec[i].data(), features_vec[i].size(),
  102 + shape.data(), shape.size());
  103 + features.push_back(std::move(x));
  104 + }
  105 +
  106 + std::vector<const Ort::Value *> features_pointer(n);
  107 + for (int32_t i = 0; i != n; ++i) {
  108 + features_pointer[i] = &features[i];
  109 + }
  110 +
  111 + std::array<int64_t, 1> features_length_shape = {n};
  112 + Ort::Value x_length = Ort::Value::CreateTensor(
  113 + memory_info, features_length_vec.data(), n,
  114 + features_length_shape.data(), features_length_shape.size());
  115 +
  116 + // Caution(fangjun): We cannot pad it with log(eps),
  117 + // i.e., -23.025850929940457f
  118 + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
  119 +
  120 + auto t = model_->Forward(std::move(x), std::move(x_length));
  121 +
  122 + auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
  123 +
  124 + for (int32_t i = 0; i != n; ++i) {
  125 + auto r = Convert(results[i], symbol_table_);
  126 + ss[i]->SetResult(r);
  127 + }
  128 + }
  129 +
  130 + private:
  131 + std::vector<float> ApplyLFR(const std::vector<float> &in) const {
  132 + int32_t lfr_window_size = model_->LfrWindowSize();
  133 + int32_t lfr_window_shift = model_->LfrWindowShift();
  134 + int32_t in_feat_dim = config_.feat_config.feature_dim;
  135 +
  136 + int32_t in_num_frames = in.size() / in_feat_dim;
  137 + int32_t out_num_frames =
  138 + (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
  139 + int32_t out_feat_dim = in_feat_dim * lfr_window_size;
  140 +
  141 + std::vector<float> out(out_num_frames * out_feat_dim);
  142 +
  143 + const float *p_in = in.data();
  144 + float *p_out = out.data();
  145 +
  146 + for (int32_t i = 0; i != out_num_frames; ++i) {
  147 + std::copy(p_in, p_in + out_feat_dim, p_out);
  148 +
  149 + p_out += out_feat_dim;
  150 + p_in += lfr_window_shift * in_feat_dim;
  151 + }
  152 +
  153 + return out;
  154 + }
  155 +
  156 + void ApplyCMVN(std::vector<float> *v) const {
  157 + const std::vector<float> &neg_mean = model_->NegativeMean();
  158 + const std::vector<float> &inv_stddev = model_->InverseStdDev();
  159 +
  160 + int32_t dim = neg_mean.size();
  161 + int32_t num_frames = v->size() / dim;
  162 +
  163 + float *p = v->data();
  164 +
  165 + for (int32_t i = 0; i != num_frames; ++i) {
  166 + for (int32_t k = 0; k != dim; ++k) {
  167 + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
  168 + }
  169 +
  170 + p += dim;
  171 + }
  172 + }
  173 +
  174 + OfflineRecognizerConfig config_;
  175 + SymbolTable symbol_table_;
  176 + std::unique_ptr<OfflineParaformerModel> model_;
  177 + std::unique_ptr<OfflineParaformerDecoder> decoder_;
  178 +};
  179 +
  180 +} // namespace sherpa_onnx
  181 +
  182 +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_
1 // sherpa-onnx/csrc/offline-recognizer-transducer-impl.h 1 // sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
2 // 2 //
3 -// Copyright (c) 2022 Xiaomi Corporation 3 +// Copyright (c) 2022-2023 Xiaomi Corporation
4 4
5 #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ 5 #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ 6 #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
@@ -6,6 +6,8 @@ @@ -6,6 +6,8 @@
6 6
7 #include <memory> 7 #include <memory>
8 8
  9 +#include "sherpa-onnx/csrc/file-utils.h"
  10 +#include "sherpa-onnx/csrc/macros.h"
9 #include "sherpa-onnx/csrc/offline-recognizer-impl.h" 11 #include "sherpa-onnx/csrc/offline-recognizer-impl.h"
10 12
11 namespace sherpa_onnx { 13 namespace sherpa_onnx {
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 #include <string> 9 #include <string>
10 #include <vector> 10 #include <vector>
11 11
  12 +#include "sherpa-onnx/csrc/offline-model-config.h"
12 #include "sherpa-onnx/csrc/offline-stream.h" 13 #include "sherpa-onnx/csrc/offline-stream.h"
13 #include "sherpa-onnx/csrc/offline-transducer-model-config.h" 14 #include "sherpa-onnx/csrc/offline-transducer-model-config.h"
14 #include "sherpa-onnx/csrc/parse-options.h" 15 #include "sherpa-onnx/csrc/parse-options.h"
@@ -32,7 +33,7 @@ struct OfflineRecognitionResult { @@ -32,7 +33,7 @@ struct OfflineRecognitionResult {
32 33
33 struct OfflineRecognizerConfig { 34 struct OfflineRecognizerConfig {
34 OfflineFeatureExtractorConfig feat_config; 35 OfflineFeatureExtractorConfig feat_config;
35 - OfflineTransducerModelConfig model_config; 36 + OfflineModelConfig model_config;
36 37
37 std::string decoding_method = "greedy_search"; 38 std::string decoding_method = "greedy_search";
38 // only greedy_search is implemented 39 // only greedy_search is implemented
@@ -40,7 +41,7 @@ struct OfflineRecognizerConfig { @@ -40,7 +41,7 @@ struct OfflineRecognizerConfig {
40 41
41 OfflineRecognizerConfig() = default; 42 OfflineRecognizerConfig() = default;
42 OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, 43 OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
43 - const OfflineTransducerModelConfig &model_config, 44 + const OfflineModelConfig &model_config,
44 const std::string &decoding_method) 45 const std::string &decoding_method)
45 : feat_config(feat_config), 46 : feat_config(feat_config),
46 model_config(model_config), 47 model_config(model_config),
@@ -38,7 +38,7 @@ std::string OfflineFeatureExtractorConfig::ToString() const { @@ -38,7 +38,7 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
38 38
39 class OfflineStream::Impl { 39 class OfflineStream::Impl {
40 public: 40 public:
41 - explicit Impl(const OfflineFeatureExtractorConfig &config) { 41 + explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
42 opts_.frame_opts.dither = 0; 42 opts_.frame_opts.dither = 0;
43 opts_.frame_opts.snip_edges = false; 43 opts_.frame_opts.snip_edges = false;
44 opts_.frame_opts.samp_freq = config.sampling_rate; 44 opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -48,6 +48,19 @@ class OfflineStream::Impl { @@ -48,6 +48,19 @@ class OfflineStream::Impl {
48 } 48 }
49 49
50 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { 50 void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
  51 + if (config_.normalize_samples) {
  52 + AcceptWaveformImpl(sampling_rate, waveform, n);
  53 + } else {
  54 + std::vector<float> buf(n);
  55 + for (int32_t i = 0; i != n; ++i) {
  56 + buf[i] = waveform[i] * 32768;
  57 + }
  58 + AcceptWaveformImpl(sampling_rate, buf.data(), n);
  59 + }
  60 + }
  61 +
  62 + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
  63 + int32_t n) {
51 if (sampling_rate != opts_.frame_opts.samp_freq) { 64 if (sampling_rate != opts_.frame_opts.samp_freq) {
52 SHERPA_ONNX_LOGE( 65 SHERPA_ONNX_LOGE(
53 "Creating a resampler:\n" 66 "Creating a resampler:\n"
@@ -101,6 +114,7 @@ class OfflineStream::Impl { @@ -101,6 +114,7 @@ class OfflineStream::Impl {
101 const OfflineRecognitionResult &GetResult() const { return r_; } 114 const OfflineRecognitionResult &GetResult() const { return r_; }
102 115
103 private: 116 private:
  117 + OfflineFeatureExtractorConfig config_;
104 std::unique_ptr<knf::OnlineFbank> fbank_; 118 std::unique_ptr<knf::OnlineFbank> fbank_;
105 knf::FbankOptions opts_; 119 knf::FbankOptions opts_;
106 OfflineRecognitionResult r_; 120 OfflineRecognitionResult r_;
@@ -23,6 +23,13 @@ struct OfflineFeatureExtractorConfig { @@ -23,6 +23,13 @@ struct OfflineFeatureExtractorConfig {
23 // Feature dimension 23 // Feature dimension
24 int32_t feature_dim = 80; 24 int32_t feature_dim = 80;
25 25
  26 + // Set internally by some models, e.g., paraformer
  27 + // This parameter is not exposed to users from the commandline
  28 + // If true, the feature extractor expects inputs to be normalized to
  29 + // the range [-1, 1].
  30 + // If false, we will multiply the inputs by 32768
  31 + bool normalize_samples = true;
  32 +
26 std::string ToString() const; 33 std::string ToString() const;
27 34
28 void Register(ParseOptions *po); 35 void Register(ParseOptions *po);
@@ -14,20 +14,9 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { @@ -14,20 +14,9 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
14 po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); 14 po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
15 po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); 15 po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
16 po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); 16 po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
17 - po->Register("tokens", &tokens, "Path to tokens.txt");  
18 - po->Register("num_threads", &num_threads,  
19 - "Number of threads to run the neural network");  
20 -  
21 - po->Register("debug", &debug,  
22 - "true to print model information while loading it.");  
23 } 17 }
24 18
25 bool OfflineTransducerModelConfig::Validate() const { 19 bool OfflineTransducerModelConfig::Validate() const {
26 - if (!FileExists(tokens)) {  
27 - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());  
28 - return false;  
29 - }  
30 -  
31 if (!FileExists(encoder_filename)) { 20 if (!FileExists(encoder_filename)) {
32 SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); 21 SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
33 return false; 22 return false;
@@ -43,11 +32,6 @@ bool OfflineTransducerModelConfig::Validate() const { @@ -43,11 +32,6 @@ bool OfflineTransducerModelConfig::Validate() const {
43 return false; 32 return false;
44 } 33 }
45 34
46 - if (num_threads < 1) {  
47 - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);  
48 - return false;  
49 - }  
50 -  
51 return true; 35 return true;
52 } 36 }
53 37
@@ -57,10 +41,7 @@ std::string OfflineTransducerModelConfig::ToString() const { @@ -57,10 +41,7 @@ std::string OfflineTransducerModelConfig::ToString() const {
57 os << "OfflineTransducerModelConfig("; 41 os << "OfflineTransducerModelConfig(";
58 os << "encoder_filename=\"" << encoder_filename << "\", "; 42 os << "encoder_filename=\"" << encoder_filename << "\", ";
59 os << "decoder_filename=\"" << decoder_filename << "\", "; 43 os << "decoder_filename=\"" << decoder_filename << "\", ";
60 - os << "joiner_filename=\"" << joiner_filename << "\", ";  
61 - os << "tokens=\"" << tokens << "\", ";  
62 - os << "num_threads=" << num_threads << ", ";  
63 - os << "debug=" << (debug ? "True" : "False") << ")"; 44 + os << "joiner_filename=\"" << joiner_filename << "\")";
64 45
65 return os.str(); 46 return os.str();
66 } 47 }
@@ -14,22 +14,14 @@ struct OfflineTransducerModelConfig { @@ -14,22 +14,14 @@ struct OfflineTransducerModelConfig {
14 std::string encoder_filename; 14 std::string encoder_filename;
15 std::string decoder_filename; 15 std::string decoder_filename;
16 std::string joiner_filename; 16 std::string joiner_filename;
17 - std::string tokens;  
18 - int32_t num_threads = 2;  
19 - bool debug = false;  
20 17
21 OfflineTransducerModelConfig() = default; 18 OfflineTransducerModelConfig() = default;
22 OfflineTransducerModelConfig(const std::string &encoder_filename, 19 OfflineTransducerModelConfig(const std::string &encoder_filename,
23 const std::string &decoder_filename, 20 const std::string &decoder_filename,
24 - const std::string &joiner_filename,  
25 - const std::string &tokens, int32_t num_threads,  
26 - bool debug) 21 + const std::string &joiner_filename)
27 : encoder_filename(encoder_filename), 22 : encoder_filename(encoder_filename),
28 decoder_filename(decoder_filename), 23 decoder_filename(decoder_filename),
29 - joiner_filename(joiner_filename),  
30 - tokens(tokens),  
31 - num_threads(num_threads),  
32 - debug(debug) {} 24 + joiner_filename(joiner_filename) {}
33 25
34 void Register(ParseOptions *po); 26 void Register(ParseOptions *po);
35 bool Validate() const; 27 bool Validate() const;
@@ -16,7 +16,7 @@ namespace sherpa_onnx { @@ -16,7 +16,7 @@ namespace sherpa_onnx {
16 16
17 class OfflineTransducerModel::Impl { 17 class OfflineTransducerModel::Impl {
18 public: 18 public:
19 - explicit Impl(const OfflineTransducerModelConfig &config) 19 + explicit Impl(const OfflineModelConfig &config)
20 : config_(config), 20 : config_(config),
21 env_(ORT_LOGGING_LEVEL_WARNING), 21 env_(ORT_LOGGING_LEVEL_WARNING),
22 sess_opts_{}, 22 sess_opts_{},
@@ -24,17 +24,17 @@ class OfflineTransducerModel::Impl { @@ -24,17 +24,17 @@ class OfflineTransducerModel::Impl {
24 sess_opts_.SetIntraOpNumThreads(config.num_threads); 24 sess_opts_.SetIntraOpNumThreads(config.num_threads);
25 sess_opts_.SetInterOpNumThreads(config.num_threads); 25 sess_opts_.SetInterOpNumThreads(config.num_threads);
26 { 26 {
27 - auto buf = ReadFile(config.encoder_filename); 27 + auto buf = ReadFile(config.transducer.encoder_filename);
28 InitEncoder(buf.data(), buf.size()); 28 InitEncoder(buf.data(), buf.size());
29 } 29 }
30 30
31 { 31 {
32 - auto buf = ReadFile(config.decoder_filename); 32 + auto buf = ReadFile(config.transducer.decoder_filename);
33 InitDecoder(buf.data(), buf.size()); 33 InitDecoder(buf.data(), buf.size());
34 } 34 }
35 35
36 { 36 {
37 - auto buf = ReadFile(config.joiner_filename); 37 + auto buf = ReadFile(config.transducer.joiner_filename);
38 InitJoiner(buf.data(), buf.size()); 38 InitJoiner(buf.data(), buf.size());
39 } 39 }
40 } 40 }
@@ -164,7 +164,7 @@ class OfflineTransducerModel::Impl { @@ -164,7 +164,7 @@ class OfflineTransducerModel::Impl {
164 } 164 }
165 165
166 private: 166 private:
167 - OfflineTransducerModelConfig config_; 167 + OfflineModelConfig config_;
168 Ort::Env env_; 168 Ort::Env env_;
169 Ort::SessionOptions sess_opts_; 169 Ort::SessionOptions sess_opts_;
170 Ort::AllocatorWithDefaultOptions allocator_; 170 Ort::AllocatorWithDefaultOptions allocator_;
@@ -195,8 +195,7 @@ class OfflineTransducerModel::Impl { @@ -195,8 +195,7 @@ class OfflineTransducerModel::Impl {
195 int32_t context_size_ = 0; // initialized in InitDecoder 195 int32_t context_size_ = 0; // initialized in InitDecoder
196 }; 196 };
197 197
198 -OfflineTransducerModel::OfflineTransducerModel(  
199 - const OfflineTransducerModelConfig &config) 198 +OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config)
200 : impl_(std::make_unique<Impl>(config)) {} 199 : impl_(std::make_unique<Impl>(config)) {}
201 200
202 OfflineTransducerModel::~OfflineTransducerModel() = default; 201 OfflineTransducerModel::~OfflineTransducerModel() = default;
@@ -9,7 +9,7 @@ @@ -9,7 +9,7 @@
9 #include <vector> 9 #include <vector>
10 10
11 #include "onnxruntime_cxx_api.h" // NOLINT 11 #include "onnxruntime_cxx_api.h" // NOLINT
12 -#include "sherpa-onnx/csrc/offline-transducer-model-config.h" 12 +#include "sherpa-onnx/csrc/offline-model-config.h"
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
@@ -17,7 +17,7 @@ struct OfflineTransducerDecoderResult; @@ -17,7 +17,7 @@ struct OfflineTransducerDecoderResult;
17 17
18 class OfflineTransducerModel { 18 class OfflineTransducerModel {
19 public: 19 public:
20 - explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config); 20 + explicit OfflineTransducerModel(const OfflineModelConfig &config);
21 ~OfflineTransducerModel(); 21 ~OfflineTransducerModel();
22 22
23 /** Run the encoder. 23 /** Run the encoder.
@@ -25,6 +25,7 @@ class OfflineTransducerModel { @@ -25,6 +25,7 @@ class OfflineTransducerModel {
25 * @param features A tensor of shape (N, T, C). It is changed in-place. 25 * @param features A tensor of shape (N, T, C). It is changed in-place.
26 * @param features_length A 1-D tensor of shape (N,) containing number of 26 * @param features_length A 1-D tensor of shape (N,) containing number of
27 * valid frames in `features` before padding. 27 * valid frames in `features` before padding.
  28 + * Its dtype is int64_t.
28 * 29 *
29 * @return Return a pair containing: 30 * @return Return a pair containing:
30 * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) 31 * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
@@ -5,6 +5,7 @@ @@ -5,6 +5,7 @@
5 5
6 #include <algorithm> 6 #include <algorithm>
7 #include <fstream> 7 #include <fstream>
  8 +#include <sstream>
8 #include <string> 9 #include <string>
9 #include <vector> 10 #include <vector>
10 11
@@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) { @@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) {
133 fprintf(stderr, "\n"); 134 fprintf(stderr, "\n");
134 } 135 }
135 136
  137 +template <typename T /*= float*/>
136 void Print2D(Ort::Value *v) { 138 void Print2D(Ort::Value *v) {
137 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 139 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
138 - const float *d = v->GetTensorData<float>(); 140 + const T *d = v->GetTensorData<T>();
139 141
  142 + std::ostringstream os;
140 for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) { 143 for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
141 for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) { 144 for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
142 - fprintf(stderr, "%.3f ", *d); 145 + os << *d << " ";
143 } 146 }
144 - fprintf(stderr, "\n"); 147 + os << "\n";
145 } 148 }
146 - fprintf(stderr, "\n"); 149 + fprintf(stderr, "%s\n", os.str().c_str());
147 } 150 }
148 151
  152 +template void Print2D<int64_t>(Ort::Value *v);
  153 +template void Print2D<float>(Ort::Value *v);
  154 +
149 void Print3D(Ort::Value *v) { 155 void Print3D(Ort::Value *v) {
150 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); 156 std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
151 const float *d = v->GetTensorData<float>(); 157 const float *d = v->GetTensorData<float>();
@@ -24,18 +24,6 @@ @@ -24,18 +24,6 @@
24 24
25 namespace sherpa_onnx { 25 namespace sherpa_onnx {
26 26
27 -#ifdef _MSC_VER  
28 -// See  
29 -// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t  
30 -static std::wstring ToWide(const std::string &s) {  
31 - std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;  
32 - return converter.from_bytes(s);  
33 -}  
34 -#define SHERPA_MAYBE_WIDE(s) ToWide(s)  
35 -#else  
36 -#define SHERPA_MAYBE_WIDE(s) s  
37 -#endif  
38 -  
39 /** 27 /**
40 * Get the input names of a model. 28 * Get the input names of a model.
41 * 29 *
@@ -79,6 +67,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); @@ -79,6 +67,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
79 void Print1D(Ort::Value *v); 67 void Print1D(Ort::Value *v);
80 68
81 // Print a 2-D tensor to stderr 69 // Print a 2-D tensor to stderr
  70 +template <typename T = float>
82 void Print2D(Ort::Value *v); 71 void Print2D(Ort::Value *v);
83 72
84 // Print a 3-D tensor to stderr 73 // Print a 3-D tensor to stderr
@@ -9,24 +9,35 @@ @@ -9,24 +9,35 @@
9 #include <vector> 9 #include <vector>
10 10
11 #include "sherpa-onnx/csrc/offline-recognizer.h" 11 #include "sherpa-onnx/csrc/offline-recognizer.h"
12 -#include "sherpa-onnx/csrc/offline-stream.h"  
13 -#include "sherpa-onnx/csrc/offline-transducer-decoder.h"  
14 -#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"  
15 -#include "sherpa-onnx/csrc/offline-transducer-model.h"  
16 -#include "sherpa-onnx/csrc/pad-sequence.h"  
17 -#include "sherpa-onnx/csrc/symbol-table.h" 12 +#include "sherpa-onnx/csrc/parse-options.h"
18 #include "sherpa-onnx/csrc/wave-reader.h" 13 #include "sherpa-onnx/csrc/wave-reader.h"
19 14
20 int main(int32_t argc, char *argv[]) { 15 int main(int32_t argc, char *argv[]) {
21 - if (argc < 6 || argc > 8) {  
22 - const char *usage = R"usage( 16 + const char *kUsageMessage = R"usage(
23 Usage: 17 Usage:
  18 +
  19 +(1) Transducer from icefall
  20 +
  21 + ./bin/sherpa-onnx-offline \
  22 + --tokens=/path/to/tokens.txt \
  23 + --encoder=/path/to/encoder.onnx \
  24 + --decoder=/path/to/decoder.onnx \
  25 + --joiner=/path/to/joiner.onnx \
  26 + --num-threads=2 \
  27 + --decoding-method=greedy_search \
  28 + /path/to/foo.wav [bar.wav foobar.wav ...]
  29 +
  30 +
  31 +(2) Paraformer from FunASR
  32 +
24 ./bin/sherpa-onnx-offline \ 33 ./bin/sherpa-onnx-offline \
25 - /path/to/tokens.txt \  
26 - /path/to/encoder.onnx \  
27 - /path/to/decoder.onnx \  
28 - /path/to/joiner.onnx \  
29 - /path/to/foo.wav [num_threads [decoding_method]] 34 + --tokens=/path/to/tokens.txt \
  35 + --paraformer=/path/to/model.onnx \
  36 + --num-threads=2 \
  37 + --decoding-method=greedy_search \
  38 + /path/to/foo.wav [bar.wav foobar.wav ...]
  39 +
  40 +Note: It supports decoding multiple files in batches
30 41
31 Default value for num_threads is 2. 42 Default value for num_threads is 2.
32 Valid values for decoding_method: greedy_search. 43 Valid values for decoding_method: greedy_search.
@@ -37,29 +48,15 @@ Please refer to @@ -37,29 +48,15 @@ Please refer to
37 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html 48 https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
38 for a list of pre-trained models to download. 49 for a list of pre-trained models to download.
39 )usage"; 50 )usage";
40 - fprintf(stderr, "%s\n", usage);  
41 -  
42 - return 0;  
43 - }  
44 51
  52 + sherpa_onnx::ParseOptions po(kUsageMessage);
45 sherpa_onnx::OfflineRecognizerConfig config; 53 sherpa_onnx::OfflineRecognizerConfig config;
  54 + config.Register(&po);
46 55
47 - config.model_config.tokens = argv[1];  
48 -  
49 - config.model_config.debug = false;  
50 - config.model_config.encoder_filename = argv[2];  
51 - config.model_config.decoder_filename = argv[3];  
52 - config.model_config.joiner_filename = argv[4];  
53 -  
54 - std::string wav_filename = argv[5];  
55 -  
56 - config.model_config.num_threads = 2;  
57 - if (argc == 7 && atoi(argv[6]) > 0) {  
58 - config.model_config.num_threads = atoi(argv[6]);  
59 - }  
60 -  
61 - if (argc == 8) {  
62 - config.decoding_method = argv[7]; 56 + po.Read(argc, argv);
  57 + if (po.NumArgs() < 1) {
  58 + po.PrintUsage();
  59 + exit(EXIT_FAILURE);
63 } 60 }
64 61
65 fprintf(stderr, "%s\n", config.ToString().c_str()); 62 fprintf(stderr, "%s\n", config.ToString().c_str());
@@ -69,8 +66,17 @@ for a list of pre-trained models to download. @@ -69,8 +66,17 @@ for a list of pre-trained models to download.
69 return -1; 66 return -1;
70 } 67 }
71 68
72 - int32_t sampling_rate = -1; 69 + sherpa_onnx::OfflineRecognizer recognizer(config);
  70 +
  71 + auto begin = std::chrono::steady_clock::now();
  72 + fprintf(stderr, "Started\n");
73 73
  74 + std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
  75 + std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
  76 + float duration = 0;
  77 + for (int32_t i = 1; i <= po.NumArgs(); ++i) {
  78 + std::string wav_filename = po.GetArg(i);
  79 + int32_t sampling_rate = -1;
74 bool is_ok = false; 80 bool is_ok = false;
75 std::vector<float> samples = 81 std::vector<float> samples =
76 sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); 82 sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
@@ -78,26 +84,25 @@ for a list of pre-trained models to download. @@ -78,26 +84,25 @@ for a list of pre-trained models to download.
78 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); 84 fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
79 return -1; 85 return -1;
80 } 86 }
81 - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);  
82 -  
83 - float duration = samples.size() / static_cast<float>(sampling_rate); 87 + duration += samples.size() / static_cast<float>(sampling_rate);
84 88
85 - sherpa_onnx::OfflineRecognizer recognizer(config);  
86 auto s = recognizer.CreateStream(); 89 auto s = recognizer.CreateStream();
87 -  
88 - auto begin = std::chrono::steady_clock::now();  
89 - fprintf(stderr, "Started\n");  
90 -  
91 s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); 90 s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
92 91
93 - recognizer.DecodeStream(s.get());  
94 -  
95 - fprintf(stderr, "Done!\n"); 92 + ss.push_back(std::move(s));
  93 + ss_pointers.push_back(ss.back().get());
  94 + }
96 95
97 - fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),  
98 - s->GetResult().text.c_str()); 96 + recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size());
99 97
100 auto end = std::chrono::steady_clock::now(); 98 auto end = std::chrono::steady_clock::now();
  99 +
  100 + fprintf(stderr, "Done!\n\n");
  101 + for (int32_t i = 1; i <= po.NumArgs(); ++i) {
  102 + fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(),
  103 + ss[i - 1]->GetResult().text.c_str());
  104 + }
  105 +
101 float elapsed_seconds = 106 float elapsed_seconds =
102 std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) 107 std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
103 .count() / 108 .count() /