正在显示
12 个修改的文件
包含
176 行增加
和
71 行删除
| 1 | function(download_kaldi_native_fbank) | 1 | function(download_kaldi_native_fbank) |
| 2 | include(FetchContent) | 2 | include(FetchContent) |
| 3 | 3 | ||
| 4 | - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.11.tar.gz") | ||
| 5 | - set(kaldi_native_fbank_HASH "SHA256=e69ae25ef6f30566ef31ca949dd1b0b8ec3a827caeba93a61d82bb848dac5d69") | 4 | + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz") |
| 5 | + set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44") | ||
| 6 | 6 | ||
| 7 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) | 7 | set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) |
| 8 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) | 8 | set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) |
| @@ -11,10 +11,11 @@ function(download_kaldi_native_fbank) | @@ -11,10 +11,11 @@ function(download_kaldi_native_fbank) | ||
| 11 | # If you don't have access to the Internet, | 11 | # If you don't have access to the Internet, |
| 12 | # please pre-download kaldi-native-fbank | 12 | # please pre-download kaldi-native-fbank |
| 13 | set(possible_file_locations | 13 | set(possible_file_locations |
| 14 | - $ENV{HOME}/Downloads/kaldi-native-fbank-1.11.tar.gz | ||
| 15 | - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.11.tar.gz | ||
| 16 | - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.11.tar.gz | ||
| 17 | - /tmp/kaldi-native-fbank-1.11.tar.gz | 14 | + $ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz |
| 15 | + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz | ||
| 16 | + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz | ||
| 17 | + /tmp/kaldi-native-fbank-1.12.tar.gz | ||
| 18 | + /star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz | ||
| 18 | ) | 19 | ) |
| 19 | 20 | ||
| 20 | foreach(f IN LISTS possible_file_locations) | 21 | foreach(f IN LISTS possible_file_locations) |
| @@ -9,6 +9,7 @@ function(download_onnxruntime) | @@ -9,6 +9,7 @@ function(download_onnxruntime) | ||
| 9 | ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz | 9 | ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz |
| 10 | ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz | 10 | ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz |
| 11 | /tmp/onnxruntime-linux-x64-1.14.0.tgz | 11 | /tmp/onnxruntime-linux-x64-1.14.0.tgz |
| 12 | + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz | ||
| 12 | ) | 13 | ) |
| 13 | 14 | ||
| 14 | set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") | 15 | set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") |
| 1 | include_directories(${CMAKE_SOURCE_DIR}) | 1 | include_directories(${CMAKE_SOURCE_DIR}) |
| 2 | 2 | ||
| 3 | add_executable(sherpa-onnx | 3 | add_executable(sherpa-onnx |
| 4 | - decode.cc | ||
| 5 | features.cc | 4 | features.cc |
| 6 | online-lstm-transducer-model.cc | 5 | online-lstm-transducer-model.cc |
| 6 | + online-transducer-greedy-search-decoder.cc | ||
| 7 | online-transducer-model-config.cc | 7 | online-transducer-model-config.cc |
| 8 | online-transducer-model.cc | 8 | online-transducer-model.cc |
| 9 | onnx-utils.cc | 9 | onnx-utils.cc |
sherpa-onnx/csrc/decode.h
已删除
100644 → 0
| 1 | -// sherpa/csrc/decode.h | ||
| 2 | -// | ||
| 3 | -// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | - | ||
| 5 | -#ifndef SHERPA_ONNX_CSRC_DECODE_H_ | ||
| 6 | -#define SHERPA_ONNX_CSRC_DECODE_H_ | ||
| 7 | - | ||
| 8 | -#include <vector> | ||
| 9 | - | ||
| 10 | -#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 11 | - | ||
| 12 | -namespace sherpa_onnx { | ||
| 13 | - | ||
| 14 | -/** Greedy search for non-streaming ASR. | ||
| 15 | - * | ||
| 16 | - * @TODO(fangjun) Support batch size > 1 | ||
| 17 | - * | ||
| 18 | - * @param model The RnntModel | ||
| 19 | - * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). | ||
| 20 | - */ | ||
| 21 | -void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, | ||
| 22 | - std::vector<int64_t> *hyp); | ||
| 23 | - | ||
| 24 | -} // namespace sherpa_onnx | ||
| 25 | - | ||
| 26 | -#endif // SHERPA_ONNX_CSRC_DECODE_H_ |
| @@ -15,16 +15,16 @@ namespace sherpa_onnx { | @@ -15,16 +15,16 @@ namespace sherpa_onnx { | ||
| 15 | 15 | ||
| 16 | class FeatureExtractor::Impl { | 16 | class FeatureExtractor::Impl { |
| 17 | public: | 17 | public: |
| 18 | - Impl(int32_t sampling_rate, int32_t feature_dim) { | 18 | + explicit Impl(const FeatureExtractorConfig &config) { |
| 19 | opts_.frame_opts.dither = 0; | 19 | opts_.frame_opts.dither = 0; |
| 20 | opts_.frame_opts.snip_edges = false; | 20 | opts_.frame_opts.snip_edges = false; |
| 21 | - opts_.frame_opts.samp_freq = sampling_rate; | 21 | + opts_.frame_opts.samp_freq = config.sampling_rate; |
| 22 | 22 | ||
| 23 | // cache 100 seconds of feature frames, which is more than enough | 23 | // cache 100 seconds of feature frames, which is more than enough |
| 24 | // for real needs | 24 | // for real needs |
| 25 | opts_.frame_opts.max_feature_vectors = 100 * 100; | 25 | opts_.frame_opts.max_feature_vectors = 100 * 100; |
| 26 | 26 | ||
| 27 | - opts_.mel_opts.num_bins = feature_dim; | 27 | + opts_.mel_opts.num_bins = config.feature_dim; |
| 28 | 28 | ||
| 29 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); | 29 | fbank_ = std::make_unique<knf::OnlineFbank>(opts_); |
| 30 | } | 30 | } |
| @@ -80,9 +80,8 @@ class FeatureExtractor::Impl { | @@ -80,9 +80,8 @@ class FeatureExtractor::Impl { | ||
| 80 | mutable std::mutex mutex_; | 80 | mutable std::mutex mutex_; |
| 81 | }; | 81 | }; |
| 82 | 82 | ||
| 83 | -FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/, | ||
| 84 | - int32_t feature_dim /*=80*/) | ||
| 85 | - : impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {} | 83 | +FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) |
| 84 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 86 | 85 | ||
| 87 | FeatureExtractor::~FeatureExtractor() = default; | 86 | FeatureExtractor::~FeatureExtractor() = default; |
| 88 | 87 |
| @@ -10,14 +10,18 @@ | @@ -10,14 +10,18 @@ | ||
| 10 | 10 | ||
| 11 | namespace sherpa_onnx { | 11 | namespace sherpa_onnx { |
| 12 | 12 | ||
| 13 | +struct FeatureExtractorConfig { | ||
| 14 | + int32_t sampling_rate = 16000; | ||
| 15 | + int32_t feature_dim = 80; | ||
| 16 | +}; | ||
| 17 | + | ||
| 13 | class FeatureExtractor { | 18 | class FeatureExtractor { |
| 14 | public: | 19 | public: |
| 15 | /** | 20 | /** |
| 16 | * @param sampling_rate Sampling rate of the data used to train the model. | 21 | * @param sampling_rate Sampling rate of the data used to train the model. |
| 17 | * @param feature_dim Dimension of the features used to train the model. | 22 | * @param feature_dim Dimension of the features used to train the model. |
| 18 | */ | 23 | */ |
| 19 | - explicit FeatureExtractor(int32_t sampling_rate = 16000, | ||
| 20 | - int32_t feature_dim = 80); | 24 | + explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); |
| 21 | ~FeatureExtractor(); | 25 | ~FeatureExtractor(); |
| 22 | 26 | ||
| 23 | /** | 27 | /** |
sherpa-onnx/csrc/online-transducer-decoder.h
0 → 100644
| 1 | +// sherpa/csrc/online-transducer-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +struct OnlineTransducerDecoderResult { | ||
| 15 | + /// The decoded token IDs so far | ||
| 16 | + std::vector<int64_t> tokens; | ||
| 17 | +}; | ||
| 18 | + | ||
| 19 | +class OnlineTransducerDecoder { | ||
| 20 | + public: | ||
| 21 | + virtual ~OnlineTransducerDecoder() = default; | ||
| 22 | + | ||
| 23 | + /* Return an empty result. | ||
| 24 | + * | ||
| 25 | + * To simplify the decoding code, we add `context_size` blanks | ||
| 26 | + * to the beginning of the decoding result, which will be | ||
| 27 | + * stripped by calling `StripPrecedingBlanks()`. | ||
| 28 | + */ | ||
| 29 | + virtual OnlineTransducerDecoderResult GetEmptyResult() = 0; | ||
| 30 | + | ||
| 31 | + /** Strip blanks added by `GetEmptyResult()`. | ||
| 32 | + * | ||
| 33 | + * @param r It is changed in-place. | ||
| 34 | + */ | ||
| 35 | + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {} | ||
| 36 | + | ||
| 37 | + /** Run transducer beam search given the output from the encoder model. | ||
| 38 | + * | ||
| 39 | + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) | ||
| 40 | + * @param result It is modified in-place. | ||
| 41 | + * | ||
| 42 | + * @note There is no need to pass encoder_out_length here since for the | ||
| 43 | + * online decoding case, each utterance has the same number of frames | ||
| 44 | + * and there are no paddings. | ||
| 45 | + */ | ||
| 46 | + virtual void Decode(Ort::Value encoder_out, | ||
| 47 | + std::vector<OnlineTransducerDecoderResult> *result) = 0; | ||
| 48 | +}; | ||
| 49 | + | ||
| 50 | +} // namespace sherpa_onnx | ||
| 51 | + | ||
| 52 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ |
| 1 | -// sherpa/csrc/decode.cc | 1 | +// sherpa/csrc/online-transducer-greedy-search-decoder.cc |
| 2 | // | 2 | // |
| 3 | // Copyright (c) 2023 Xiaomi Corporation | 3 | // Copyright (c) 2023 Xiaomi Corporation |
| 4 | 4 | ||
| 5 | -#include "sherpa-onnx/csrc/decode.h" | 5 | +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" |
| 6 | 6 | ||
| 7 | #include <assert.h> | 7 | #include <assert.h> |
| 8 | 8 | ||
| @@ -10,19 +10,9 @@ | @@ -10,19 +10,9 @@ | ||
| 10 | #include <utility> | 10 | #include <utility> |
| 11 | #include <vector> | 11 | #include <vector> |
| 12 | 12 | ||
| 13 | -namespace sherpa_onnx { | ||
| 14 | - | ||
| 15 | -static Ort::Value Clone(Ort::Value *v) { | ||
| 16 | - auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | ||
| 17 | - std::vector<int64_t> shape = type_and_shape.GetShape(); | ||
| 18 | - | ||
| 19 | - auto memory_info = | ||
| 20 | - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" |
| 21 | 14 | ||
| 22 | - return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(), | ||
| 23 | - type_and_shape.GetElementCount(), | ||
| 24 | - shape.data(), shape.size()); | ||
| 25 | -} | 15 | +namespace sherpa_onnx { |
| 26 | 16 | ||
| 27 | static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | 17 | static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { |
| 28 | std::vector<int64_t> encoder_out_shape = | 18 | std::vector<int64_t> encoder_out_shape = |
| @@ -42,26 +32,58 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | @@ -42,26 +32,58 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { | ||
| 42 | encoder_out_dim, shape.data(), shape.size()); | 32 | encoder_out_dim, shape.data(), shape.size()); |
| 43 | } | 33 | } |
| 44 | 34 | ||
| 45 | -void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, | ||
| 46 | - std::vector<int64_t> *hyp) { | 35 | +OnlineTransducerDecoderResult |
| 36 | +OnlineTransducerGreedySearchDecoder::GetEmptyResult() { | ||
| 37 | + int32_t context_size = model_->ContextSize(); | ||
| 38 | + int32_t blank_id = 0; // always 0 | ||
| 39 | + OnlineTransducerDecoderResult r; | ||
| 40 | + r.tokens.resize(context_size, blank_id); | ||
| 41 | + | ||
| 42 | + return r; | ||
| 43 | +} | ||
| 44 | + | ||
| 45 | +void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( | ||
| 46 | + OnlineTransducerDecoderResult *r) { | ||
| 47 | + int32_t context_size = model_->ContextSize(); | ||
| 48 | + | ||
| 49 | + auto start = r->tokens.begin() + context_size; | ||
| 50 | + auto end = r->tokens.end(); | ||
| 51 | + | ||
| 52 | + r->tokens = std::vector<int64_t>(start, end); | ||
| 53 | +} | ||
| 54 | + | ||
| 55 | +void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 56 | + Ort::Value encoder_out, | ||
| 57 | + std::vector<OnlineTransducerDecoderResult> *result) { | ||
| 47 | std::vector<int64_t> encoder_out_shape = | 58 | std::vector<int64_t> encoder_out_shape = |
| 48 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 59 | encoder_out.GetTensorTypeAndShapeInfo().GetShape(); |
| 49 | 60 | ||
| 50 | - if (encoder_out_shape[0] > 1) { | ||
| 51 | - fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n", | ||
| 52 | - static_cast<int32_t>(encoder_out_shape[0])); | 61 | + if (encoder_out_shape[0] != result->size()) { |
| 62 | + fprintf(stderr, | ||
| 63 | + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", | ||
| 64 | + static_cast<int32_t>(encoder_out_shape[0]), | ||
| 65 | + static_cast<int32_t>(result->size())); | ||
| 66 | + exit(-1); | ||
| 67 | + } | ||
| 68 | + | ||
| 69 | + if (result->size() != 1) { | ||
| 70 | + fprintf(stderr, "only batch size == 1 is implemented. Given: %d", | ||
| 71 | + static_cast<int32_t>(result->size())); | ||
| 72 | + exit(-1); | ||
| 53 | } | 73 | } |
| 54 | 74 | ||
| 75 | + auto &hyp = (*result)[0].tokens; | ||
| 76 | + | ||
| 55 | int32_t num_frames = encoder_out_shape[1]; | 77 | int32_t num_frames = encoder_out_shape[1]; |
| 56 | - int32_t vocab_size = model->VocabSize(); | 78 | + int32_t vocab_size = model_->VocabSize(); |
| 57 | 79 | ||
| 58 | - Ort::Value decoder_input = model->BuildDecoderInput(*hyp); | ||
| 59 | - Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input)); | 80 | + Ort::Value decoder_input = model_->BuildDecoderInput(hyp); |
| 81 | + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 60 | 82 | ||
| 61 | for (int32_t t = 0; t != num_frames; ++t) { | 83 | for (int32_t t = 0; t != num_frames; ++t) { |
| 62 | Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); | 84 | Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); |
| 63 | Ort::Value logit = | 85 | Ort::Value logit = |
| 64 | - model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); | 86 | + model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); |
| 65 | const float *p_logit = logit.GetTensorData<float>(); | 87 | const float *p_logit = logit.GetTensorData<float>(); |
| 66 | 88 | ||
| 67 | auto y = static_cast<int32_t>(std::distance( | 89 | auto y = static_cast<int32_t>(std::distance( |
| @@ -69,9 +91,9 @@ void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, | @@ -69,9 +91,9 @@ void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, | ||
| 69 | std::max_element(static_cast<const float *>(p_logit), | 91 | std::max_element(static_cast<const float *>(p_logit), |
| 70 | static_cast<const float *>(p_logit) + vocab_size))); | 92 | static_cast<const float *>(p_logit) + vocab_size))); |
| 71 | if (y != 0) { | 93 | if (y != 0) { |
| 72 | - hyp->push_back(y); | ||
| 73 | - decoder_input = model->BuildDecoderInput(*hyp); | ||
| 74 | - decoder_out = model->RunDecoder(std::move(decoder_input)); | 94 | + hyp.push_back(y); |
| 95 | + decoder_input = model_->BuildDecoderInput(hyp); | ||
| 96 | + decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 75 | } | 97 | } |
| 76 | } | 98 | } |
| 77 | } | 99 | } |
| 1 | +// sherpa/csrc/online-transducer-greedy-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/online-transducer-decoder.h" | ||
| 11 | +#include "sherpa-onnx/csrc/online-transducer-model.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { | ||
| 16 | + public: | ||
| 17 | + explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) | ||
| 18 | + : model_(model) {} | ||
| 19 | + | ||
| 20 | + OnlineTransducerDecoderResult GetEmptyResult() override; | ||
| 21 | + | ||
| 22 | + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override; | ||
| 23 | + | ||
| 24 | + void Decode(Ort::Value encoder_out, | ||
| 25 | + std::vector<OnlineTransducerDecoderResult> *result) override; | ||
| 26 | + | ||
| 27 | + private: | ||
| 28 | + OnlineTransducerModel *model_; // Not owned | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ |
| @@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | @@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | ||
| 46 | } | 46 | } |
| 47 | } | 47 | } |
| 48 | 48 | ||
| 49 | +Ort::Value Clone(Ort::Value *v) { | ||
| 50 | + auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | ||
| 51 | + std::vector<int64_t> shape = type_and_shape.GetShape(); | ||
| 52 | + | ||
| 53 | + auto memory_info = | ||
| 54 | + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
| 55 | + | ||
| 56 | + return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(), | ||
| 57 | + type_and_shape.GetElementCount(), | ||
| 58 | + shape.data(), shape.size()); | ||
| 59 | +} | ||
| 60 | + | ||
| 49 | } // namespace sherpa_onnx | 61 | } // namespace sherpa_onnx |
| @@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | @@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 55 | void PrintModelMetadata(std::ostream &os, | 55 | void PrintModelMetadata(std::ostream &os, |
| 56 | const Ort::ModelMetadata &meta_data); // NOLINT | 56 | const Ort::ModelMetadata &meta_data); // NOLINT |
| 57 | 57 | ||
| 58 | +// Return a shallow copy of v | ||
| 59 | +Ort::Value Clone(Ort::Value *v); | ||
| 60 | + | ||
| 58 | } // namespace sherpa_onnx | 61 | } // namespace sherpa_onnx |
| 59 | 62 | ||
| 60 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 63 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| @@ -9,8 +9,8 @@ | @@ -9,8 +9,8 @@ | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | #include "kaldi-native-fbank/csrc/online-feature.h" | 11 | #include "kaldi-native-fbank/csrc/online-feature.h" |
| 12 | -#include "sherpa-onnx/csrc/decode.h" | ||
| 13 | #include "sherpa-onnx/csrc/features.h" | 12 | #include "sherpa-onnx/csrc/features.h" |
| 13 | +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" | ||
| 14 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" | 14 | #include "sherpa-onnx/csrc/online-transducer-model-config.h" |
| 15 | #include "sherpa-onnx/csrc/online-transducer-model.h" | 15 | #include "sherpa-onnx/csrc/online-transducer-model.h" |
| 16 | #include "sherpa-onnx/csrc/symbol-table.h" | 16 | #include "sherpa-onnx/csrc/symbol-table.h" |
| @@ -64,8 +64,6 @@ for a list of pre-trained models to download. | @@ -64,8 +64,6 @@ for a list of pre-trained models to download. | ||
| 64 | 64 | ||
| 65 | std::vector<Ort::Value> states = model->GetEncoderInitStates(); | 65 | std::vector<Ort::Value> states = model->GetEncoderInitStates(); |
| 66 | 66 | ||
| 67 | - std::vector<int64_t> hyp(model->ContextSize(), 0); | ||
| 68 | - | ||
| 69 | int32_t expected_sampling_rate = 16000; | 67 | int32_t expected_sampling_rate = 16000; |
| 70 | 68 | ||
| 71 | bool is_ok = false; | 69 | bool is_ok = false; |
| @@ -100,6 +98,10 @@ for a list of pre-trained models to download. | @@ -100,6 +98,10 @@ for a list of pre-trained models to download. | ||
| 100 | 98 | ||
| 101 | std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; | 99 | std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim}; |
| 102 | 100 | ||
| 101 | + sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); | ||
| 102 | + std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = { | ||
| 103 | + decoder.GetEmptyResult()}; | ||
| 104 | + | ||
| 103 | for (int32_t start = 0; start + chunk_size < num_frames; | 105 | for (int32_t start = 0; start + chunk_size < num_frames; |
| 104 | start += chunk_shift) { | 106 | start += chunk_shift) { |
| 105 | std::vector<float> features = feat_extractor.GetFrames(start, chunk_size); | 107 | std::vector<float> features = feat_extractor.GetFrames(start, chunk_size); |
| @@ -109,8 +111,10 @@ for a list of pre-trained models to download. | @@ -109,8 +111,10 @@ for a list of pre-trained models to download. | ||
| 109 | x_shape.data(), x_shape.size()); | 111 | x_shape.data(), x_shape.size()); |
| 110 | auto pair = model->RunEncoder(std::move(x), states); | 112 | auto pair = model->RunEncoder(std::move(x), states); |
| 111 | states = std::move(pair.second); | 113 | states = std::move(pair.second); |
| 112 | - sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp); | 114 | + decoder.Decode(std::move(pair.first), &result); |
| 113 | } | 115 | } |
| 116 | + decoder.StripLeadingBlanks(&result[0]); | ||
| 117 | + const auto &hyp = result[0].tokens; | ||
| 114 | std::string text; | 118 | std::string text; |
| 115 | for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { | 119 | for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { |
| 116 | text += sym[hyp[i]]; | 120 | text += sym[hyp[i]]; |
-
请 注册 或 登录 后发表评论