正在显示
29 个修改的文件
包含
859 行增加
和
139 行删除
| @@ -25,36 +25,59 @@ log "Download pretrained model and test-data from $repo_url" | @@ -25,36 +25,59 @@ log "Download pretrained model and test-data from $repo_url" | ||
| 25 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | 25 | GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url |
| 26 | pushd $repo | 26 | pushd $repo |
| 27 | git lfs pull --include "*.onnx" | 27 | git lfs pull --include "*.onnx" |
| 28 | -cd test_wavs | ||
| 29 | popd | 28 | popd |
| 30 | 29 | ||
| 31 | -waves=( | ||
| 32 | -$repo/test_wavs/0.wav | ||
| 33 | -$repo/test_wavs/1.wav | ||
| 34 | -$repo/test_wavs/2.wav | ||
| 35 | -) | ||
| 36 | - | ||
| 37 | -for wave in ${waves[@]}; do | ||
| 38 | - time $EXE \ | ||
| 39 | - $repo/tokens.txt \ | ||
| 40 | - $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 41 | - $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 42 | - $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 43 | - $wave \ | ||
| 44 | - 2 | ||
| 45 | -done | 30 | +time $EXE \ |
| 31 | + --tokens=$repo/tokens.txt \ | ||
| 32 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 33 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 34 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 35 | + --num-threads=2 \ | ||
| 36 | + $repo/test_wavs/0.wav \ | ||
| 37 | + $repo/test_wavs/1.wav \ | ||
| 38 | + $repo/test_wavs/2.wav | ||
| 46 | 39 | ||
| 47 | 40 | ||
| 48 | if command -v sox &> /dev/null; then | 41 | if command -v sox &> /dev/null; then |
| 49 | echo "test 8kHz" | 42 | echo "test 8kHz" |
| 50 | sox $repo/test_wavs/0.wav -r 8000 8k.wav | 43 | sox $repo/test_wavs/0.wav -r 8000 8k.wav |
| 44 | + | ||
| 51 | time $EXE \ | 45 | time $EXE \ |
| 52 | - $repo/tokens.txt \ | ||
| 53 | - $repo/encoder-epoch-99-avg-1.onnx \ | ||
| 54 | - $repo/decoder-epoch-99-avg-1.onnx \ | ||
| 55 | - $repo/joiner-epoch-99-avg-1.onnx \ | ||
| 56 | - 8k.wav \ | ||
| 57 | - 2 | 46 | + --tokens=$repo/tokens.txt \ |
| 47 | + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ | ||
| 48 | + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ | ||
| 49 | + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ | ||
| 50 | + --num-threads=2 \ | ||
| 51 | + $repo/test_wavs/0.wav \ | ||
| 52 | + $repo/test_wavs/1.wav \ | ||
| 53 | + $repo/test_wavs/2.wav \ | ||
| 54 | + 8k.wav | ||
| 58 | fi | 55 | fi |
| 59 | 56 | ||
| 60 | rm -rf $repo | 57 | rm -rf $repo |
| 58 | + | ||
| 59 | +log "------------------------------------------------------------" | ||
| 60 | +log "Run Paraformer (Chinese)" | ||
| 61 | +log "------------------------------------------------------------" | ||
| 62 | + | ||
| 63 | +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 | ||
| 64 | +log "Start testing ${repo_url}" | ||
| 65 | +repo=$(basename $repo_url) | ||
| 66 | +log "Download pretrained model and test-data from $repo_url" | ||
| 67 | + | ||
| 68 | +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
| 69 | +pushd $repo | ||
| 70 | +git lfs pull --include "*.onnx" | ||
| 71 | +popd | ||
| 72 | + | ||
| 73 | +time $EXE \ | ||
| 74 | + --tokens=$repo/tokens.txt \ | ||
| 75 | + --paraformer=$repo/model.onnx \ | ||
| 76 | + --num-threads=2 \ | ||
| 77 | + --decoding-method=greedy_search \ | ||
| 78 | + $repo/test_wavs/0.wav \ | ||
| 79 | + $repo/test_wavs/1.wav \ | ||
| 80 | + $repo/test_wavs/2.wav \ | ||
| 81 | + $repo/test_wavs/8k.wav | ||
| 82 | + | ||
| 83 | +rm -rf $repo |
| @@ -71,7 +71,15 @@ jobs: | @@ -71,7 +71,15 @@ jobs: | ||
| 71 | 71 | ||
| 72 | ls -lh ./bin/Release/sherpa-onnx.exe | 72 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 73 | 73 | ||
| 74 | - - name: Test sherpa-onnx for Windows x64 | 74 | + - name: Test offline transducer for Windows x64 |
| 75 | + shell: bash | ||
| 76 | + run: | | ||
| 77 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 78 | + export EXE=sherpa-onnx-offline.exe | ||
| 79 | + | ||
| 80 | + .github/scripts/test-offline-transducer.sh | ||
| 81 | + | ||
| 82 | + - name: Test online transducer for Windows x64 | ||
| 75 | shell: bash | 83 | shell: bash |
| 76 | run: | | 84 | run: | |
| 77 | export PATH=$PWD/build/bin/Release:$PATH | 85 | export PATH=$PWD/build/bin/Release:$PATH |
| @@ -71,7 +71,15 @@ jobs: | @@ -71,7 +71,15 @@ jobs: | ||
| 71 | 71 | ||
| 72 | ls -lh ./bin/Release/sherpa-onnx.exe | 72 | ls -lh ./bin/Release/sherpa-onnx.exe |
| 73 | 73 | ||
| 74 | - - name: Test sherpa-onnx for Windows x86 | 74 | + - name: Test offline transducer for Windows x86 |
| 75 | + shell: bash | ||
| 76 | + run: | | ||
| 77 | + export PATH=$PWD/build/bin/Release:$PATH | ||
| 78 | + export EXE=sherpa-onnx-offline.exe | ||
| 79 | + | ||
| 80 | + .github/scripts/test-offline-transducer.sh | ||
| 81 | + | ||
| 82 | + - name: Test online transducer for Windows x86 | ||
| 75 | shell: bash | 83 | shell: bash |
| 76 | run: | | 84 | run: | |
| 77 | export PATH=$PWD/build/bin/Release:$PATH | 85 | export PATH=$PWD/build/bin/Release:$PATH |
| @@ -41,3 +41,7 @@ android/SherpaOnnx/app/src/main/assets/ | @@ -41,3 +41,7 @@ android/SherpaOnnx/app/src/main/assets/ | ||
| 41 | *.ncnn.* | 41 | *.ncnn.* |
| 42 | run-sherpa-onnx-offline.sh | 42 | run-sherpa-onnx-offline.sh |
| 43 | sherpa-onnx-conformer-en-2023-03-18 | 43 | sherpa-onnx-conformer-en-2023-03-18 |
| 44 | +paraformer-onnxruntime-python-example | ||
| 45 | +run-sherpa-onnx-offline-paraformer.sh | ||
| 46 | +run-sherpa-onnx-offline-transducer.sh | ||
| 47 | +sherpa-onnx-paraformer-zh-2023-03-28 |
| @@ -6,6 +6,10 @@ set(sources | @@ -6,6 +6,10 @@ set(sources | ||
| 6 | features.cc | 6 | features.cc |
| 7 | file-utils.cc | 7 | file-utils.cc |
| 8 | hypothesis.cc | 8 | hypothesis.cc |
| 9 | + offline-model-config.cc | ||
| 10 | + offline-paraformer-greedy-search-decoder.cc | ||
| 11 | + offline-paraformer-model-config.cc | ||
| 12 | + offline-paraformer-model.cc | ||
| 9 | offline-recognizer-impl.cc | 13 | offline-recognizer-impl.cc |
| 10 | offline-recognizer.cc | 14 | offline-recognizer.cc |
| 11 | offline-stream.cc | 15 | offline-stream.cc |
| @@ -57,6 +57,23 @@ | @@ -57,6 +57,23 @@ | ||
| 57 | } \ | 57 | } \ |
| 58 | } while (0) | 58 | } while (0) |
| 59 | 59 | ||
| 60 | +// read a vector of floats | ||
| 61 | +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ | ||
| 62 | + do { \ | ||
| 63 | + auto value = \ | ||
| 64 | + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ | ||
| 65 | + if (!value) { \ | ||
| 66 | + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ | ||
| 67 | + exit(-1); \ | ||
| 68 | + } \ | ||
| 69 | + \ | ||
| 70 | + bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ | ||
| 71 | + if (!ret) { \ | ||
| 72 | + SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ | ||
| 73 | + exit(-1); \ | ||
| 74 | + } \ | ||
| 75 | + } while (0) | ||
| 76 | + | ||
| 60 | // Read a string | 77 | // Read a string |
| 61 | #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ | 78 | #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ |
| 62 | do { \ | 79 | do { \ |
sherpa-onnx/csrc/offline-model-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 5 | + | ||
| 6 | +#include <string> | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 9 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void OfflineModelConfig::Register(ParseOptions *po) { | ||
| 14 | + transducer.Register(po); | ||
| 15 | + paraformer.Register(po); | ||
| 16 | + | ||
| 17 | + po->Register("tokens", &tokens, "Path to tokens.txt"); | ||
| 18 | + | ||
| 19 | + po->Register("num-threads", &num_threads, | ||
| 20 | + "Number of threads to run the neural network"); | ||
| 21 | + | ||
| 22 | + po->Register("debug", &debug, | ||
| 23 | + "true to print model information while loading it."); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +bool OfflineModelConfig::Validate() const { | ||
| 27 | + if (num_threads < 1) { | ||
| 28 | + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
| 29 | + return false; | ||
| 30 | + } | ||
| 31 | + | ||
| 32 | + if (!FileExists(tokens)) { | ||
| 33 | + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); | ||
| 34 | + return false; | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + if (!paraformer.model.empty()) { | ||
| 38 | + return paraformer.Validate(); | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + return transducer.Validate(); | ||
| 42 | +} | ||
| 43 | + | ||
| 44 | +std::string OfflineModelConfig::ToString() const { | ||
| 45 | + std::ostringstream os; | ||
| 46 | + | ||
| 47 | + os << "OfflineModelConfig("; | ||
| 48 | + os << "transducer=" << transducer.ToString() << ", "; | ||
| 49 | + os << "paraformer=" << paraformer.ToString() << ", "; | ||
| 50 | + os << "tokens=\"" << tokens << "\", "; | ||
| 51 | + os << "num_threads=" << num_threads << ", "; | ||
| 52 | + os << "debug=" << (debug ? "True" : "False") << ")"; | ||
| 53 | + | ||
| 54 | + return os.str(); | ||
| 55 | +} | ||
| 56 | + | ||
| 57 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-model-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | ||
| 10 | +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineModelConfig { | ||
| 15 | + OfflineTransducerModelConfig transducer; | ||
| 16 | + OfflineParaformerModelConfig paraformer; | ||
| 17 | + | ||
| 18 | + std::string tokens; | ||
| 19 | + int32_t num_threads = 2; | ||
| 20 | + bool debug = false; | ||
| 21 | + | ||
| 22 | + OfflineModelConfig() = default; | ||
| 23 | + OfflineModelConfig(const OfflineTransducerModelConfig &transducer, | ||
| 24 | + const OfflineParaformerModelConfig ¶former, | ||
| 25 | + const std::string &tokens, int32_t num_threads, bool debug) | ||
| 26 | + : transducer(transducer), | ||
| 27 | + paraformer(paraformer), | ||
| 28 | + tokens(tokens), | ||
| 29 | + num_threads(num_threads), | ||
| 30 | + debug(debug) {} | ||
| 31 | + | ||
| 32 | + void Register(ParseOptions *po); | ||
| 33 | + bool Validate() const; | ||
| 34 | + | ||
| 35 | + std::string ToString() const; | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx | ||
| 39 | + | ||
| 40 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ |
| 1 | +// sherpa-onnx/csrc/offline-paraformer-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OfflineParaformerDecoderResult { | ||
| 15 | + /// The decoded token IDs | ||
| 16 | + std::vector<int64_t> tokens; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +class OfflineParaformerDecoder { | ||
| 20 | + public: | ||
| 21 | + virtual ~OfflineParaformerDecoder() = default; | ||
| 22 | + | ||
| 23 | + /** Run beam search given the output from the paraformer model. | ||
| 24 | + * | ||
| 25 | + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) | ||
| 26 | + * @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t. | ||
| 27 | + * log_probs[i].argmax(axis=-1) equals to token_num[i] | ||
| 28 | + * | ||
| 29 | + * @return Return a vector of size `N` containing the decoded results. | ||
| 30 | + */ | ||
| 31 | + virtual std::vector<OfflineParaformerDecoderResult> Decode( | ||
| 32 | + Ort::Value log_probs, Ort::Value token_num) = 0; | ||
| 33 | +}; | ||
| 34 | + | ||
| 35 | +} // namespace sherpa_onnx | ||
| 36 | + | ||
| 37 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <vector> | ||
| 8 | + | ||
| 9 | +namespace sherpa_onnx { | ||
| 10 | + | ||
| 11 | +std::vector<OfflineParaformerDecoderResult> | ||
| 12 | +OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/, | ||
| 13 | + Ort::Value token_num) { | ||
| 14 | + std::vector<int64_t> shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 15 | + int32_t batch_size = shape[0]; | ||
| 16 | + int32_t num_tokens = shape[1]; | ||
| 17 | + | ||
| 18 | + std::vector<OfflineParaformerDecoderResult> results(batch_size); | ||
| 19 | + | ||
| 20 | + const int64_t *p = token_num.GetTensorData<int64_t>(); | ||
| 21 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 22 | + for (int32_t k = 0; k != num_tokens; ++k) { | ||
| 23 | + if (p[k] == eos_id_) break; | ||
| 24 | + | ||
| 25 | + results[i].tokens.push_back(p[k]); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + p += num_tokens; | ||
| 29 | + } | ||
| 30 | + | ||
| 31 | + return results; | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-paraformer-decoder.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { | ||
| 15 | + public: | ||
| 16 | + explicit OfflineParaformerGreedySearchDecoder(int32_t eos_id) | ||
| 17 | + : eos_id_(eos_id) {} | ||
| 18 | + | ||
| 19 | + std::vector<OfflineParaformerDecoderResult> Decode( | ||
| 20 | + Ort::Value /*log_probs*/, Ort::Value token_num) override; | ||
| 21 | + | ||
| 22 | + private: | ||
| 23 | + int32_t eos_id_; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ |
| 1 | +// sherpa-onnx/csrc/offline-paraformer-model-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h" | ||
| 6 | + | ||
| 7 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 8 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void OfflineParaformerModelConfig::Register(ParseOptions *po) { | ||
| 13 | + po->Register("paraformer", &model, "Path to model.onnx of paraformer."); | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +bool OfflineParaformerModelConfig::Validate() const { | ||
| 17 | + if (!FileExists(model)) { | ||
| 18 | + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); | ||
| 19 | + return false; | ||
| 20 | + } | ||
| 21 | + | ||
| 22 | + return true; | ||
| 23 | +} | ||
| 24 | + | ||
| 25 | +std::string OfflineParaformerModelConfig::ToString() const { | ||
| 26 | + std::ostringstream os; | ||
| 27 | + | ||
| 28 | + os << "OfflineParaformerModelConfig("; | ||
| 29 | + os << "model=\"" << model << "\")"; | ||
| 30 | + | ||
| 31 | + return os.str(); | ||
| 32 | +} | ||
| 33 | + | ||
| 34 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-paraformer-model-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineParaformerModelConfig { | ||
| 14 | + std::string model; | ||
| 15 | + | ||
| 16 | + OfflineParaformerModelConfig() = default; | ||
| 17 | + explicit OfflineParaformerModelConfig(const std::string &model) | ||
| 18 | + : model(model) {} | ||
| 19 | + | ||
| 20 | + void Register(ParseOptions *po); | ||
| 21 | + bool Validate() const; | ||
| 22 | + | ||
| 23 | + std::string ToString() const; | ||
| 24 | +}; | ||
| 25 | + | ||
| 26 | +} // namespace sherpa_onnx | ||
| 27 | + | ||
| 28 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ |
sherpa-onnx/csrc/offline-paraformer-model.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-paraformer-model.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-paraformer-model.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <string> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 12 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineParaformerModel::Impl { | ||
| 17 | + public: | ||
| 18 | + explicit Impl(const OfflineModelConfig &config) | ||
| 19 | + : config_(config), | ||
| 20 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 21 | + sess_opts_{}, | ||
| 22 | + allocator_{} { | ||
| 23 | + sess_opts_.SetIntraOpNumThreads(config_.num_threads); | ||
| 24 | + sess_opts_.SetInterOpNumThreads(config_.num_threads); | ||
| 25 | + | ||
| 26 | + Init(); | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, | ||
| 30 | + Ort::Value features_length) { | ||
| 31 | + std::array<Ort::Value, 2> inputs = {std::move(features), | ||
| 32 | + std::move(features_length)}; | ||
| 33 | + | ||
| 34 | + auto out = | ||
| 35 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 36 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 37 | + | ||
| 38 | + return {std::move(out[0]), std::move(out[1])}; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + int32_t VocabSize() const { return vocab_size_; } | ||
| 42 | + | ||
| 43 | + int32_t LfrWindowSize() const { return lfr_window_size_; } | ||
| 44 | + | ||
| 45 | + int32_t LfrWindowShift() const { return lfr_window_shift_; } | ||
| 46 | + | ||
| 47 | + const std::vector<float> &NegativeMean() const { return neg_mean_; } | ||
| 48 | + | ||
| 49 | + const std::vector<float> &InverseStdDev() const { return inv_stddev_; } | ||
| 50 | + | ||
| 51 | + OrtAllocator *Allocator() const { return allocator_; } | ||
| 52 | + | ||
| 53 | + private: | ||
| 54 | + void Init() { | ||
| 55 | + auto buf = ReadFile(config_.paraformer.model); | ||
| 56 | + | ||
| 57 | + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(), | ||
| 58 | + sess_opts_); | ||
| 59 | + | ||
| 60 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 61 | + | ||
| 62 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 63 | + | ||
| 64 | + // get meta data | ||
| 65 | + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); | ||
| 66 | + if (config_.debug) { | ||
| 67 | + std::ostringstream os; | ||
| 68 | + PrintModelMetadata(os, meta_data); | ||
| 69 | + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below | ||
| 73 | + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); | ||
| 74 | + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size"); | ||
| 75 | + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift"); | ||
| 76 | + | ||
| 77 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean"); | ||
| 78 | + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + private: | ||
| 82 | + OfflineModelConfig config_; | ||
| 83 | + Ort::Env env_; | ||
| 84 | + Ort::SessionOptions sess_opts_; | ||
| 85 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 86 | + | ||
| 87 | + std::unique_ptr<Ort::Session> sess_; | ||
| 88 | + | ||
| 89 | + std::vector<std::string> input_names_; | ||
| 90 | + std::vector<const char *> input_names_ptr_; | ||
| 91 | + | ||
| 92 | + std::vector<std::string> output_names_; | ||
| 93 | + std::vector<const char *> output_names_ptr_; | ||
| 94 | + | ||
| 95 | + std::vector<float> neg_mean_; | ||
| 96 | + std::vector<float> inv_stddev_; | ||
| 97 | + | ||
| 98 | + int32_t vocab_size_ = 0; // initialized in Init | ||
| 99 | + int32_t lfr_window_size_ = 0; | ||
| 100 | + int32_t lfr_window_shift_ = 0; | ||
| 101 | +}; | ||
| 102 | + | ||
| 103 | +OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config) | ||
| 104 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 105 | + | ||
| 106 | +OfflineParaformerModel::~OfflineParaformerModel() = default; | ||
| 107 | + | ||
| 108 | +std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward( | ||
| 109 | + Ort::Value features, Ort::Value features_length) { | ||
| 110 | + return impl_->Forward(std::move(features), std::move(features_length)); | ||
| 111 | +} | ||
| 112 | + | ||
| 113 | +int32_t OfflineParaformerModel::VocabSize() const { return impl_->VocabSize(); } | ||
| 114 | + | ||
| 115 | +int32_t OfflineParaformerModel::LfrWindowSize() const { | ||
| 116 | + return impl_->LfrWindowSize(); | ||
| 117 | +} | ||
| 118 | +int32_t OfflineParaformerModel::LfrWindowShift() const { | ||
| 119 | + return impl_->LfrWindowShift(); | ||
| 120 | +} | ||
| 121 | +const std::vector<float> &OfflineParaformerModel::NegativeMean() const { | ||
| 122 | + return impl_->NegativeMean(); | ||
| 123 | +} | ||
| 124 | +const std::vector<float> &OfflineParaformerModel::InverseStdDev() const { | ||
| 125 | + return impl_->InverseStdDev(); | ||
| 126 | +} | ||
| 127 | + | ||
| 128 | +OrtAllocator *OfflineParaformerModel::Allocator() const { | ||
| 129 | + return impl_->Allocator(); | ||
| 130 | +} | ||
| 131 | + | ||
| 132 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-paraformer-model.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-paraformer-model.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ | ||
| 6 | + | ||
| 7 | +#include <memory> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineParaformerModel { | ||
| 17 | + public: | ||
| 18 | + explicit OfflineParaformerModel(const OfflineModelConfig &config); | ||
| 19 | + ~OfflineParaformerModel(); | ||
| 20 | + | ||
| 21 | + /** Run the forward method of the model. | ||
| 22 | + * | ||
| 23 | + * @param features A tensor of shape (N, T, C). It is changed in-place. | ||
| 24 | + * @param features_length A 1-D tensor of shape (N,) containing number of | ||
| 25 | + * valid frames in `features` before padding. | ||
| 26 | + * Its dtype is int32_t. | ||
| 27 | + * | ||
| 28 | + * @return Return a pair containing: | ||
| 29 | + * - log_probs: A 3-D tensor of shape (N, T', vocab_size) | ||
| 30 | + * - token_num: A 1-D tensor of shape (N, T') containing number | ||
| 31 | + * of valid tokens in each utterance. Its dtype is int64_t. | ||
| 32 | + */ | ||
| 33 | + std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features, | ||
| 34 | + Ort::Value features_length); | ||
| 35 | + | ||
| 36 | + /** Return the vocabulary size of the model | ||
| 37 | + */ | ||
| 38 | + int32_t VocabSize() const; | ||
| 39 | + | ||
| 40 | + /** It is lfr_m in config.yaml | ||
| 41 | + */ | ||
| 42 | + int32_t LfrWindowSize() const; | ||
| 43 | + | ||
| 44 | + /** It is lfr_n in config.yaml | ||
| 45 | + */ | ||
| 46 | + int32_t LfrWindowShift() const; | ||
| 47 | + | ||
| 48 | + /** Return negative mean for CMVN | ||
| 49 | + */ | ||
| 50 | + const std::vector<float> &NegativeMean() const; | ||
| 51 | + | ||
| 52 | + /** Return inverse stddev for CMVN | ||
| 53 | + */ | ||
| 54 | + const std::vector<float> &InverseStdDev() const; | ||
| 55 | + | ||
| 56 | + /** Return an allocator for allocating memory | ||
| 57 | + */ | ||
| 58 | + OrtAllocator *Allocator() const; | ||
| 59 | + | ||
| 60 | + private: | ||
| 61 | + class Impl; | ||
| 62 | + std::unique_ptr<Impl> impl_; | ||
| 63 | +}; | ||
| 64 | + | ||
| 65 | +} // namespace sherpa_onnx | ||
| 66 | + | ||
| 67 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include "onnxruntime_cxx_api.h" // NOLINT | 9 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 10 | #include "sherpa-onnx/csrc/macros.h" | 10 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" | ||
| 11 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" |
| 12 | #include "sherpa-onnx/csrc/onnx-utils.h" | 13 | #include "sherpa-onnx/csrc/onnx-utils.h" |
| 13 | #include "sherpa-onnx/csrc/text-utils.h" | 14 | #include "sherpa-onnx/csrc/text-utils.h" |
| @@ -16,10 +17,20 @@ namespace sherpa_onnx { | @@ -16,10 +17,20 @@ namespace sherpa_onnx { | ||
| 16 | 17 | ||
| 17 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | 18 | std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( |
| 18 | const OfflineRecognizerConfig &config) { | 19 | const OfflineRecognizerConfig &config) { |
| 19 | - Ort::Env env; | 20 | + Ort::Env env(ORT_LOGGING_LEVEL_ERROR); |
| 20 | 21 | ||
| 21 | Ort::SessionOptions sess_opts; | 22 | Ort::SessionOptions sess_opts; |
| 22 | - auto buf = ReadFile(config.model_config.encoder_filename); | 23 | + std::string model_filename; |
| 24 | + if (!config.model_config.transducer.encoder_filename.empty()) { | ||
| 25 | + model_filename = config.model_config.transducer.encoder_filename; | ||
| 26 | + } else if (!config.model_config.paraformer.model.empty()) { | ||
| 27 | + model_filename = config.model_config.paraformer.model; | ||
| 28 | + } else { | ||
| 29 | + SHERPA_ONNX_LOGE("Please provide a model"); | ||
| 30 | + exit(-1); | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + auto buf = ReadFile(model_filename); | ||
| 23 | 34 | ||
| 24 | auto encoder_sess = | 35 | auto encoder_sess = |
| 25 | std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts); | 36 | std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts); |
| @@ -35,7 +46,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | @@ -35,7 +46,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create( | ||
| 35 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); | 46 | return std::make_unique<OfflineRecognizerTransducerImpl>(config); |
| 36 | } | 47 | } |
| 37 | 48 | ||
| 38 | - SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str()); | 49 | + if (model_type == "paraformer") { |
| 50 | + return std::make_unique<OfflineRecognizerParaformerImpl>(config); | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + SHERPA_ONNX_LOGE( | ||
| 54 | + "\nUnsupported model_type: %s\n" | ||
| 55 | + "We support only the following model types at present: \n" | ||
| 56 | + " - transducer models from icefall\n" | ||
| 57 | + " - Paraformer models from FunASR\n", | ||
| 58 | + model_type.c_str()); | ||
| 39 | 59 | ||
| 40 | exit(-1); | 60 | exit(-1); |
| 41 | } | 61 | } |
| 1 | +// sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ | ||
| 7 | + | ||
| 8 | +#include <algorithm> | ||
| 9 | +#include <memory> | ||
| 10 | +#include <string> | ||
| 11 | +#include <utility> | ||
| 12 | +#include <vector> | ||
| 13 | + | ||
| 14 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 15 | +#include "sherpa-onnx/csrc/offline-paraformer-decoder.h" | ||
| 16 | +#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" | ||
| 17 | +#include "sherpa-onnx/csrc/offline-paraformer-model.h" | ||
| 18 | +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" | ||
| 19 | +#include "sherpa-onnx/csrc/offline-recognizer.h" | ||
| 20 | +#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 21 | +#include "sherpa-onnx/csrc/symbol-table.h" | ||
| 22 | + | ||
| 23 | +namespace sherpa_onnx { | ||
| 24 | + | ||
| 25 | +static OfflineRecognitionResult Convert( | ||
| 26 | + const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) { | ||
| 27 | + OfflineRecognitionResult r; | ||
| 28 | + r.tokens.reserve(src.tokens.size()); | ||
| 29 | + | ||
| 30 | + std::string text; | ||
| 31 | + for (auto i : src.tokens) { | ||
| 32 | + auto sym = sym_table[i]; | ||
| 33 | + text.append(sym); | ||
| 34 | + | ||
| 35 | + r.tokens.push_back(std::move(sym)); | ||
| 36 | + } | ||
| 37 | + r.text = std::move(text); | ||
| 38 | + | ||
| 39 | + return r; | ||
| 40 | +} | ||
| 41 | + | ||
| 42 | +class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { | ||
| 43 | + public: | ||
| 44 | + explicit OfflineRecognizerParaformerImpl( | ||
| 45 | + const OfflineRecognizerConfig &config) | ||
| 46 | + : config_(config), | ||
| 47 | + symbol_table_(config_.model_config.tokens), | ||
| 48 | + model_(std::make_unique<OfflineParaformerModel>(config.model_config)) { | ||
| 49 | + if (config.decoding_method == "greedy_search") { | ||
| 50 | + int32_t eos_id = symbol_table_["</s>"]; | ||
| 51 | + decoder_ = std::make_unique<OfflineParaformerGreedySearchDecoder>(eos_id); | ||
| 52 | + } else { | ||
| 53 | + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", | ||
| 54 | + config.decoding_method.c_str()); | ||
| 55 | + exit(-1); | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | + // Paraformer models assume input samples are in the range | ||
| 59 | + // [-32768, 32767], so we set normalize_samples to false | ||
| 60 | + config_.feat_config.normalize_samples = false; | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + std::unique_ptr<OfflineStream> CreateStream() const override { | ||
| 64 | + return std::make_unique<OfflineStream>(config_.feat_config); | ||
| 65 | + } | ||
| 66 | + | ||
| 67 | + void DecodeStreams(OfflineStream **ss, int32_t n) const override { | ||
| 68 | + // 1. Apply LFR | ||
| 69 | + // 2. Apply CMVN | ||
| 70 | + // | ||
| 71 | + // Please refer to | ||
| 72 | + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf | ||
| 73 | + // for what LFR means | ||
| 74 | + // | ||
| 75 | + // "Lower Frame Rate Neural Network Acoustic Models" | ||
| 76 | + auto memory_info = | ||
| 77 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 78 | + | ||
| 79 | + std::vector<Ort::Value> features; | ||
| 80 | + features.reserve(n); | ||
| 81 | + | ||
| 82 | + int32_t feat_dim = | ||
| 83 | + config_.feat_config.feature_dim * model_->LfrWindowSize(); | ||
| 84 | + | ||
| 85 | + std::vector<std::vector<float>> features_vec(n); | ||
| 86 | + std::vector<int32_t> features_length_vec(n); | ||
| 87 | + for (int32_t i = 0; i != n; ++i) { | ||
| 88 | + std::vector<float> f = ss[i]->GetFrames(); | ||
| 89 | + | ||
| 90 | + f = ApplyLFR(f); | ||
| 91 | + ApplyCMVN(&f); | ||
| 92 | + | ||
| 93 | + int32_t num_frames = f.size() / feat_dim; | ||
| 94 | + features_vec[i] = std::move(f); | ||
| 95 | + | ||
| 96 | + features_length_vec[i] = num_frames; | ||
| 97 | + | ||
| 98 | + std::array<int64_t, 2> shape = {num_frames, feat_dim}; | ||
| 99 | + | ||
| 100 | + Ort::Value x = Ort::Value::CreateTensor( | ||
| 101 | + memory_info, features_vec[i].data(), features_vec[i].size(), | ||
| 102 | + shape.data(), shape.size()); | ||
| 103 | + features.push_back(std::move(x)); | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + std::vector<const Ort::Value *> features_pointer(n); | ||
| 107 | + for (int32_t i = 0; i != n; ++i) { | ||
| 108 | + features_pointer[i] = &features[i]; | ||
| 109 | + } | ||
| 110 | + | ||
| 111 | + std::array<int64_t, 1> features_length_shape = {n}; | ||
| 112 | + Ort::Value x_length = Ort::Value::CreateTensor( | ||
| 113 | + memory_info, features_length_vec.data(), n, | ||
| 114 | + features_length_shape.data(), features_length_shape.size()); | ||
| 115 | + | ||
| 116 | + // Caution(fangjun): We cannot pad it with log(eps), | ||
| 117 | + // i.e., -23.025850929940457f | ||
| 118 | + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); | ||
| 119 | + | ||
| 120 | + auto t = model_->Forward(std::move(x), std::move(x_length)); | ||
| 121 | + | ||
| 122 | + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); | ||
| 123 | + | ||
| 124 | + for (int32_t i = 0; i != n; ++i) { | ||
| 125 | + auto r = Convert(results[i], symbol_table_); | ||
| 126 | + ss[i]->SetResult(r); | ||
| 127 | + } | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + private: | ||
| 131 | + std::vector<float> ApplyLFR(const std::vector<float> &in) const { | ||
| 132 | + int32_t lfr_window_size = model_->LfrWindowSize(); | ||
| 133 | + int32_t lfr_window_shift = model_->LfrWindowShift(); | ||
| 134 | + int32_t in_feat_dim = config_.feat_config.feature_dim; | ||
| 135 | + | ||
| 136 | + int32_t in_num_frames = in.size() / in_feat_dim; | ||
| 137 | + int32_t out_num_frames = | ||
| 138 | + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; | ||
| 139 | + int32_t out_feat_dim = in_feat_dim * lfr_window_size; | ||
| 140 | + | ||
| 141 | + std::vector<float> out(out_num_frames * out_feat_dim); | ||
| 142 | + | ||
| 143 | + const float *p_in = in.data(); | ||
| 144 | + float *p_out = out.data(); | ||
| 145 | + | ||
| 146 | + for (int32_t i = 0; i != out_num_frames; ++i) { | ||
| 147 | + std::copy(p_in, p_in + out_feat_dim, p_out); | ||
| 148 | + | ||
| 149 | + p_out += out_feat_dim; | ||
| 150 | + p_in += lfr_window_shift * in_feat_dim; | ||
| 151 | + } | ||
| 152 | + | ||
| 153 | + return out; | ||
| 154 | + } | ||
| 155 | + | ||
| 156 | + void ApplyCMVN(std::vector<float> *v) const { | ||
| 157 | + const std::vector<float> &neg_mean = model_->NegativeMean(); | ||
| 158 | + const std::vector<float> &inv_stddev = model_->InverseStdDev(); | ||
| 159 | + | ||
| 160 | + int32_t dim = neg_mean.size(); | ||
| 161 | + int32_t num_frames = v->size() / dim; | ||
| 162 | + | ||
| 163 | + float *p = v->data(); | ||
| 164 | + | ||
| 165 | + for (int32_t i = 0; i != num_frames; ++i) { | ||
| 166 | + for (int32_t k = 0; k != dim; ++k) { | ||
| 167 | + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + p += dim; | ||
| 171 | + } | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + OfflineRecognizerConfig config_; | ||
| 175 | + SymbolTable symbol_table_; | ||
| 176 | + std::unique_ptr<OfflineParaformerModel> model_; | ||
| 177 | + std::unique_ptr<OfflineParaformerDecoder> decoder_; | ||
| 178 | +}; | ||
| 179 | + | ||
| 180 | +} // namespace sherpa_onnx | ||
| 181 | + | ||
| 182 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ |
| 1 | // sherpa-onnx/csrc/offline-recognizer-transducer-impl.h | 1 | // sherpa-onnx/csrc/offline-recognizer-transducer-impl.h |
| 2 | // | 2 | // |
| 3 | -// Copyright (c) 2022 Xiaomi Corporation | 3 | +// Copyright (c) 2022-2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ | 5 | #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ |
| 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ | 6 | #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ |
| @@ -6,6 +6,8 @@ | @@ -6,6 +6,8 @@ | ||
| 6 | 6 | ||
| 7 | #include <memory> | 7 | #include <memory> |
| 8 | 8 | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 9 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 11 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 10 | 12 | ||
| 11 | namespace sherpa_onnx { | 13 | namespace sherpa_onnx { |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <string> | 9 | #include <string> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/offline-model-config.h" | ||
| 12 | #include "sherpa-onnx/csrc/offline-stream.h" | 13 | #include "sherpa-onnx/csrc/offline-stream.h" |
| 13 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 14 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| 14 | #include "sherpa-onnx/csrc/parse-options.h" | 15 | #include "sherpa-onnx/csrc/parse-options.h" |
| @@ -32,7 +33,7 @@ struct OfflineRecognitionResult { | @@ -32,7 +33,7 @@ struct OfflineRecognitionResult { | ||
| 32 | 33 | ||
| 33 | struct OfflineRecognizerConfig { | 34 | struct OfflineRecognizerConfig { |
| 34 | OfflineFeatureExtractorConfig feat_config; | 35 | OfflineFeatureExtractorConfig feat_config; |
| 35 | - OfflineTransducerModelConfig model_config; | 36 | + OfflineModelConfig model_config; |
| 36 | 37 | ||
| 37 | std::string decoding_method = "greedy_search"; | 38 | std::string decoding_method = "greedy_search"; |
| 38 | // only greedy_search is implemented | 39 | // only greedy_search is implemented |
| @@ -40,7 +41,7 @@ struct OfflineRecognizerConfig { | @@ -40,7 +41,7 @@ struct OfflineRecognizerConfig { | ||
| 40 | 41 | ||
| 41 | OfflineRecognizerConfig() = default; | 42 | OfflineRecognizerConfig() = default; |
| 42 | OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, | 43 | OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, |
| 43 | - const OfflineTransducerModelConfig &model_config, | 44 | + const OfflineModelConfig &model_config, |
| 44 | const std::string &decoding_method) | 45 | const std::string &decoding_method) |
| 45 | : feat_config(feat_config), | 46 | : feat_config(feat_config), |
| 46 | model_config(model_config), | 47 | model_config(model_config), |
| @@ -38,7 +38,7 @@ std::string OfflineFeatureExtractorConfig::ToString() const { | @@ -38,7 +38,7 @@ std::string OfflineFeatureExtractorConfig::ToString() const { | ||
| 38 | 38 | ||
| 39 | class OfflineStream::Impl { | 39 | class OfflineStream::Impl { |
| 40 | public: | 40 | public: |
| 41 | - explicit Impl(const OfflineFeatureExtractorConfig &config) { | 41 | + explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) { |
| 42 | opts_.frame_opts.dither = 0; | 42 | opts_.frame_opts.dither = 0; |
| 43 | opts_.frame_opts.snip_edges = false; | 43 | opts_.frame_opts.snip_edges = false; |
| 44 | opts_.frame_opts.samp_freq = config.sampling_rate; | 44 | opts_.frame_opts.samp_freq = config.sampling_rate; |
| @@ -48,6 +48,19 @@ class OfflineStream::Impl { | @@ -48,6 +48,19 @@ class OfflineStream::Impl { | ||
| 48 | } | 48 | } |
| 49 | 49 | ||
| 50 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { | 50 | void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { |
| 51 | + if (config_.normalize_samples) { | ||
| 52 | + AcceptWaveformImpl(sampling_rate, waveform, n); | ||
| 53 | + } else { | ||
| 54 | + std::vector<float> buf(n); | ||
| 55 | + for (int32_t i = 0; i != n; ++i) { | ||
| 56 | + buf[i] = waveform[i] * 32768; | ||
| 57 | + } | ||
| 58 | + AcceptWaveformImpl(sampling_rate, buf.data(), n); | ||
| 59 | + } | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform, | ||
| 63 | + int32_t n) { | ||
| 51 | if (sampling_rate != opts_.frame_opts.samp_freq) { | 64 | if (sampling_rate != opts_.frame_opts.samp_freq) { |
| 52 | SHERPA_ONNX_LOGE( | 65 | SHERPA_ONNX_LOGE( |
| 53 | "Creating a resampler:\n" | 66 | "Creating a resampler:\n" |
| @@ -101,6 +114,7 @@ class OfflineStream::Impl { | @@ -101,6 +114,7 @@ class OfflineStream::Impl { | ||
| 101 | const OfflineRecognitionResult &GetResult() const { return r_; } | 114 | const OfflineRecognitionResult &GetResult() const { return r_; } |
| 102 | 115 | ||
| 103 | private: | 116 | private: |
| 117 | + OfflineFeatureExtractorConfig config_; | ||
| 104 | std::unique_ptr<knf::OnlineFbank> fbank_; | 118 | std::unique_ptr<knf::OnlineFbank> fbank_; |
| 105 | knf::FbankOptions opts_; | 119 | knf::FbankOptions opts_; |
| 106 | OfflineRecognitionResult r_; | 120 | OfflineRecognitionResult r_; |
| @@ -23,6 +23,13 @@ struct OfflineFeatureExtractorConfig { | @@ -23,6 +23,13 @@ struct OfflineFeatureExtractorConfig { | ||
| 23 | // Feature dimension | 23 | // Feature dimension |
| 24 | int32_t feature_dim = 80; | 24 | int32_t feature_dim = 80; |
| 25 | 25 | ||
| 26 | + // Set internally by some models, e.g., paraformer | ||
| 27 | + // This parameter is not exposed to users from the commandline | ||
| 28 | + // If true, the feature extractor expects inputs to be normalized to | ||
| 29 | + // the range [-1, 1]. | ||
| 30 | + // If false, we will multiply the inputs by 32768 | ||
| 31 | + bool normalize_samples = true; | ||
| 32 | + | ||
| 26 | std::string ToString() const; | 33 | std::string ToString() const; |
| 27 | 34 | ||
| 28 | void Register(ParseOptions *po); | 35 | void Register(ParseOptions *po); |
| @@ -14,20 +14,9 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { | @@ -14,20 +14,9 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { | ||
| 14 | po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); | 14 | po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); |
| 15 | po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); | 15 | po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); |
| 16 | po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); | 16 | po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); |
| 17 | - po->Register("tokens", &tokens, "Path to tokens.txt"); | ||
| 18 | - po->Register("num_threads", &num_threads, | ||
| 19 | - "Number of threads to run the neural network"); | ||
| 20 | - | ||
| 21 | - po->Register("debug", &debug, | ||
| 22 | - "true to print model information while loading it."); | ||
| 23 | } | 17 | } |
| 24 | 18 | ||
| 25 | bool OfflineTransducerModelConfig::Validate() const { | 19 | bool OfflineTransducerModelConfig::Validate() const { |
| 26 | - if (!FileExists(tokens)) { | ||
| 27 | - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); | ||
| 28 | - return false; | ||
| 29 | - } | ||
| 30 | - | ||
| 31 | if (!FileExists(encoder_filename)) { | 20 | if (!FileExists(encoder_filename)) { |
| 32 | SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); | 21 | SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); |
| 33 | return false; | 22 | return false; |
| @@ -43,11 +32,6 @@ bool OfflineTransducerModelConfig::Validate() const { | @@ -43,11 +32,6 @@ bool OfflineTransducerModelConfig::Validate() const { | ||
| 43 | return false; | 32 | return false; |
| 44 | } | 33 | } |
| 45 | 34 | ||
| 46 | - if (num_threads < 1) { | ||
| 47 | - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); | ||
| 48 | - return false; | ||
| 49 | - } | ||
| 50 | - | ||
| 51 | return true; | 35 | return true; |
| 52 | } | 36 | } |
| 53 | 37 | ||
| @@ -57,10 +41,7 @@ std::string OfflineTransducerModelConfig::ToString() const { | @@ -57,10 +41,7 @@ std::string OfflineTransducerModelConfig::ToString() const { | ||
| 57 | os << "OfflineTransducerModelConfig("; | 41 | os << "OfflineTransducerModelConfig("; |
| 58 | os << "encoder_filename=\"" << encoder_filename << "\", "; | 42 | os << "encoder_filename=\"" << encoder_filename << "\", "; |
| 59 | os << "decoder_filename=\"" << decoder_filename << "\", "; | 43 | os << "decoder_filename=\"" << decoder_filename << "\", "; |
| 60 | - os << "joiner_filename=\"" << joiner_filename << "\", "; | ||
| 61 | - os << "tokens=\"" << tokens << "\", "; | ||
| 62 | - os << "num_threads=" << num_threads << ", "; | ||
| 63 | - os << "debug=" << (debug ? "True" : "False") << ")"; | 44 | + os << "joiner_filename=\"" << joiner_filename << "\")"; |
| 64 | 45 | ||
| 65 | return os.str(); | 46 | return os.str(); |
| 66 | } | 47 | } |
| @@ -14,22 +14,14 @@ struct OfflineTransducerModelConfig { | @@ -14,22 +14,14 @@ struct OfflineTransducerModelConfig { | ||
| 14 | std::string encoder_filename; | 14 | std::string encoder_filename; |
| 15 | std::string decoder_filename; | 15 | std::string decoder_filename; |
| 16 | std::string joiner_filename; | 16 | std::string joiner_filename; |
| 17 | - std::string tokens; | ||
| 18 | - int32_t num_threads = 2; | ||
| 19 | - bool debug = false; | ||
| 20 | 17 | ||
| 21 | OfflineTransducerModelConfig() = default; | 18 | OfflineTransducerModelConfig() = default; |
| 22 | OfflineTransducerModelConfig(const std::string &encoder_filename, | 19 | OfflineTransducerModelConfig(const std::string &encoder_filename, |
| 23 | const std::string &decoder_filename, | 20 | const std::string &decoder_filename, |
| 24 | - const std::string &joiner_filename, | ||
| 25 | - const std::string &tokens, int32_t num_threads, | ||
| 26 | - bool debug) | 21 | + const std::string &joiner_filename) |
| 27 | : encoder_filename(encoder_filename), | 22 | : encoder_filename(encoder_filename), |
| 28 | decoder_filename(decoder_filename), | 23 | decoder_filename(decoder_filename), |
| 29 | - joiner_filename(joiner_filename), | ||
| 30 | - tokens(tokens), | ||
| 31 | - num_threads(num_threads), | ||
| 32 | - debug(debug) {} | 24 | + joiner_filename(joiner_filename) {} |
| 33 | 25 | ||
| 34 | void Register(ParseOptions *po); | 26 | void Register(ParseOptions *po); |
| 35 | bool Validate() const; | 27 | bool Validate() const; |
| @@ -16,7 +16,7 @@ namespace sherpa_onnx { | @@ -16,7 +16,7 @@ namespace sherpa_onnx { | ||
| 16 | 16 | ||
| 17 | class OfflineTransducerModel::Impl { | 17 | class OfflineTransducerModel::Impl { |
| 18 | public: | 18 | public: |
| 19 | - explicit Impl(const OfflineTransducerModelConfig &config) | 19 | + explicit Impl(const OfflineModelConfig &config) |
| 20 | : config_(config), | 20 | : config_(config), |
| 21 | env_(ORT_LOGGING_LEVEL_WARNING), | 21 | env_(ORT_LOGGING_LEVEL_WARNING), |
| 22 | sess_opts_{}, | 22 | sess_opts_{}, |
| @@ -24,17 +24,17 @@ class OfflineTransducerModel::Impl { | @@ -24,17 +24,17 @@ class OfflineTransducerModel::Impl { | ||
| 24 | sess_opts_.SetIntraOpNumThreads(config.num_threads); | 24 | sess_opts_.SetIntraOpNumThreads(config.num_threads); |
| 25 | sess_opts_.SetInterOpNumThreads(config.num_threads); | 25 | sess_opts_.SetInterOpNumThreads(config.num_threads); |
| 26 | { | 26 | { |
| 27 | - auto buf = ReadFile(config.encoder_filename); | 27 | + auto buf = ReadFile(config.transducer.encoder_filename); |
| 28 | InitEncoder(buf.data(), buf.size()); | 28 | InitEncoder(buf.data(), buf.size()); |
| 29 | } | 29 | } |
| 30 | 30 | ||
| 31 | { | 31 | { |
| 32 | - auto buf = ReadFile(config.decoder_filename); | 32 | + auto buf = ReadFile(config.transducer.decoder_filename); |
| 33 | InitDecoder(buf.data(), buf.size()); | 33 | InitDecoder(buf.data(), buf.size()); |
| 34 | } | 34 | } |
| 35 | 35 | ||
| 36 | { | 36 | { |
| 37 | - auto buf = ReadFile(config.joiner_filename); | 37 | + auto buf = ReadFile(config.transducer.joiner_filename); |
| 38 | InitJoiner(buf.data(), buf.size()); | 38 | InitJoiner(buf.data(), buf.size()); |
| 39 | } | 39 | } |
| 40 | } | 40 | } |
| @@ -164,7 +164,7 @@ class OfflineTransducerModel::Impl { | @@ -164,7 +164,7 @@ class OfflineTransducerModel::Impl { | ||
| 164 | } | 164 | } |
| 165 | 165 | ||
| 166 | private: | 166 | private: |
| 167 | - OfflineTransducerModelConfig config_; | 167 | + OfflineModelConfig config_; |
| 168 | Ort::Env env_; | 168 | Ort::Env env_; |
| 169 | Ort::SessionOptions sess_opts_; | 169 | Ort::SessionOptions sess_opts_; |
| 170 | Ort::AllocatorWithDefaultOptions allocator_; | 170 | Ort::AllocatorWithDefaultOptions allocator_; |
| @@ -195,8 +195,7 @@ class OfflineTransducerModel::Impl { | @@ -195,8 +195,7 @@ class OfflineTransducerModel::Impl { | ||
| 195 | int32_t context_size_ = 0; // initialized in InitDecoder | 195 | int32_t context_size_ = 0; // initialized in InitDecoder |
| 196 | }; | 196 | }; |
| 197 | 197 | ||
| 198 | -OfflineTransducerModel::OfflineTransducerModel( | ||
| 199 | - const OfflineTransducerModelConfig &config) | 198 | +OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config) |
| 200 | : impl_(std::make_unique<Impl>(config)) {} | 199 | : impl_(std::make_unique<Impl>(config)) {} |
| 201 | 200 | ||
| 202 | OfflineTransducerModel::~OfflineTransducerModel() = default; | 201 | OfflineTransducerModel::~OfflineTransducerModel() = default; |
| @@ -9,7 +9,7 @@ | @@ -9,7 +9,7 @@ | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | -#include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 12 | +#include "sherpa-onnx/csrc/offline-model-config.h" |
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| @@ -17,7 +17,7 @@ struct OfflineTransducerDecoderResult; | @@ -17,7 +17,7 @@ struct OfflineTransducerDecoderResult; | ||
| 17 | 17 | ||
| 18 | class OfflineTransducerModel { | 18 | class OfflineTransducerModel { |
| 19 | public: | 19 | public: |
| 20 | - explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config); | 20 | + explicit OfflineTransducerModel(const OfflineModelConfig &config); |
| 21 | ~OfflineTransducerModel(); | 21 | ~OfflineTransducerModel(); |
| 22 | 22 | ||
| 23 | /** Run the encoder. | 23 | /** Run the encoder. |
| @@ -25,6 +25,7 @@ class OfflineTransducerModel { | @@ -25,6 +25,7 @@ class OfflineTransducerModel { | ||
| 25 | * @param features A tensor of shape (N, T, C). It is changed in-place. | 25 | * @param features A tensor of shape (N, T, C). It is changed in-place. |
| 26 | * @param features_length A 1-D tensor of shape (N,) containing number of | 26 | * @param features_length A 1-D tensor of shape (N,) containing number of |
| 27 | * valid frames in `features` before padding. | 27 | * valid frames in `features` before padding. |
| 28 | + * Its dtype is int64_t. | ||
| 28 | * | 29 | * |
| 29 | * @return Return a pair containing: | 30 | * @return Return a pair containing: |
| 30 | * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) | 31 | * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) |
| @@ -5,6 +5,7 @@ | @@ -5,6 +5,7 @@ | ||
| 5 | 5 | ||
| 6 | #include <algorithm> | 6 | #include <algorithm> |
| 7 | #include <fstream> | 7 | #include <fstream> |
| 8 | +#include <sstream> | ||
| 8 | #include <string> | 9 | #include <string> |
| 9 | #include <vector> | 10 | #include <vector> |
| 10 | 11 | ||
| @@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) { | @@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) { | ||
| 133 | fprintf(stderr, "\n"); | 134 | fprintf(stderr, "\n"); |
| 134 | } | 135 | } |
| 135 | 136 | ||
| 137 | +template <typename T /*= float*/> | ||
| 136 | void Print2D(Ort::Value *v) { | 138 | void Print2D(Ort::Value *v) { |
| 137 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 139 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 138 | - const float *d = v->GetTensorData<float>(); | 140 | + const T *d = v->GetTensorData<T>(); |
| 139 | 141 | ||
| 142 | + std::ostringstream os; | ||
| 140 | for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) { | 143 | for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) { |
| 141 | for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) { | 144 | for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) { |
| 142 | - fprintf(stderr, "%.3f ", *d); | 145 | + os << *d << " "; |
| 143 | } | 146 | } |
| 144 | - fprintf(stderr, "\n"); | 147 | + os << "\n"; |
| 145 | } | 148 | } |
| 146 | - fprintf(stderr, "\n"); | 149 | + fprintf(stderr, "%s\n", os.str().c_str()); |
| 147 | } | 150 | } |
| 148 | 151 | ||
| 152 | +template void Print2D<int64_t>(Ort::Value *v); | ||
| 153 | +template void Print2D<float>(Ort::Value *v); | ||
| 154 | + | ||
| 149 | void Print3D(Ort::Value *v) { | 155 | void Print3D(Ort::Value *v) { |
| 150 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); | 156 | std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape(); |
| 151 | const float *d = v->GetTensorData<float>(); | 157 | const float *d = v->GetTensorData<float>(); |
| @@ -24,18 +24,6 @@ | @@ -24,18 +24,6 @@ | ||
| 24 | 24 | ||
| 25 | namespace sherpa_onnx { | 25 | namespace sherpa_onnx { |
| 26 | 26 | ||
| 27 | -#ifdef _MSC_VER | ||
| 28 | -// See | ||
| 29 | -// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t | ||
| 30 | -static std::wstring ToWide(const std::string &s) { | ||
| 31 | - std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter; | ||
| 32 | - return converter.from_bytes(s); | ||
| 33 | -} | ||
| 34 | -#define SHERPA_MAYBE_WIDE(s) ToWide(s) | ||
| 35 | -#else | ||
| 36 | -#define SHERPA_MAYBE_WIDE(s) s | ||
| 37 | -#endif | ||
| 38 | - | ||
| 39 | /** | 27 | /** |
| 40 | * Get the input names of a model. | 28 | * Get the input names of a model. |
| 41 | * | 29 | * |
| @@ -79,6 +67,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | @@ -79,6 +67,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | ||
| 79 | void Print1D(Ort::Value *v); | 67 | void Print1D(Ort::Value *v); |
| 80 | 68 | ||
| 81 | // Print a 2-D tensor to stderr | 69 | // Print a 2-D tensor to stderr |
| 70 | +template <typename T = float> | ||
| 82 | void Print2D(Ort::Value *v); | 71 | void Print2D(Ort::Value *v); |
| 83 | 72 | ||
| 84 | // Print a 3-D tensor to stderr | 73 | // Print a 3-D tensor to stderr |
| @@ -9,24 +9,35 @@ | @@ -9,24 +9,35 @@ | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | #include "sherpa-onnx/csrc/offline-recognizer.h" | 11 | #include "sherpa-onnx/csrc/offline-recognizer.h" |
| 12 | -#include "sherpa-onnx/csrc/offline-stream.h" | ||
| 13 | -#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 14 | -#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | ||
| 15 | -#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 16 | -#include "sherpa-onnx/csrc/pad-sequence.h" | ||
| 17 | -#include "sherpa-onnx/csrc/symbol-table.h" | 12 | +#include "sherpa-onnx/csrc/parse-options.h" |
| 18 | #include "sherpa-onnx/csrc/wave-reader.h" | 13 | #include "sherpa-onnx/csrc/wave-reader.h" |
| 19 | 14 | ||
| 20 | int main(int32_t argc, char *argv[]) { | 15 | int main(int32_t argc, char *argv[]) { |
| 21 | - if (argc < 6 || argc > 8) { | ||
| 22 | - const char *usage = R"usage( | 16 | + const char *kUsageMessage = R"usage( |
| 23 | Usage: | 17 | Usage: |
| 18 | + | ||
| 19 | +(1) Transducer from icefall | ||
| 20 | + | ||
| 21 | + ./bin/sherpa-onnx-offline \ | ||
| 22 | + --tokens=/path/to/tokens.txt \ | ||
| 23 | + --encoder=/path/to/encoder.onnx \ | ||
| 24 | + --decoder=/path/to/decoder.onnx \ | ||
| 25 | + --joiner=/path/to/joiner.onnx \ | ||
| 26 | + --num-threads=2 \ | ||
| 27 | + --decoding-method=greedy_search \ | ||
| 28 | + /path/to/foo.wav [bar.wav foobar.wav ...] | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +(2) Paraformer from FunASR | ||
| 32 | + | ||
| 24 | ./bin/sherpa-onnx-offline \ | 33 | ./bin/sherpa-onnx-offline \ |
| 25 | - /path/to/tokens.txt \ | ||
| 26 | - /path/to/encoder.onnx \ | ||
| 27 | - /path/to/decoder.onnx \ | ||
| 28 | - /path/to/joiner.onnx \ | ||
| 29 | - /path/to/foo.wav [num_threads [decoding_method]] | 34 | + --tokens=/path/to/tokens.txt \ |
| 35 | + --paraformer=/path/to/model.onnx \ | ||
| 36 | + --num-threads=2 \ | ||
| 37 | + --decoding-method=greedy_search \ | ||
| 38 | + /path/to/foo.wav [bar.wav foobar.wav ...] | ||
| 39 | + | ||
| 40 | +Note: It supports decoding multiple files in batches | ||
| 30 | 41 | ||
| 31 | Default value for num_threads is 2. | 42 | Default value for num_threads is 2. |
| 32 | Valid values for decoding_method: greedy_search. | 43 | Valid values for decoding_method: greedy_search. |
| @@ -37,29 +48,15 @@ Please refer to | @@ -37,29 +48,15 @@ Please refer to | ||
| 37 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | 48 | https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html |
| 38 | for a list of pre-trained models to download. | 49 | for a list of pre-trained models to download. |
| 39 | )usage"; | 50 | )usage"; |
| 40 | - fprintf(stderr, "%s\n", usage); | ||
| 41 | - | ||
| 42 | - return 0; | ||
| 43 | - } | ||
| 44 | 51 | ||
| 52 | + sherpa_onnx::ParseOptions po(kUsageMessage); | ||
| 45 | sherpa_onnx::OfflineRecognizerConfig config; | 53 | sherpa_onnx::OfflineRecognizerConfig config; |
| 54 | + config.Register(&po); | ||
| 46 | 55 | ||
| 47 | - config.model_config.tokens = argv[1]; | ||
| 48 | - | ||
| 49 | - config.model_config.debug = false; | ||
| 50 | - config.model_config.encoder_filename = argv[2]; | ||
| 51 | - config.model_config.decoder_filename = argv[3]; | ||
| 52 | - config.model_config.joiner_filename = argv[4]; | ||
| 53 | - | ||
| 54 | - std::string wav_filename = argv[5]; | ||
| 55 | - | ||
| 56 | - config.model_config.num_threads = 2; | ||
| 57 | - if (argc == 7 && atoi(argv[6]) > 0) { | ||
| 58 | - config.model_config.num_threads = atoi(argv[6]); | ||
| 59 | - } | ||
| 60 | - | ||
| 61 | - if (argc == 8) { | ||
| 62 | - config.decoding_method = argv[7]; | 56 | + po.Read(argc, argv); |
| 57 | + if (po.NumArgs() < 1) { | ||
| 58 | + po.PrintUsage(); | ||
| 59 | + exit(EXIT_FAILURE); | ||
| 63 | } | 60 | } |
| 64 | 61 | ||
| 65 | fprintf(stderr, "%s\n", config.ToString().c_str()); | 62 | fprintf(stderr, "%s\n", config.ToString().c_str()); |
| @@ -69,35 +66,43 @@ for a list of pre-trained models to download. | @@ -69,35 +66,43 @@ for a list of pre-trained models to download. | ||
| 69 | return -1; | 66 | return -1; |
| 70 | } | 67 | } |
| 71 | 68 | ||
| 72 | - int32_t sampling_rate = -1; | ||
| 73 | - | ||
| 74 | - bool is_ok = false; | ||
| 75 | - std::vector<float> samples = | ||
| 76 | - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 77 | - if (!is_ok) { | ||
| 78 | - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 79 | - return -1; | ||
| 80 | - } | ||
| 81 | - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); | ||
| 82 | - | ||
| 83 | - float duration = samples.size() / static_cast<float>(sampling_rate); | ||
| 84 | - | ||
| 85 | sherpa_onnx::OfflineRecognizer recognizer(config); | 69 | sherpa_onnx::OfflineRecognizer recognizer(config); |
| 86 | - auto s = recognizer.CreateStream(); | ||
| 87 | 70 | ||
| 88 | auto begin = std::chrono::steady_clock::now(); | 71 | auto begin = std::chrono::steady_clock::now(); |
| 89 | fprintf(stderr, "Started\n"); | 72 | fprintf(stderr, "Started\n"); |
| 90 | 73 | ||
| 91 | - s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | 74 | + std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; |
| 75 | + std::vector<sherpa_onnx::OfflineStream *> ss_pointers; | ||
| 76 | + float duration = 0; | ||
| 77 | + for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||
| 78 | + std::string wav_filename = po.GetArg(i); | ||
| 79 | + int32_t sampling_rate = -1; | ||
| 80 | + bool is_ok = false; | ||
| 81 | + std::vector<float> samples = | ||
| 82 | + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||
| 83 | + if (!is_ok) { | ||
| 84 | + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||
| 85 | + return -1; | ||
| 86 | + } | ||
| 87 | + duration += samples.size() / static_cast<float>(sampling_rate); | ||
| 88 | + | ||
| 89 | + auto s = recognizer.CreateStream(); | ||
| 90 | + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||
| 91 | + | ||
| 92 | + ss.push_back(std::move(s)); | ||
| 93 | + ss_pointers.push_back(ss.back().get()); | ||
| 94 | + } | ||
| 92 | 95 | ||
| 93 | - recognizer.DecodeStream(s.get()); | 96 | + recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); |
| 94 | 97 | ||
| 95 | - fprintf(stderr, "Done!\n"); | 98 | + auto end = std::chrono::steady_clock::now(); |
| 96 | 99 | ||
| 97 | - fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), | ||
| 98 | - s->GetResult().text.c_str()); | 100 | + fprintf(stderr, "Done!\n\n"); |
| 101 | + for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||
| 102 | + fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(), | ||
| 103 | + ss[i - 1]->GetResult().text.c_str()); | ||
| 104 | + } | ||
| 99 | 105 | ||
| 100 | - auto end = std::chrono::steady_clock::now(); | ||
| 101 | float elapsed_seconds = | 106 | float elapsed_seconds = |
| 102 | std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | 107 | std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) |
| 103 | .count() / | 108 | .count() / |
-
请 注册 或 登录 后发表评论