Committed by
GitHub
Add RNN LM rescore for offline ASR with modified_beam_search (#125)
正在显示
32 个修改的文件
包含
842 行增加
和
52 行删除
| @@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI) | @@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI) | ||
| 51 | set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE) | 51 | set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE) |
| 52 | endif() | 52 | endif() |
| 53 | 53 | ||
| 54 | +if(SHERPA_ONNX_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS) | ||
| 55 | + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_PYTHON is ON") | ||
| 56 | + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) | ||
| 57 | +endif() | ||
| 58 | + | ||
| 54 | if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS) | 59 | if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS) |
| 55 | message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON") | 60 | message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON") |
| 56 | set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) | 61 | set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) |
| @@ -18,6 +18,8 @@ set(sources | @@ -18,6 +18,8 @@ set(sources | ||
| 18 | hypothesis.cc | 18 | hypothesis.cc |
| 19 | offline-ctc-greedy-search-decoder.cc | 19 | offline-ctc-greedy-search-decoder.cc |
| 20 | offline-ctc-model.cc | 20 | offline-ctc-model.cc |
| 21 | + offline-lm-config.cc | ||
| 22 | + offline-lm.cc | ||
| 21 | offline-model-config.cc | 23 | offline-model-config.cc |
| 22 | offline-nemo-enc-dec-ctc-model-config.cc | 24 | offline-nemo-enc-dec-ctc-model-config.cc |
| 23 | offline-nemo-enc-dec-ctc-model.cc | 25 | offline-nemo-enc-dec-ctc-model.cc |
| @@ -26,10 +28,13 @@ set(sources | @@ -26,10 +28,13 @@ set(sources | ||
| 26 | offline-paraformer-model.cc | 28 | offline-paraformer-model.cc |
| 27 | offline-recognizer-impl.cc | 29 | offline-recognizer-impl.cc |
| 28 | offline-recognizer.cc | 30 | offline-recognizer.cc |
| 31 | + offline-rnn-lm.cc | ||
| 29 | offline-stream.cc | 32 | offline-stream.cc |
| 30 | offline-transducer-greedy-search-decoder.cc | 33 | offline-transducer-greedy-search-decoder.cc |
| 31 | offline-transducer-model-config.cc | 34 | offline-transducer-model-config.cc |
| 32 | offline-transducer-model.cc | 35 | offline-transducer-model.cc |
| 36 | + offline-transducer-modified-beam-search-decoder.cc | ||
| 37 | + online-lm-config.cc | ||
| 33 | online-lstm-transducer-model.cc | 38 | online-lstm-transducer-model.cc |
| 34 | online-recognizer.cc | 39 | online-recognizer.cc |
| 35 | online-stream.cc | 40 | online-stream.cc |
| @@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) { | @@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) { | ||
| 17 | hyps_dict_[key] = std::move(hyp); | 17 | hyps_dict_[key] = std::move(hyp); |
| 18 | } else { | 18 | } else { |
| 19 | it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); | 19 | it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob); |
| 20 | + | ||
| 21 | + it->second.lm_log_prob = | ||
| 22 | + LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob); | ||
| 20 | } | 23 | } |
| 21 | } | 24 | } |
| 22 | 25 | ||
| @@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { | @@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { | ||
| 24 | if (length_norm == false) { | 27 | if (length_norm == false) { |
| 25 | return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), | 28 | return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), |
| 26 | [](const auto &left, auto &right) -> bool { | 29 | [](const auto &left, auto &right) -> bool { |
| 27 | - return left.second.log_prob < | ||
| 28 | - right.second.log_prob; | 30 | + return left.second.TotalLogProb() < |
| 31 | + right.second.TotalLogProb(); | ||
| 29 | }) | 32 | }) |
| 30 | ->second; | 33 | ->second; |
| 31 | } else { | 34 | } else { |
| @@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { | @@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { | ||
| 33 | return std::max_element( | 36 | return std::max_element( |
| 34 | hyps_dict_.begin(), hyps_dict_.end(), | 37 | hyps_dict_.begin(), hyps_dict_.end(), |
| 35 | [](const auto &left, const auto &right) -> bool { | 38 | [](const auto &left, const auto &right) -> bool { |
| 36 | - return left.second.log_prob / left.second.ys.size() < | ||
| 37 | - right.second.log_prob / right.second.ys.size(); | 39 | + return left.second.TotalLogProb() / left.second.ys.size() < |
| 40 | + right.second.TotalLogProb() / right.second.ys.size(); | ||
| 38 | }) | 41 | }) |
| 39 | ->second; | 42 | ->second; |
| 40 | } | 43 | } |
| @@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const { | @@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const { | ||
| 47 | std::vector<Hypothesis> all_hyps = Vec(); | 50 | std::vector<Hypothesis> all_hyps = Vec(); |
| 48 | 51 | ||
| 49 | if (length_norm == false) { | 52 | if (length_norm == false) { |
| 50 | - std::partial_sort( | ||
| 51 | - all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), | ||
| 52 | - [](const auto &a, const auto &b) { return a.log_prob > b.log_prob; }); | 53 | + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), |
| 54 | + [](const auto &a, const auto &b) { | ||
| 55 | + return a.TotalLogProb() > b.TotalLogProb(); | ||
| 56 | + }); | ||
| 53 | } else { | 57 | } else { |
| 54 | // for length_norm is true | 58 | // for length_norm is true |
| 55 | std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), | 59 | std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), |
| 56 | [](const auto &a, const auto &b) { | 60 | [](const auto &a, const auto &b) { |
| 57 | - return a.log_prob / a.ys.size() > | ||
| 58 | - b.log_prob / b.ys.size(); | 61 | + return a.TotalLogProb() / a.ys.size() > |
| 62 | + b.TotalLogProb() / b.ys.size(); | ||
| 59 | }); | 63 | }); |
| 60 | } | 64 | } |
| 61 | 65 |
| @@ -25,14 +25,20 @@ struct Hypothesis { | @@ -25,14 +25,20 @@ struct Hypothesis { | ||
| 25 | std::vector<int32_t> timestamps; | 25 | std::vector<int32_t> timestamps; |
| 26 | 26 | ||
| 27 | // The total score of ys in log space. | 27 | // The total score of ys in log space. |
| 28 | + // It contains only acoustic scores | ||
| 28 | double log_prob = 0; | 29 | double log_prob = 0; |
| 29 | 30 | ||
| 31 | + // LM log prob if any. | ||
| 32 | + double lm_log_prob = 0; | ||
| 33 | + | ||
| 30 | int32_t num_trailing_blanks = 0; | 34 | int32_t num_trailing_blanks = 0; |
| 31 | 35 | ||
| 32 | Hypothesis() = default; | 36 | Hypothesis() = default; |
| 33 | Hypothesis(const std::vector<int64_t> &ys, double log_prob) | 37 | Hypothesis(const std::vector<int64_t> &ys, double log_prob) |
| 34 | : ys(ys), log_prob(log_prob) {} | 38 | : ys(ys), log_prob(log_prob) {} |
| 35 | 39 | ||
| 40 | + double TotalLogProb() const { return log_prob + lm_log_prob; } | ||
| 41 | + | ||
| 36 | // If two Hypotheses have the same `Key`, then they contain | 42 | // If two Hypotheses have the same `Key`, then they contain |
| 37 | // the same token sequence. | 43 | // the same token sequence. |
| 38 | std::string Key() const { | 44 | std::string Key() const { |
| @@ -94,6 +100,9 @@ class Hypotheses { | @@ -94,6 +100,9 @@ class Hypotheses { | ||
| 94 | const auto begin() const { return hyps_dict_.begin(); } | 100 | const auto begin() const { return hyps_dict_.begin(); } |
| 95 | const auto end() const { return hyps_dict_.end(); } | 101 | const auto end() const { return hyps_dict_.end(); } |
| 96 | 102 | ||
| 103 | + auto begin() { return hyps_dict_.begin(); } | ||
| 104 | + auto end() { return hyps_dict_.end(); } | ||
| 105 | + | ||
| 97 | void Clear() { hyps_dict_.clear(); } | 106 | void Clear() { hyps_dict_.clear(); } |
| 98 | 107 | ||
| 99 | private: | 108 | private: |
| @@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) { | @@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) { | ||
| 88 | } | 88 | } |
| 89 | } | 89 | } |
| 90 | 90 | ||
| 91 | +template <typename T> | ||
| 92 | +void LogSoftmax(T *in, int32_t w, int32_t h) { | ||
| 93 | + for (int32_t i = 0; i != h; ++i) { | ||
| 94 | + LogSoftmax(in, w); | ||
| 95 | + in += w; | ||
| 96 | + } | ||
| 97 | +} | ||
| 98 | + | ||
| 99 | +// TODO(fangjun): use std::partial_sort to replace std::sort. | ||
| 100 | +// Remember also to fix sherpa-ncnn | ||
| 91 | template <class T> | 101 | template <class T> |
| 92 | std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | 102 | std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { |
| 93 | std::vector<int32_t> vec_index(size); | 103 | std::vector<int32_t> vec_index(size); |
sherpa-onnx/csrc/offline-lm-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-lm-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-lm-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void OfflineLMConfig::Register(ParseOptions *po) { | ||
| 15 | + po->Register("lm", &model, "Path to LM model."); | ||
| 16 | + po->Register("lm-scale", &scale, "LM scale."); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OfflineLMConfig::Validate() const { | ||
| 20 | + if (!FileExists(model)) { | ||
| 21 | + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); | ||
| 22 | + return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + return true; | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +std::string OfflineLMConfig::ToString() const { | ||
| 29 | + std::ostringstream os; | ||
| 30 | + | ||
| 31 | + os << "OfflineLMConfig("; | ||
| 32 | + os << "model=\"" << model << "\", "; | ||
| 33 | + os << "scale=" << scale << ")"; | ||
| 34 | + | ||
| 35 | + return os.str(); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-lm-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-lm-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OfflineLMConfig { | ||
| 14 | + // path to the onnx model | ||
| 15 | + std::string model; | ||
| 16 | + | ||
| 17 | + // LM scale | ||
| 18 | + float scale = 1.0; | ||
| 19 | + | ||
| 20 | + OfflineLMConfig() = default; | ||
| 21 | + | ||
| 22 | + OfflineLMConfig(const std::string &model, float scale) | ||
| 23 | + : model(model), scale(scale) {} | ||
| 24 | + | ||
| 25 | + void Register(ParseOptions *po); | ||
| 26 | + bool Validate() const; | ||
| 27 | + | ||
| 28 | + std::string ToString() const; | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ |
sherpa-onnx/csrc/offline-lm.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-lm.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-lm.h" | ||
| 6 | + | ||
| 7 | +#include <algorithm> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/offline-rnn-lm.h" | ||
| 12 | + | ||
| 13 | +namespace sherpa_onnx { | ||
| 14 | + | ||
| 15 | +std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) { | ||
| 16 | + return std::make_unique<OfflineRnnLM>(config); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +void OfflineLM::ComputeLMScore(float scale, int32_t context_size, | ||
| 20 | + std::vector<Hypotheses> *hyps) { | ||
| 21 | + // compute the max token seq so that we know how much space to allocate | ||
| 22 | + int32_t max_token_seq = 0; | ||
| 23 | + int32_t num_hyps = 0; | ||
| 24 | + | ||
| 25 | + // we subtract context_size below since each token sequence is prepended | ||
| 26 | + // with context_size blanks | ||
| 27 | + for (const auto &h : *hyps) { | ||
| 28 | + num_hyps += h.Size(); | ||
| 29 | + for (const auto &t : h) { | ||
| 30 | + max_token_seq = | ||
| 31 | + std::max<int32_t>(max_token_seq, t.second.ys.size() - context_size); | ||
| 32 | + } | ||
| 33 | + } | ||
| 34 | + | ||
| 35 | + Ort::AllocatorWithDefaultOptions allocator; | ||
| 36 | + std::array<int64_t, 2> x_shape{num_hyps, max_token_seq}; | ||
| 37 | + Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator, x_shape.data(), | ||
| 38 | + x_shape.size()); | ||
| 39 | + | ||
| 40 | + std::array<int64_t, 1> x_lens_shape{num_hyps}; | ||
| 41 | + Ort::Value x_lens = Ort::Value::CreateTensor<int64_t>( | ||
| 42 | + allocator, x_lens_shape.data(), x_lens_shape.size()); | ||
| 43 | + | ||
| 44 | + int64_t *p = x.GetTensorMutableData<int64_t>(); | ||
| 45 | + std::fill(p, p + num_hyps * max_token_seq, 0); | ||
| 46 | + | ||
| 47 | + int64_t *p_lens = x_lens.GetTensorMutableData<int64_t>(); | ||
| 48 | + | ||
| 49 | + for (const auto &h : *hyps) { | ||
| 50 | + for (const auto &t : h) { | ||
| 51 | + const auto &ys = t.second.ys; | ||
| 52 | + int32_t len = ys.size() - context_size; | ||
| 53 | + std::copy(ys.begin() + context_size, ys.end(), p); | ||
| 54 | + *p_lens = len; | ||
| 55 | + | ||
| 56 | + p += max_token_seq; | ||
| 57 | + ++p_lens; | ||
| 58 | + } | ||
| 59 | + } | ||
| 60 | + auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); | ||
| 61 | + const float *p_nll = negative_loglike.GetTensorData<float>(); | ||
| 62 | + for (auto &h : *hyps) { | ||
| 63 | + for (auto &t : h) { | ||
| 64 | + // Use -scale here since we want to change negative loglike to loglike. | ||
| 65 | + t.second.lm_log_prob = -scale * (*p_nll); | ||
| 66 | + ++p_nll; | ||
| 67 | + } | ||
| 68 | + } | ||
| 69 | +} | ||
| 70 | + | ||
| 71 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-lm.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-lm.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 13 | +#include "sherpa-onnx/csrc/offline-lm-config.h" | ||
| 14 | + | ||
| 15 | +namespace sherpa_onnx { | ||
| 16 | + | ||
| 17 | +class OfflineLM { | ||
| 18 | + public: | ||
| 19 | + virtual ~OfflineLM() = default; | ||
| 20 | + | ||
| 21 | + static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config); | ||
| 22 | + | ||
| 23 | + /** Rescore a batch of sentences. | ||
| 24 | + * | ||
| 25 | + * @param x A 2-D tensor of shape (N, L) with data type int64. | ||
| 26 | + * @param x_lens A 1-D tensor of shape (N,) with data type int64. | ||
| 27 | + * It contains number of valid tokens in x before padding. | ||
| 28 | + * @return Return a 1-D tensor of shape (N,) containing the negative log | ||
| 29 | + * likelihood of each utterance. Its data type is float32. | ||
| 30 | + * | ||
| 31 | + * Caution: It returns negative log likelihood (nll), not log likelihood | ||
| 32 | + */ | ||
| 33 | + virtual Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) = 0; | ||
| 34 | + | ||
| 35 | + // This function updates hyp.lm_lob_prob of hyps. | ||
| 36 | + // | ||
| 37 | + // @param scale LM score | ||
| 38 | + // @param context_size Context size of the transducer decoder model | ||
| 39 | + // @param hyps It is changed in-place. | ||
| 40 | + void ComputeLMScore(float scale, int32_t context_size, | ||
| 41 | + std::vector<Hypotheses> *hyps); | ||
| 42 | +}; | ||
| 43 | + | ||
| 44 | +} // namespace sherpa_onnx | ||
| 45 | + | ||
| 46 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_ |
| @@ -16,6 +16,7 @@ | @@ -16,6 +16,7 @@ | ||
| 16 | #include "sherpa-onnx/csrc/offline-transducer-decoder.h" | 16 | #include "sherpa-onnx/csrc/offline-transducer-decoder.h" |
| 17 | #include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" | 17 | #include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" |
| 18 | #include "sherpa-onnx/csrc/offline-transducer-model.h" | 18 | #include "sherpa-onnx/csrc/offline-transducer-model.h" |
| 19 | +#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" | ||
| 19 | #include "sherpa-onnx/csrc/pad-sequence.h" | 20 | #include "sherpa-onnx/csrc/pad-sequence.h" |
| 20 | #include "sherpa-onnx/csrc/symbol-table.h" | 21 | #include "sherpa-onnx/csrc/symbol-table.h" |
| 21 | 22 | ||
| @@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 57 | decoder_ = | 58 | decoder_ = |
| 58 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); | 59 | std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get()); |
| 59 | } else if (config_.decoding_method == "modified_beam_search") { | 60 | } else if (config_.decoding_method == "modified_beam_search") { |
| 60 | - SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented"); | ||
| 61 | - exit(-1); | 61 | + if (!config_.lm_config.model.empty()) { |
| 62 | + lm_ = OfflineLM::Create(config.lm_config); | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>( | ||
| 66 | + model_.get(), lm_.get(), config_.max_active_paths, | ||
| 67 | + config_.lm_config.scale); | ||
| 62 | } else { | 68 | } else { |
| 63 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", | 69 | SHERPA_ONNX_LOGE("Unsupported decoding method: %s", |
| 64 | config_.decoding_method.c_str()); | 70 | config_.decoding_method.c_str()); |
| @@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | @@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { | ||
| 127 | SymbolTable symbol_table_; | 133 | SymbolTable symbol_table_; |
| 128 | std::unique_ptr<OfflineTransducerModel> model_; | 134 | std::unique_ptr<OfflineTransducerModel> model_; |
| 129 | std::unique_ptr<OfflineTransducerDecoder> decoder_; | 135 | std::unique_ptr<OfflineTransducerDecoder> decoder_; |
| 136 | + std::unique_ptr<OfflineLM> lm_; | ||
| 130 | }; | 137 | }; |
| 131 | 138 | ||
| 132 | } // namespace sherpa_onnx | 139 | } // namespace sherpa_onnx |
| @@ -8,6 +8,7 @@ | @@ -8,6 +8,7 @@ | ||
| 8 | 8 | ||
| 9 | #include "sherpa-onnx/csrc/file-utils.h" | 9 | #include "sherpa-onnx/csrc/file-utils.h" |
| 10 | #include "sherpa-onnx/csrc/macros.h" | 10 | #include "sherpa-onnx/csrc/macros.h" |
| 11 | +#include "sherpa-onnx/csrc/offline-lm-config.h" | ||
| 11 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" | 12 | #include "sherpa-onnx/csrc/offline-recognizer-impl.h" |
| 12 | 13 | ||
| 13 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| @@ -15,13 +16,28 @@ namespace sherpa_onnx { | @@ -15,13 +16,28 @@ namespace sherpa_onnx { | ||
| 15 | void OfflineRecognizerConfig::Register(ParseOptions *po) { | 16 | void OfflineRecognizerConfig::Register(ParseOptions *po) { |
| 16 | feat_config.Register(po); | 17 | feat_config.Register(po); |
| 17 | model_config.Register(po); | 18 | model_config.Register(po); |
| 19 | + lm_config.Register(po); | ||
| 18 | 20 | ||
| 19 | - po->Register("decoding-method", &decoding_method, | ||
| 20 | - "decoding method," | ||
| 21 | - "Valid values: greedy_search."); | 21 | + po->Register( |
| 22 | + "decoding-method", &decoding_method, | ||
| 23 | + "decoding method," | ||
| 24 | + "Valid values: greedy_search, modified_beam_search. " | ||
| 25 | + "modified_beam_search is applicable only for transducer models."); | ||
| 26 | + | ||
| 27 | + po->Register("max-active-paths", &max_active_paths, | ||
| 28 | + "Used only when decoding_method is modified_beam_search"); | ||
| 22 | } | 29 | } |
| 23 | 30 | ||
| 24 | bool OfflineRecognizerConfig::Validate() const { | 31 | bool OfflineRecognizerConfig::Validate() const { |
| 32 | + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { | ||
| 33 | + if (max_active_paths <= 0) { | ||
| 34 | + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", | ||
| 35 | + max_active_paths); | ||
| 36 | + return false; | ||
| 37 | + } | ||
| 38 | + if (!lm_config.Validate()) return false; | ||
| 39 | + } | ||
| 40 | + | ||
| 25 | return model_config.Validate(); | 41 | return model_config.Validate(); |
| 26 | } | 42 | } |
| 27 | 43 | ||
| @@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const { | @@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const { | ||
| 31 | os << "OfflineRecognizerConfig("; | 47 | os << "OfflineRecognizerConfig("; |
| 32 | os << "feat_config=" << feat_config.ToString() << ", "; | 48 | os << "feat_config=" << feat_config.ToString() << ", "; |
| 33 | os << "model_config=" << model_config.ToString() << ", "; | 49 | os << "model_config=" << model_config.ToString() << ", "; |
| 34 | - os << "decoding_method=\"" << decoding_method << "\")"; | 50 | + os << "lm_config=" << lm_config.ToString() << ", "; |
| 51 | + os << "decoding_method=\"" << decoding_method << "\", "; | ||
| 52 | + os << "max_active_paths=" << max_active_paths << ")"; | ||
| 35 | 53 | ||
| 36 | return os.str(); | 54 | return os.str(); |
| 37 | } | 55 | } |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <string> | 9 | #include <string> |
| 10 | #include <vector> | 10 | #include <vector> |
| 11 | 11 | ||
| 12 | +#include "sherpa-onnx/csrc/offline-lm-config.h" | ||
| 12 | #include "sherpa-onnx/csrc/offline-model-config.h" | 13 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 13 | #include "sherpa-onnx/csrc/offline-stream.h" | 14 | #include "sherpa-onnx/csrc/offline-stream.h" |
| 14 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" | 15 | #include "sherpa-onnx/csrc/offline-transducer-model-config.h" |
| @@ -21,18 +22,24 @@ struct OfflineRecognitionResult; | @@ -21,18 +22,24 @@ struct OfflineRecognitionResult; | ||
| 21 | struct OfflineRecognizerConfig { | 22 | struct OfflineRecognizerConfig { |
| 22 | OfflineFeatureExtractorConfig feat_config; | 23 | OfflineFeatureExtractorConfig feat_config; |
| 23 | OfflineModelConfig model_config; | 24 | OfflineModelConfig model_config; |
| 25 | + OfflineLMConfig lm_config; | ||
| 24 | 26 | ||
| 25 | std::string decoding_method = "greedy_search"; | 27 | std::string decoding_method = "greedy_search"; |
| 28 | + int32_t max_active_paths = 4; | ||
| 26 | // only greedy_search is implemented | 29 | // only greedy_search is implemented |
| 27 | // TODO(fangjun): Implement modified_beam_search | 30 | // TODO(fangjun): Implement modified_beam_search |
| 28 | 31 | ||
| 29 | OfflineRecognizerConfig() = default; | 32 | OfflineRecognizerConfig() = default; |
| 30 | OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, | 33 | OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, |
| 31 | const OfflineModelConfig &model_config, | 34 | const OfflineModelConfig &model_config, |
| 32 | - const std::string &decoding_method) | 35 | + const OfflineLMConfig &lm_config, |
| 36 | + const std::string &decoding_method, | ||
| 37 | + int32_t max_active_paths) | ||
| 33 | : feat_config(feat_config), | 38 | : feat_config(feat_config), |
| 34 | model_config(model_config), | 39 | model_config(model_config), |
| 35 | - decoding_method(decoding_method) {} | 40 | + lm_config(lm_config), |
| 41 | + decoding_method(decoding_method), | ||
| 42 | + max_active_paths(max_active_paths) {} | ||
| 36 | 43 | ||
| 37 | void Register(ParseOptions *po); | 44 | void Register(ParseOptions *po); |
| 38 | bool Validate() const; | 45 | bool Validate() const; |
sherpa-onnx/csrc/offline-rnn-lm.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-rnn-lm.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-rnn-lm.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 12 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 13 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 14 | +#include "sherpa-onnx/csrc/text-utils.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OfflineRnnLM::Impl { | ||
| 19 | + public: | ||
| 20 | + explicit Impl(const OfflineLMConfig &config) | ||
| 21 | + : config_(config), | ||
| 22 | + env_(ORT_LOGGING_LEVEL_ERROR), | ||
| 23 | + sess_opts_{}, | ||
| 24 | + allocator_{} { | ||
| 25 | + Init(config); | ||
| 26 | + } | ||
| 27 | + | ||
| 28 | + Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) { | ||
| 29 | + std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)}; | ||
| 30 | + | ||
| 31 | + auto out = | ||
| 32 | + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | ||
| 33 | + output_names_ptr_.data(), output_names_ptr_.size()); | ||
| 34 | + | ||
| 35 | + return std::move(out[0]); | ||
| 36 | + } | ||
| 37 | + | ||
| 38 | + private: | ||
| 39 | + void Init(const OfflineLMConfig &config) { | ||
| 40 | + auto buf = ReadFile(config_.model); | ||
| 41 | + | ||
| 42 | + sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(), | ||
| 43 | + sess_opts_); | ||
| 44 | + | ||
| 45 | + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); | ||
| 46 | + | ||
| 47 | + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + private: | ||
| 51 | + OfflineLMConfig config_; | ||
| 52 | + Ort::Env env_; | ||
| 53 | + Ort::SessionOptions sess_opts_; | ||
| 54 | + Ort::AllocatorWithDefaultOptions allocator_; | ||
| 55 | + | ||
| 56 | + std::unique_ptr<Ort::Session> sess_; | ||
| 57 | + | ||
| 58 | + std::vector<std::string> input_names_; | ||
| 59 | + std::vector<const char *> input_names_ptr_; | ||
| 60 | + | ||
| 61 | + std::vector<std::string> output_names_; | ||
| 62 | + std::vector<const char *> output_names_ptr_; | ||
| 63 | +}; | ||
| 64 | + | ||
| 65 | +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) | ||
| 66 | + : impl_(std::make_unique<Impl>(config)) {} | ||
| 67 | + | ||
| 68 | +OfflineRnnLM::~OfflineRnnLM() = default; | ||
| 69 | + | ||
| 70 | +Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) { | ||
| 71 | + return impl_->Rescore(std::move(x), std::move(x_lens)); | ||
| 72 | +} | ||
| 73 | + | ||
| 74 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/offline-rnn-lm.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/offline-rnn-lm.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | + | ||
| 10 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 11 | +#include "sherpa-onnx/csrc/offline-lm-config.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-lm.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineRnnLM : public OfflineLM { | ||
| 17 | + public: | ||
| 18 | + ~OfflineRnnLM() override; | ||
| 19 | + | ||
| 20 | + explicit OfflineRnnLM(const OfflineLMConfig &config); | ||
| 21 | + | ||
| 22 | + /** Rescore a batch of sentences. | ||
| 23 | + * | ||
| 24 | + * @param x A 2-D tensor of shape (N, L) with data type int64. | ||
| 25 | + * @param x_lens A 1-D tensor of shape (N,) with data type int64. | ||
| 26 | + * It contains number of valid tokens in x before padding. | ||
| 27 | + * @return Return a 1-D tensor of shape (N,) containing the log likelihood | ||
| 28 | + * of each utterance. Its data type is float32. | ||
| 29 | + * | ||
| 30 | + * Caution: It returns log likelihood, not negative log likelihood (nll). | ||
| 31 | + */ | ||
| 32 | + Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) override; | ||
| 33 | + | ||
| 34 | + private: | ||
| 35 | + class Impl; | ||
| 36 | + std::unique_ptr<Impl> impl_; | ||
| 37 | +}; | ||
| 38 | + | ||
| 39 | +} // namespace sherpa_onnx | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ |
| @@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl { | @@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl { | ||
| 95 | std::copy(begin, end, p); | 95 | std::copy(begin, end, p); |
| 96 | p += context_size; | 96 | p += context_size; |
| 97 | } | 97 | } |
| 98 | + | ||
| 99 | + return decoder_input; | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results, | ||
| 103 | + int32_t end_index) const { | ||
| 104 | + assert(end_index <= results.size()); | ||
| 105 | + | ||
| 106 | + int32_t batch_size = end_index; | ||
| 107 | + int32_t context_size = ContextSize(); | ||
| 108 | + std::array<int64_t, 2> shape{batch_size, context_size}; | ||
| 109 | + | ||
| 110 | + Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>( | ||
| 111 | + Allocator(), shape.data(), shape.size()); | ||
| 112 | + int64_t *p = decoder_input.GetTensorMutableData<int64_t>(); | ||
| 113 | + | ||
| 114 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 115 | + const auto &r = results[i]; | ||
| 116 | + const int64_t *begin = r.ys.data() + r.ys.size() - context_size; | ||
| 117 | + const int64_t *end = r.ys.data() + r.ys.size(); | ||
| 118 | + std::copy(begin, end, p); | ||
| 119 | + p += context_size; | ||
| 120 | + } | ||
| 121 | + | ||
| 98 | return decoder_input; | 122 | return decoder_input; |
| 99 | } | 123 | } |
| 100 | 124 | ||
| @@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput( | @@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput( | ||
| 234 | return impl_->BuildDecoderInput(results, end_index); | 258 | return impl_->BuildDecoderInput(results, end_index); |
| 235 | } | 259 | } |
| 236 | 260 | ||
| 261 | +Ort::Value OfflineTransducerModel::BuildDecoderInput( | ||
| 262 | + const std::vector<Hypothesis> &results, int32_t end_index) const { | ||
| 263 | + return impl_->BuildDecoderInput(results, end_index); | ||
| 264 | +} | ||
| 265 | + | ||
| 237 | } // namespace sherpa_onnx | 266 | } // namespace sherpa_onnx |
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | #include <vector> | 9 | #include <vector> |
| 10 | 10 | ||
| 11 | #include "onnxruntime_cxx_api.h" // NOLINT | 11 | #include "onnxruntime_cxx_api.h" // NOLINT |
| 12 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 12 | #include "sherpa-onnx/csrc/offline-model-config.h" | 13 | #include "sherpa-onnx/csrc/offline-model-config.h" |
| 13 | 14 | ||
| 14 | namespace sherpa_onnx { | 15 | namespace sherpa_onnx { |
| @@ -79,13 +80,16 @@ class OfflineTransducerModel { | @@ -79,13 +80,16 @@ class OfflineTransducerModel { | ||
| 79 | * | 80 | * |
| 80 | * @param results Current decoded results. | 81 | * @param results Current decoded results. |
| 81 | * @param end_index We only use results[0:end_index] to build | 82 | * @param end_index We only use results[0:end_index] to build |
| 82 | - * the decoder_input. | 83 | + * the decoder_input. results[end_index] is not used. |
| 83 | * @return Return a tensor of shape (results.size(), ContextSize()) | 84 | * @return Return a tensor of shape (results.size(), ContextSize()) |
| 84 | */ | 85 | */ |
| 85 | Ort::Value BuildDecoderInput( | 86 | Ort::Value BuildDecoderInput( |
| 86 | const std::vector<OfflineTransducerDecoderResult> &results, | 87 | const std::vector<OfflineTransducerDecoderResult> &results, |
| 87 | int32_t end_index) const; | 88 | int32_t end_index) const; |
| 88 | 89 | ||
| 90 | + Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results, | ||
| 91 | + int32_t end_index) const; | ||
| 92 | + | ||
| 89 | private: | 93 | private: |
| 90 | class Impl; | 94 | class Impl; |
| 91 | std::unique_ptr<Impl> impl_; | 95 | std::unique_ptr<Impl> impl_; |
| 1 | +// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h" | ||
| 6 | + | ||
| 7 | +#include <deque> | ||
| 8 | +#include <utility> | ||
| 9 | +#include <vector> | ||
| 10 | + | ||
| 11 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 12 | +#include "sherpa-onnx/csrc/onnx-utils.h" | ||
| 13 | +#include "sherpa-onnx/csrc/packed-sequence.h" | ||
| 14 | +#include "sherpa-onnx/csrc/slice.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +static std::vector<int32_t> GetHypsRowSplits( | ||
| 19 | + const std::vector<Hypotheses> &hyps) { | ||
| 20 | + std::vector<int32_t> row_splits; | ||
| 21 | + row_splits.reserve(hyps.size() + 1); | ||
| 22 | + | ||
| 23 | + row_splits.push_back(0); | ||
| 24 | + int32_t s = 0; | ||
| 25 | + for (const auto &h : hyps) { | ||
| 26 | + s += h.Size(); | ||
| 27 | + row_splits.push_back(s); | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + return row_splits; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +std::vector<OfflineTransducerDecoderResult> | ||
| 34 | +OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 35 | + Ort::Value encoder_out, Ort::Value encoder_out_length) { | ||
| 36 | + PackedSequence packed_encoder_out = PackPaddedSequence( | ||
| 37 | + model_->Allocator(), &encoder_out, &encoder_out_length); | ||
| 38 | + | ||
| 39 | + int32_t batch_size = | ||
| 40 | + static_cast<int32_t>(packed_encoder_out.sorted_indexes.size()); | ||
| 41 | + | ||
| 42 | + int32_t vocab_size = model_->VocabSize(); | ||
| 43 | + int32_t context_size = model_->ContextSize(); | ||
| 44 | + | ||
| 45 | + std::vector<int64_t> blanks(context_size, 0); | ||
| 46 | + Hypotheses blank_hyp({{blanks, 0}}); | ||
| 47 | + | ||
| 48 | + std::deque<Hypotheses> finalized; | ||
| 49 | + std::vector<Hypotheses> cur(batch_size, blank_hyp); | ||
| 50 | + std::vector<Hypothesis> prev; | ||
| 51 | + | ||
| 52 | + int32_t start = 0; | ||
| 53 | + int32_t t = 0; | ||
| 54 | + for (auto n : packed_encoder_out.batch_sizes) { | ||
| 55 | + Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n); | ||
| 56 | + start += n; | ||
| 57 | + | ||
| 58 | + if (n < static_cast<int32_t>(cur.size())) { | ||
| 59 | + for (int32_t k = static_cast<int32_t>(cur.size()) - 1; k >= n; --k) { | ||
| 60 | + finalized.push_front(std::move(cur[k])); | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + cur.erase(cur.begin() + n, cur.end()); | ||
| 64 | + } // if (n < static_cast<int32_t>(cur.size())) | ||
| 65 | + | ||
| 66 | + // Due to merging paths with identical token sequences, | ||
| 67 | + // not all utterances have "max_active_paths" paths. | ||
| 68 | + auto hyps_row_splits = GetHypsRowSplits(cur); | ||
| 69 | + int32_t num_hyps = hyps_row_splits.back(); | ||
| 70 | + | ||
| 71 | + prev.clear(); | ||
| 72 | + prev.reserve(num_hyps); | ||
| 73 | + | ||
| 74 | + for (auto &hyps : cur) { | ||
| 75 | + for (auto &h : hyps) { | ||
| 76 | + prev.push_back(std::move(h.second)); | ||
| 77 | + } | ||
| 78 | + } | ||
| 79 | + cur.clear(); | ||
| 80 | + cur.reserve(n); | ||
| 81 | + | ||
| 82 | + auto decoder_input = model_->BuildDecoderInput(prev, num_hyps); | ||
| 83 | + // decoder_input shape: (num_hyps, context_size) | ||
| 84 | + | ||
| 85 | + auto decoder_out = model_->RunDecoder(std::move(decoder_input)); | ||
| 86 | + // decoder_out is (num_hyps, joiner_dim) | ||
| 87 | + | ||
| 88 | + cur_encoder_out = | ||
| 89 | + Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); | ||
| 90 | + // now cur_encoder_out is of shape (num_hyps, joiner_dim) | ||
| 91 | + | ||
| 92 | + Ort::Value logit = model_->RunJoiner( | ||
| 93 | + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | ||
| 94 | + | ||
| 95 | + float *p_logit = logit.GetTensorMutableData<float>(); | ||
| 96 | + LogSoftmax(p_logit, vocab_size, num_hyps); | ||
| 97 | + | ||
| 98 | + // now p_logit contains log_softmax output, we rename it to p_logprob | ||
| 99 | + // to match what it actually contains | ||
| 100 | + float *p_logprob = p_logit; | ||
| 101 | + | ||
| 102 | + // add log_prob of each hypothesis to p_logprob before taking top_k | ||
| 103 | + for (int32_t i = 0; i != num_hyps; ++i) { | ||
| 104 | + float log_prob = prev[i].log_prob; | ||
| 105 | + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { | ||
| 106 | + *p_logprob += log_prob; | ||
| 107 | + } | ||
| 108 | + } | ||
| 109 | + p_logprob = p_logit; // we changed p_logprob in the above for loop | ||
| 110 | + | ||
| 111 | + // Now compute top_k for each utterance | ||
| 112 | + for (int32_t i = 0; i != n; ++i) { | ||
| 113 | + int32_t start = hyps_row_splits[i]; | ||
| 114 | + int32_t end = hyps_row_splits[i + 1]; | ||
| 115 | + auto topk = | ||
| 116 | + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); | ||
| 117 | + | ||
| 118 | + Hypotheses hyps; | ||
| 119 | + for (auto k : topk) { | ||
| 120 | + int32_t hyp_index = k / vocab_size + start; | ||
| 121 | + int32_t new_token = k % vocab_size; | ||
| 122 | + Hypothesis new_hyp = prev[hyp_index]; | ||
| 123 | + | ||
| 124 | + if (new_token != 0) { | ||
| 125 | + // blank id is fixed to 0 | ||
| 126 | + new_hyp.ys.push_back(new_token); | ||
| 127 | + new_hyp.timestamps.push_back(t); | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + new_hyp.log_prob = p_logprob[k]; | ||
| 131 | + hyps.Add(std::move(new_hyp)); | ||
| 132 | + } // for (auto k : topk) | ||
| 133 | + p_logprob += (end - start) * vocab_size; | ||
| 134 | + cur.push_back(std::move(hyps)); | ||
| 135 | + } // for (int32_t i = 0; i != n; ++i) | ||
| 136 | + | ||
| 137 | + ++t; | ||
| 138 | + } // for (auto n : packed_encoder_out.batch_sizes) | ||
| 139 | + | ||
| 140 | + for (auto &h : finalized) { | ||
| 141 | + cur.push_back(std::move(h)); | ||
| 142 | + } | ||
| 143 | + | ||
| 144 | + if (lm_) { | ||
| 145 | + // use LM for rescoring | ||
| 146 | + lm_->ComputeLMScore(lm_scale_, context_size, &cur); | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size); | ||
| 150 | + for (int32_t i = 0; i != batch_size; ++i) { | ||
| 151 | + Hypothesis hyp = cur[i].GetMostProbable(true); | ||
| 152 | + | ||
| 153 | + auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]]; | ||
| 154 | + | ||
| 155 | + // strip leading blanks | ||
| 156 | + r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()}; | ||
| 157 | + r.timestamps = std::move(hyp.timestamps); | ||
| 158 | + } | ||
| 159 | + | ||
| 160 | + return unsorted_ans; | ||
| 161 | +} | ||
| 162 | + | ||
| 163 | +} // namespace sherpa_onnx |
| 1 | +// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ | ||
| 7 | + | ||
| 8 | +#include <vector> | ||
| 9 | + | ||
| 10 | +#include "sherpa-onnx/csrc/offline-lm.h" | ||
| 11 | +#include "sherpa-onnx/csrc/offline-transducer-decoder.h" | ||
| 12 | +#include "sherpa-onnx/csrc/offline-transducer-model.h" | ||
| 13 | + | ||
| 14 | +namespace sherpa_onnx { | ||
| 15 | + | ||
| 16 | +class OfflineTransducerModifiedBeamSearchDecoder | ||
| 17 | + : public OfflineTransducerDecoder { | ||
| 18 | + public: | ||
| 19 | + OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, | ||
| 20 | + OfflineLM *lm, | ||
| 21 | + int32_t max_active_paths, | ||
| 22 | + float lm_scale) | ||
| 23 | + : model_(model), | ||
| 24 | + lm_(lm), | ||
| 25 | + max_active_paths_(max_active_paths), | ||
| 26 | + lm_scale_(lm_scale) {} | ||
| 27 | + | ||
| 28 | + std::vector<OfflineTransducerDecoderResult> Decode( | ||
| 29 | + Ort::Value encoder_out, Ort::Value encoder_out_length) override; | ||
| 30 | + | ||
| 31 | + private: | ||
| 32 | + OfflineTransducerModel *model_; // Not owned | ||
| 33 | + OfflineLM *lm_; // Not owned; may be nullptr | ||
| 34 | + | ||
| 35 | + int32_t max_active_paths_; | ||
| 36 | + float lm_scale_; // used only when lm_ is not nullptr | ||
| 37 | +}; | ||
| 38 | + | ||
| 39 | +} // namespace sherpa_onnx | ||
| 40 | + | ||
| 41 | +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ |
sherpa-onnx/csrc/online-lm-config.cc
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-lm-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/csrc/online-lm-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/file-utils.h" | ||
| 10 | +#include "sherpa-onnx/csrc/macros.h" | ||
| 11 | + | ||
| 12 | +namespace sherpa_onnx { | ||
| 13 | + | ||
| 14 | +void OnlineLMConfig::Register(ParseOptions *po) { | ||
| 15 | + po->Register("lm", &model, "Path to LM model."); | ||
| 16 | + po->Register("lm-scale", &scale, "LM scale."); | ||
| 17 | +} | ||
| 18 | + | ||
| 19 | +bool OnlineLMConfig::Validate() const { | ||
| 20 | + if (!FileExists(model)) { | ||
| 21 | + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); | ||
| 22 | + return false; | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + return true; | ||
| 26 | +} | ||
| 27 | + | ||
| 28 | +std::string OnlineLMConfig::ToString() const { | ||
| 29 | + std::ostringstream os; | ||
| 30 | + | ||
| 31 | + os << "OnlineLMConfig("; | ||
| 32 | + os << "model=\"" << model << "\", "; | ||
| 33 | + os << "scale=" << scale << ")"; | ||
| 34 | + | ||
| 35 | + return os.str(); | ||
| 36 | +} | ||
| 37 | + | ||
| 38 | +} // namespace sherpa_onnx |
sherpa-onnx/csrc/online-lm-config.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-lm-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ | ||
| 5 | +#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx/csrc/parse-options.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +struct OnlineLMConfig { | ||
| 14 | + // path to the onnx model | ||
| 15 | + std::string model; | ||
| 16 | + | ||
| 17 | + // LM scale | ||
| 18 | + float scale = 1.0; | ||
| 19 | + | ||
| 20 | + OnlineLMConfig() = default; | ||
| 21 | + | ||
| 22 | + OnlineLMConfig(const std::string &model, float scale) | ||
| 23 | + : model(model), scale(scale) {} | ||
| 24 | + | ||
| 25 | + void Register(ParseOptions *po); | ||
| 26 | + bool Validate() const; | ||
| 27 | + | ||
| 28 | + std::string ToString() const; | ||
| 29 | +}; | ||
| 30 | + | ||
| 31 | +} // namespace sherpa_onnx | ||
| 32 | + | ||
| 33 | +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ |
sherpa-onnx/csrc/online-lm.h
0 → 100644
| 1 | +// sherpa-onnx/csrc/online-lm.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_ | ||
| 6 | +#define SHERPA_ONNX_CSRC_ONLINE_LM_H_ | ||
| 7 | + | ||
| 8 | +#include <memory> | ||
| 9 | +#include <utility> | ||
| 10 | +#include <vector> | ||
| 11 | + | ||
| 12 | +#include "onnxruntime_cxx_api.h" // NOLINT | ||
| 13 | +#include "sherpa-onnx/csrc/hypothesis.h" | ||
| 14 | +#include "sherpa-onnx/csrc/online-lm-config.h" | ||
| 15 | + | ||
| 16 | +namespace sherpa_onnx { | ||
| 17 | + | ||
| 18 | +class OnlineLM { | ||
| 19 | + public: | ||
| 20 | + virtual ~OnlineLM() = default; | ||
| 21 | + | ||
| 22 | + static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); | ||
| 23 | + | ||
| 24 | + virtual std::vector<Ort::Value> GetInitStates() = 0; | ||
| 25 | + | ||
| 26 | + /** Rescore a batch of sentences. | ||
| 27 | + * | ||
| 28 | + * @param x A 2-D tensor of shape (N, L) with data type int64. | ||
| 29 | + * @param y A 2-D tensor of shape (N, L) with data type int64. | ||
| 30 | + * @param states It contains the states for the LM model | ||
| 31 | + * @return Return a pair containingo | ||
| 32 | + * - negative loglike | ||
| 33 | + * - updated states | ||
| 34 | + * | ||
| 35 | + * Caution: It returns negative log likelihood (nll), not log likelihood | ||
| 36 | + */ | ||
| 37 | + std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore( | ||
| 38 | + Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0; | ||
| 39 | + | ||
| 40 | + // This function updates hyp.lm_lob_prob of hyps. | ||
| 41 | + // | ||
| 42 | + // @param scale LM score | ||
| 43 | + // @param context_size Context size of the transducer decoder model | ||
| 44 | + // @param hyps It is changed in-place. | ||
| 45 | + void ComputeLMScore(float scale, int32_t context_size, | ||
| 46 | + std::vector<Hypotheses> *hyps); | ||
| 47 | + /** TODO(fangjun): | ||
| 48 | + * | ||
| 49 | + * 1. Add two fields to Hypothesis | ||
| 50 | + * (a) int32_t lm_cur_pos = 0; number of scored tokens so far | ||
| 51 | + * (b) std::vector<Ort::Value> lm_states; | ||
| 52 | + * 2. When we want to score a hypothesis, we construct x and y as follows: | ||
| 53 | + * | ||
| 54 | + * std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos, | ||
| 55 | + * hyp.ys.end() - 1}; | ||
| 56 | + * std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1 | ||
| 57 | + * hyp.ys.end()}; | ||
| 58 | + * hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos; | ||
| 59 | + */ | ||
| 60 | +}; | ||
| 61 | + | ||
| 62 | +} // namespace sherpa_onnx | ||
| 63 | + | ||
| 64 | +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_ |
| @@ -36,38 +36,6 @@ static void UseCachedDecoderOut( | @@ -36,38 +36,6 @@ static void UseCachedDecoderOut( | ||
| 36 | } | 36 | } |
| 37 | } | 37 | } |
| 38 | 38 | ||
| 39 | -static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 40 | - const std::vector<int32_t> &hyps_num_split) { | ||
| 41 | - std::vector<int64_t> cur_encoder_out_shape = | ||
| 42 | - cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 43 | - | ||
| 44 | - std::array<int64_t, 2> ans_shape{hyps_num_split.back(), | ||
| 45 | - cur_encoder_out_shape[1]}; | ||
| 46 | - | ||
| 47 | - Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 48 | - ans_shape.size()); | ||
| 49 | - | ||
| 50 | - const float *src = cur_encoder_out->GetTensorData<float>(); | ||
| 51 | - float *dst = ans.GetTensorMutableData<float>(); | ||
| 52 | - int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1; | ||
| 53 | - for (int32_t b = 0; b != batch_size; ++b) { | ||
| 54 | - int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; | ||
| 55 | - for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { | ||
| 56 | - std::copy(src, src + cur_encoder_out_shape[1], dst); | ||
| 57 | - dst += cur_encoder_out_shape[1]; | ||
| 58 | - } | ||
| 59 | - src += cur_encoder_out_shape[1]; | ||
| 60 | - } | ||
| 61 | - return ans; | ||
| 62 | -} | ||
| 63 | - | ||
| 64 | -static void LogSoftmax(float *in, int32_t w, int32_t h) { | ||
| 65 | - for (int32_t i = 0; i != h; ++i) { | ||
| 66 | - LogSoftmax(in, w); | ||
| 67 | - in += w; | ||
| 68 | - } | ||
| 69 | -} | ||
| 70 | - | ||
| 71 | OnlineTransducerDecoderResult | 39 | OnlineTransducerDecoderResult |
| 72 | OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { | 40 | OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { |
| 73 | int32_t context_size = model_->ContextSize(); | 41 | int32_t context_size = model_->ContextSize(); |
| @@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) { | @@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) { | ||
| 193 | } | 193 | } |
| 194 | #endif | 194 | #endif |
| 195 | 195 | ||
| 196 | +Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 197 | + const std::vector<int32_t> &hyps_num_split) { | ||
| 198 | + std::vector<int64_t> cur_encoder_out_shape = | ||
| 199 | + cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape(); | ||
| 200 | + | ||
| 201 | + std::array<int64_t, 2> ans_shape{hyps_num_split.back(), | ||
| 202 | + cur_encoder_out_shape[1]}; | ||
| 203 | + | ||
| 204 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(), | ||
| 205 | + ans_shape.size()); | ||
| 206 | + | ||
| 207 | + const float *src = cur_encoder_out->GetTensorData<float>(); | ||
| 208 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 209 | + int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1; | ||
| 210 | + for (int32_t b = 0; b != batch_size; ++b) { | ||
| 211 | + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; | ||
| 212 | + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { | ||
| 213 | + std::copy(src, src + cur_encoder_out_shape[1], dst); | ||
| 214 | + dst += cur_encoder_out_shape[1]; | ||
| 215 | + } | ||
| 216 | + src += cur_encoder_out_shape[1]; | ||
| 217 | + } | ||
| 218 | + return ans; | ||
| 219 | +} | ||
| 220 | + | ||
| 196 | } // namespace sherpa_onnx | 221 | } // namespace sherpa_onnx |
| @@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename); | @@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename); | ||
| 86 | std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); | 86 | std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename); |
| 87 | #endif | 87 | #endif |
| 88 | 88 | ||
| 89 | +// TODO(fangjun): Document it | ||
| 90 | +Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, | ||
| 91 | + const std::vector<int32_t> &hyps_num_split); | ||
| 89 | } // namespace sherpa_onnx | 92 | } // namespace sherpa_onnx |
| 90 | 93 | ||
| 91 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 94 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| @@ -111,6 +111,9 @@ for a list of pre-trained models to download. | @@ -111,6 +111,9 @@ for a list of pre-trained models to download. | ||
| 111 | 111 | ||
| 112 | fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | 112 | fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); |
| 113 | fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | 113 | fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); |
| 114 | + if (config.decoding_method == "modified_beam_search") { | ||
| 115 | + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); | ||
| 116 | + } | ||
| 114 | 117 | ||
| 115 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | 118 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); |
| 116 | float rtf = elapsed_seconds / duration; | 119 | float rtf = elapsed_seconds / duration; |
| @@ -117,6 +117,9 @@ for a list of pre-trained models to download. | @@ -117,6 +117,9 @@ for a list of pre-trained models to download. | ||
| 117 | 117 | ||
| 118 | fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | 118 | fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); |
| 119 | fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | 119 | fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); |
| 120 | + if (config.decoding_method == "modified_beam_search") { | ||
| 121 | + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); | ||
| 122 | + } | ||
| 120 | 123 | ||
| 121 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | 124 | fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); |
| 122 | float rtf = elapsed_seconds / duration; | 125 | float rtf = elapsed_seconds / duration; |
| @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx | @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx | ||
| 4 | display.cc | 4 | display.cc |
| 5 | endpoint.cc | 5 | endpoint.cc |
| 6 | features.cc | 6 | features.cc |
| 7 | + offline-lm-config.cc | ||
| 7 | offline-model-config.cc | 8 | offline-model-config.cc |
| 8 | offline-nemo-enc-dec-ctc-model-config.cc | 9 | offline-nemo-enc-dec-ctc-model-config.cc |
| 9 | offline-paraformer-model-config.cc | 10 | offline-paraformer-model-config.cc |
sherpa-onnx/python/csrc/offline-lm-config.cc
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/offline-lm-config.cc | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#include "sherpa-onnx/python/csrc/offline-lm-config.h" | ||
| 6 | + | ||
| 7 | +#include <string> | ||
| 8 | + | ||
| 9 | +#include "sherpa-onnx//csrc/offline-lm-config.h" | ||
| 10 | + | ||
| 11 | +namespace sherpa_onnx { | ||
| 12 | + | ||
| 13 | +void PybindOfflineLMConfig(py::module *m) { | ||
| 14 | + using PyClass = OfflineLMConfig; | ||
| 15 | + py::class_<PyClass>(*m, "OfflineLMConfig") | ||
| 16 | + .def(py::init<const std::string &, float>(), py::arg("model"), | ||
| 17 | + py::arg("scale")) | ||
| 18 | + .def_readwrite("model", &PyClass::model) | ||
| 19 | + .def_readwrite("scale", &PyClass::scale) | ||
| 20 | + .def("__str__", &PyClass::ToString); | ||
| 21 | +} | ||
| 22 | + | ||
| 23 | +} // namespace sherpa_onnx |
sherpa-onnx/python/csrc/offline-lm-config.h
0 → 100644
| 1 | +// sherpa-onnx/python/csrc/offline-lm-config.h | ||
| 2 | +// | ||
| 3 | +// Copyright (c) 2023 Xiaomi Corporation | ||
| 4 | + | ||
| 5 | +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ | ||
| 6 | +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ | ||
| 7 | + | ||
| 8 | +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" | ||
| 9 | + | ||
| 10 | +namespace sherpa_onnx { | ||
| 11 | + | ||
| 12 | +void PybindOfflineLMConfig(py::module *m); | ||
| 13 | + | ||
| 14 | +} | ||
| 15 | + | ||
| 16 | +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ |
| @@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | @@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) { | ||
| 15 | using PyClass = OfflineRecognizerConfig; | 15 | using PyClass = OfflineRecognizerConfig; |
| 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") | 16 | py::class_<PyClass>(*m, "OfflineRecognizerConfig") |
| 17 | .def(py::init<const OfflineFeatureExtractorConfig &, | 17 | .def(py::init<const OfflineFeatureExtractorConfig &, |
| 18 | - const OfflineModelConfig &, const std::string &>(), | 18 | + const OfflineModelConfig &, const OfflineLMConfig &, |
| 19 | + const std::string &, int32_t>(), | ||
| 19 | py::arg("feat_config"), py::arg("model_config"), | 20 | py::arg("feat_config"), py::arg("model_config"), |
| 20 | - py::arg("decoding_method")) | 21 | + py::arg("lm_config") = OfflineLMConfig(), |
| 22 | + py::arg("decoding_method") = "greedy_search", | ||
| 23 | + py::arg("max_active_paths") = 4) | ||
| 21 | .def_readwrite("feat_config", &PyClass::feat_config) | 24 | .def_readwrite("feat_config", &PyClass::feat_config) |
| 22 | .def_readwrite("model_config", &PyClass::model_config) | 25 | .def_readwrite("model_config", &PyClass::model_config) |
| 26 | + .def_readwrite("lm_config", &PyClass::lm_config) | ||
| 23 | .def_readwrite("decoding_method", &PyClass::decoding_method) | 27 | .def_readwrite("decoding_method", &PyClass::decoding_method) |
| 28 | + .def_readwrite("max_active_paths", &PyClass::max_active_paths) | ||
| 24 | .def("__str__", &PyClass::ToString); | 29 | .def("__str__", &PyClass::ToString); |
| 25 | } | 30 | } |
| 26 | 31 |
| @@ -7,6 +7,7 @@ | @@ -7,6 +7,7 @@ | ||
| 7 | #include "sherpa-onnx/python/csrc/display.h" | 7 | #include "sherpa-onnx/python/csrc/display.h" |
| 8 | #include "sherpa-onnx/python/csrc/endpoint.h" | 8 | #include "sherpa-onnx/python/csrc/endpoint.h" |
| 9 | #include "sherpa-onnx/python/csrc/features.h" | 9 | #include "sherpa-onnx/python/csrc/features.h" |
| 10 | +#include "sherpa-onnx/python/csrc/offline-lm-config.h" | ||
| 10 | #include "sherpa-onnx/python/csrc/offline-model-config.h" | 11 | #include "sherpa-onnx/python/csrc/offline-model-config.h" |
| 11 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" | 12 | #include "sherpa-onnx/python/csrc/offline-recognizer.h" |
| 12 | #include "sherpa-onnx/python/csrc/offline-stream.h" | 13 | #include "sherpa-onnx/python/csrc/offline-stream.h" |
| @@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | @@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { | ||
| 28 | PybindDisplay(&m); | 29 | PybindDisplay(&m); |
| 29 | 30 | ||
| 30 | PybindOfflineStream(&m); | 31 | PybindOfflineStream(&m); |
| 32 | + PybindOfflineLMConfig(&m); | ||
| 31 | PybindOfflineModelConfig(&m); | 33 | PybindOfflineModelConfig(&m); |
| 32 | PybindOfflineRecognizer(&m); | 34 | PybindOfflineRecognizer(&m); |
| 33 | } | 35 | } |
-
请 注册 或 登录 后发表评论