Fangjun Kuang
Committed by GitHub

Add RNN LM rescore for offline ASR with modified_beam_search (#125)

... ... @@ -56,3 +56,4 @@ run-offline-decode-files.sh
sherpa-onnx-nemo-ctc-en-citrinet-512
run-offline-decode-files-nemo-ctc.sh
*.jar
sherpa-onnx-nemo-ctc-*
... ...
... ... @@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI)
set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
endif()
if(SHERPA_ONNX_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS)
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_PYTHON is ON")
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
endif()
if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS)
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON")
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
... ...
... ... @@ -18,6 +18,8 @@ set(sources
hypothesis.cc
offline-ctc-greedy-search-decoder.cc
offline-ctc-model.cc
offline-lm-config.cc
offline-lm.cc
offline-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc
offline-nemo-enc-dec-ctc-model.cc
... ... @@ -26,10 +28,13 @@ set(sources
offline-paraformer-model.cc
offline-recognizer-impl.cc
offline-recognizer.cc
offline-rnn-lm.cc
offline-stream.cc
offline-transducer-greedy-search-decoder.cc
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
online-lm-config.cc
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
... ...
... ... @@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) {
hyps_dict_[key] = std::move(hyp);
} else {
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
it->second.lm_log_prob =
LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);
}
}
... ... @@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
if (length_norm == false) {
return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
[](const auto &left, auto &right) -> bool {
return left.second.log_prob <
right.second.log_prob;
return left.second.TotalLogProb() <
right.second.TotalLogProb();
})
->second;
} else {
... ... @@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
return std::max_element(
hyps_dict_.begin(), hyps_dict_.end(),
[](const auto &left, const auto &right) -> bool {
return left.second.log_prob / left.second.ys.size() <
right.second.log_prob / right.second.ys.size();
return left.second.TotalLogProb() / left.second.ys.size() <
right.second.TotalLogProb() / right.second.ys.size();
})
->second;
}
... ... @@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
std::vector<Hypothesis> all_hyps = Vec();
if (length_norm == false) {
std::partial_sort(
all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
[](const auto &a, const auto &b) { return a.log_prob > b.log_prob; });
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
[](const auto &a, const auto &b) {
return a.TotalLogProb() > b.TotalLogProb();
});
} else {
// for length_norm is true
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
[](const auto &a, const auto &b) {
return a.log_prob / a.ys.size() >
b.log_prob / b.ys.size();
return a.TotalLogProb() / a.ys.size() >
b.TotalLogProb() / b.ys.size();
});
}
... ...
... ... @@ -25,14 +25,20 @@ struct Hypothesis {
std::vector<int32_t> timestamps;
// The total score of ys in log space.
// It contains only acoustic scores
double log_prob = 0;
// LM log prob if any.
double lm_log_prob = 0;
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
double TotalLogProb() const { return log_prob + lm_log_prob; }
// If two Hypotheses have the same `Key`, then they contain
// the same token sequence.
std::string Key() const {
... ... @@ -94,6 +100,9 @@ class Hypotheses {
const auto begin() const { return hyps_dict_.begin(); }
const auto end() const { return hyps_dict_.end(); }
auto begin() { return hyps_dict_.begin(); }
auto end() { return hyps_dict_.end(); }
void Clear() { hyps_dict_.clear(); }
private:
... ...
... ... @@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) {
}
}
template <typename T>
void LogSoftmax(T *in, int32_t w, int32_t h) {
for (int32_t i = 0; i != h; ++i) {
LogSoftmax(in, w);
in += w;
}
}
// TODO(fangjun): use std::partial_sort to replace std::sort.
// Remember also to fix sherpa-ncnn
template <class T>
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
std::vector<int32_t> vec_index(size);
... ...
// sherpa-onnx/csrc/offline-lm-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OfflineLMConfig::Register(ParseOptions *po) {
po->Register("lm", &model, "Path to LM model.");
po->Register("lm-scale", &scale, "LM scale.");
}
bool OfflineLMConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
return false;
}
return true;
}
std::string OfflineLMConfig::ToString() const {
std::ostringstream os;
os << "OfflineLMConfig(";
os << "model=\"" << model << "\", ";
os << "scale=" << scale << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-lm-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OfflineLMConfig {
// path to the onnx model
std::string model;
// LM scale
float scale = 1.0;
OfflineLMConfig() = default;
OfflineLMConfig(const std::string &model, float scale)
: model(model), scale(scale) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
... ...
// sherpa-onnx/csrc/offline-lm.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-lm.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
namespace sherpa_onnx {
std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
return std::make_unique<OfflineRnnLM>(config);
}
void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
// compute the max token seq so that we know how much space to allocate
int32_t max_token_seq = 0;
int32_t num_hyps = 0;
// we subtract context_size below since each token sequence is prepended
// with context_size blanks
for (const auto &h : *hyps) {
num_hyps += h.Size();
for (const auto &t : h) {
max_token_seq =
std::max<int32_t>(max_token_seq, t.second.ys.size() - context_size);
}
}
Ort::AllocatorWithDefaultOptions allocator;
std::array<int64_t, 2> x_shape{num_hyps, max_token_seq};
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator, x_shape.data(),
x_shape.size());
std::array<int64_t, 1> x_lens_shape{num_hyps};
Ort::Value x_lens = Ort::Value::CreateTensor<int64_t>(
allocator, x_lens_shape.data(), x_lens_shape.size());
int64_t *p = x.GetTensorMutableData<int64_t>();
std::fill(p, p + num_hyps * max_token_seq, 0);
int64_t *p_lens = x_lens.GetTensorMutableData<int64_t>();
for (const auto &h : *hyps) {
for (const auto &t : h) {
const auto &ys = t.second.ys;
int32_t len = ys.size() - context_size;
std::copy(ys.begin() + context_size, ys.end(), p);
*p_lens = len;
p += max_token_seq;
++p_lens;
}
}
auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
const float *p_nll = negative_loglike.GetTensorData<float>();
for (auto &h : *hyps) {
for (auto &t : h) {
// Use -scale here since we want to change negative loglike to loglike.
t.second.lm_log_prob = -scale * (*p_nll);
++p_nll;
}
}
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-lm.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_
#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_
#include <memory>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
namespace sherpa_onnx {
class OfflineLM {
public:
virtual ~OfflineLM() = default;
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
* @param x_lens A 1-D tensor of shape (N,) with data type int64.
* It contains number of valid tokens in x before padding.
* @return Return a 1-D tensor of shape (N,) containing the negative log
* likelihood of each utterance. Its data type is float32.
*
* Caution: It returns negative log likelihood (nll), not log likelihood
*/
virtual Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) = 0;
// This function updates hyp.lm_lob_prob of hyps.
//
// @param scale LM score
// @param context_size Context size of the transducer decoder model
// @param hyps It is changed in-place.
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps);
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_
... ...
... ... @@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-model.h"
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
#include "sherpa-onnx/csrc/pad-sequence.h"
#include "sherpa-onnx/csrc/symbol-table.h"
... ... @@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
} else if (config_.decoding_method == "modified_beam_search") {
SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
exit(-1);
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(config.lm_config);
}
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
... ... @@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
SymbolTable symbol_table_;
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -8,6 +8,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
namespace sherpa_onnx {
... ... @@ -15,13 +16,28 @@ namespace sherpa_onnx {
void OfflineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
lm_config.Register(po);
po->Register("decoding-method", &decoding_method,
po->Register(
"decoding-method", &decoding_method,
"decoding method,"
"Valid values: greedy_search.");
"Valid values: greedy_search, modified_beam_search. "
"modified_beam_search is applicable only for transducer models.");
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
}
bool OfflineRecognizerConfig::Validate() const {
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
if (max_active_paths <= 0) {
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
max_active_paths);
return false;
}
if (!lm_config.Validate()) return false;
}
return model_config.Validate();
}
... ... @@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "OfflineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ")";
return os.str();
}
... ...
... ... @@ -9,6 +9,7 @@
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
... ... @@ -21,18 +22,24 @@ struct OfflineRecognitionResult;
struct OfflineRecognizerConfig {
OfflineFeatureExtractorConfig feat_config;
OfflineModelConfig model_config;
OfflineLMConfig lm_config;
std::string decoding_method = "greedy_search";
int32_t max_active_paths = 4;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
OfflineRecognizerConfig() = default;
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
const OfflineModelConfig &model_config,
const std::string &decoding_method)
const OfflineLMConfig &lm_config,
const std::string &decoding_method,
int32_t max_active_paths)
: feat_config(feat_config),
model_config(model_config),
decoding_method(decoding_method) {}
lm_config(lm_config),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
// sherpa-onnx/csrc/offline-rnn-lm.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OfflineRnnLM::Impl {
public:
explicit Impl(const OfflineLMConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{},
allocator_{} {
Init(config);
}
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(out[0]);
}
private:
void Init(const OfflineLMConfig &config) {
auto buf = ReadFile(config_.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
}
private:
OfflineLMConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
};
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineRnnLM::~OfflineRnnLM() = default;
Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {
return impl_->Rescore(std::move(x), std::move(x_lens));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-rnn-lm.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-lm-config.h"
#include "sherpa-onnx/csrc/offline-lm.h"
namespace sherpa_onnx {
class OfflineRnnLM : public OfflineLM {
public:
~OfflineRnnLM() override;
explicit OfflineRnnLM(const OfflineLMConfig &config);
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
* @param x_lens A 1-D tensor of shape (N,) with data type int64.
* It contains number of valid tokens in x before padding.
* @return Return a 1-D tensor of shape (N,) containing the log likelihood
* of each utterance. Its data type is float32.
*
* Caution: It returns log likelihood, not negative log likelihood (nll).
*/
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
... ...
... ... @@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl {
std::copy(begin, end, p);
p += context_size;
}
return decoder_input;
}
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
int32_t end_index) const {
assert(end_index <= results.size());
int32_t batch_size = end_index;
int32_t context_size = ContextSize();
std::array<int64_t, 2> shape{batch_size, context_size};
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
Allocator(), shape.data(), shape.size());
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
for (int32_t i = 0; i != batch_size; ++i) {
const auto &r = results[i];
const int64_t *begin = r.ys.data() + r.ys.size() - context_size;
const int64_t *end = r.ys.data() + r.ys.size();
std::copy(begin, end, p);
p += context_size;
}
return decoder_input;
}
... ... @@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput(
return impl_->BuildDecoderInput(results, end_index);
}
Ort::Value OfflineTransducerModel::BuildDecoderInput(
const std::vector<Hypothesis> &results, int32_t end_index) const {
return impl_->BuildDecoderInput(results, end_index);
}
} // namespace sherpa_onnx
... ...
... ... @@ -9,6 +9,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
namespace sherpa_onnx {
... ... @@ -79,13 +80,16 @@ class OfflineTransducerModel {
*
* @param results Current decoded results.
* @param end_index We only use results[0:end_index] to build
* the decoder_input.
* the decoder_input. results[end_index] is not used.
* @return Return a tensor of shape (results.size(), ContextSize())
*/
Ort::Value BuildDecoderInput(
const std::vector<OfflineTransducerDecoderResult> &results,
int32_t end_index) const;
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
int32_t end_index) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
... ...
// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
#include <deque>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/packed-sequence.h"
#include "sherpa-onnx/csrc/slice.h"
namespace sherpa_onnx {
static std::vector<int32_t> GetHypsRowSplits(
const std::vector<Hypotheses> &hyps) {
std::vector<int32_t> row_splits;
row_splits.reserve(hyps.size() + 1);
row_splits.push_back(0);
int32_t s = 0;
for (const auto &h : hyps) {
s += h.Size();
row_splits.push_back(s);
}
return row_splits;
}
std::vector<OfflineTransducerDecoderResult>
OfflineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) {
PackedSequence packed_encoder_out = PackPaddedSequence(
model_->Allocator(), &encoder_out, &encoder_out_length);
int32_t batch_size =
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
int32_t vocab_size = model_->VocabSize();
int32_t context_size = model_->ContextSize();
std::vector<int64_t> blanks(context_size, 0);
Hypotheses blank_hyp({{blanks, 0}});
std::deque<Hypotheses> finalized;
std::vector<Hypotheses> cur(batch_size, blank_hyp);
std::vector<Hypothesis> prev;
int32_t start = 0;
int32_t t = 0;
for (auto n : packed_encoder_out.batch_sizes) {
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
start += n;
if (n < static_cast<int32_t>(cur.size())) {
for (int32_t k = static_cast<int32_t>(cur.size()) - 1; k >= n; --k) {
finalized.push_front(std::move(cur[k]));
}
cur.erase(cur.begin() + n, cur.end());
} // if (n < static_cast<int32_t>(cur.size()))
// Due to merging paths with identical token sequences,
// not all utterances have "max_active_paths" paths.
auto hyps_row_splits = GetHypsRowSplits(cur);
int32_t num_hyps = hyps_row_splits.back();
prev.clear();
prev.reserve(num_hyps);
for (auto &hyps : cur) {
for (auto &h : hyps) {
prev.push_back(std::move(h.second));
}
}
cur.clear();
cur.reserve(n);
auto decoder_input = model_->BuildDecoderInput(prev, num_hyps);
// decoder_input shape: (num_hyps, context_size)
auto decoder_out = model_->RunDecoder(std::move(decoder_input));
// decoder_out is (num_hyps, joiner_dim)
cur_encoder_out =
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
Ort::Value logit = model_->RunJoiner(
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
LogSoftmax(p_logit, vocab_size, num_hyps);
// now p_logit contains log_softmax output, we rename it to p_logprob
// to match what it actually contains
float *p_logprob = p_logit;
// add log_prob of each hypothesis to p_logprob before taking top_k
for (int32_t i = 0; i != num_hyps; ++i) {
float log_prob = prev[i].log_prob;
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
*p_logprob += log_prob;
}
}
p_logprob = p_logit; // we changed p_logprob in the above for loop
// Now compute top_k for each utterance
for (int32_t i = 0; i != n; ++i) {
int32_t start = hyps_row_splits[i];
int32_t end = hyps_row_splits[i + 1];
auto topk =
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
Hypotheses hyps;
for (auto k : topk) {
int32_t hyp_index = k / vocab_size + start;
int32_t new_token = k % vocab_size;
Hypothesis new_hyp = prev[hyp_index];
if (new_token != 0) {
// blank id is fixed to 0
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t);
}
new_hyp.log_prob = p_logprob[k];
hyps.Add(std::move(new_hyp));
} // for (auto k : topk)
p_logprob += (end - start) * vocab_size;
cur.push_back(std::move(hyps));
} // for (int32_t i = 0; i != n; ++i)
++t;
} // for (auto n : packed_encoder_out.batch_sizes)
for (auto &h : finalized) {
cur.push_back(std::move(h));
}
if (lm_) {
// use LM for rescoring
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
}
std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
for (int32_t i = 0; i != batch_size; ++i) {
Hypothesis hyp = cur[i].GetMostProbable(true);
auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]];
// strip leading blanks
r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()};
r.timestamps = std::move(hyp.timestamps);
}
return unsorted_ans;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-lm.h"
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
#include "sherpa-onnx/csrc/offline-transducer-model.h"
namespace sherpa_onnx {
class OfflineTransducerModifiedBeamSearchDecoder
: public OfflineTransducerDecoder {
public:
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
OfflineLM *lm,
int32_t max_active_paths,
float lm_scale)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
private:
OfflineTransducerModel *model_; // Not owned
OfflineLM *lm_; // Not owned; may be nullptr
int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
... ...
// sherpa-onnx/csrc/online-lm-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lm-config.h"
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineLMConfig::Register(ParseOptions *po) {
po->Register("lm", &model, "Path to LM model.");
po->Register("lm-scale", &scale, "LM scale.");
}
bool OnlineLMConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
return false;
}
return true;
}
std::string OnlineLMConfig::ToString() const {
std::ostringstream os;
os << "OnlineLMConfig(";
os << "model=\"" << model << "\", ";
os << "scale=" << scale << ")";
return os.str();
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-lm-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineLMConfig {
// path to the onnx model
std::string model;
// LM scale
float scale = 1.0;
OnlineLMConfig() = default;
OnlineLMConfig(const std::string &model, float scale)
: model(model), scale(scale) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
... ...
// sherpa-onnx/csrc/online-lm.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_
#define SHERPA_ONNX_CSRC_ONLINE_LM_H_
#include <memory>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/hypothesis.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
namespace sherpa_onnx {
class OnlineLM {
public:
virtual ~OnlineLM() = default;
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
virtual std::vector<Ort::Value> GetInitStates() = 0;
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
* @param y A 2-D tensor of shape (N, L) with data type int64.
* @param states It contains the states for the LM model
* @return Return a pair containingo
* - negative loglike
* - updated states
*
* Caution: It returns negative log likelihood (nll), not log likelihood
*/
std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore(
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
// This function updates hyp.lm_lob_prob of hyps.
//
// @param scale LM score
// @param context_size Context size of the transducer decoder model
// @param hyps It is changed in-place.
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps);
/** TODO(fangjun):
*
* 1. Add two fields to Hypothesis
* (a) int32_t lm_cur_pos = 0; number of scored tokens so far
* (b) std::vector<Ort::Value> lm_states;
* 2. When we want to score a hypothesis, we construct x and y as follows:
*
* std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,
* hyp.ys.end() - 1};
* std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1
* hyp.ys.end()};
* hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;
*/
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_
... ...
... ... @@ -36,38 +36,6 @@ static void UseCachedDecoderOut(
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
cur_encoder_out_shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
ans_shape.size());
const float *src = cur_encoder_out->GetTensorData<float>();
float *dst = ans.GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
for (int32_t b = 0; b != batch_size; ++b) {
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
std::copy(src, src + cur_encoder_out_shape[1], dst);
dst += cur_encoder_out_shape[1];
}
src += cur_encoder_out_shape[1];
}
return ans;
}
static void LogSoftmax(float *in, int32_t w, int32_t h) {
for (int32_t i = 0; i != h; ++i) {
LogSoftmax(in, w);
in += w;
}
}
OnlineTransducerDecoderResult
OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
... ...
... ... @@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
}
#endif
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
cur_encoder_out_shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
ans_shape.size());
const float *src = cur_encoder_out->GetTensorData<float>();
float *dst = ans.GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
for (int32_t b = 0; b != batch_size; ++b) {
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
std::copy(src, src + cur_encoder_out_shape[1], dst);
dst += cur_encoder_out_shape[1];
}
src += cur_encoder_out_shape[1];
}
return ans;
}
} // namespace sherpa_onnx
... ...
... ... @@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename);
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
#endif
// TODO(fangjun): Document it
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
... ...
... ... @@ -111,6 +111,9 @@ for a list of pre-trained models to download.
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
if (config.decoding_method == "modified_beam_search") {
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
}
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
... ...
... ... @@ -117,6 +117,9 @@ for a list of pre-trained models to download.
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
if (config.decoding_method == "modified_beam_search") {
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
}
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
... ...
... ... @@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc
offline-lm-config.cc
offline-model-config.cc
offline-nemo-enc-dec-ctc-model-config.cc
offline-paraformer-model-config.cc
... ...
// sherpa-onnx/python/csrc/offline-lm-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
#include <string>
#include "sherpa-onnx//csrc/offline-lm-config.h"
namespace sherpa_onnx {
void PybindOfflineLMConfig(py::module *m) {
using PyClass = OfflineLMConfig;
py::class_<PyClass>(*m, "OfflineLMConfig")
.def(py::init<const std::string &, float>(), py::arg("model"),
py::arg("scale"))
.def_readwrite("model", &PyClass::model)
.def_readwrite("scale", &PyClass::scale)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/offline-lm-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineLMConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
... ...
... ... @@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
using PyClass = OfflineRecognizerConfig;
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const std::string &>(),
const OfflineModelConfig &, const OfflineLMConfig &,
const std::string &, int32_t>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("decoding_method"))
py::arg("lm_config") = OfflineLMConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -7,6 +7,7 @@
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
#include "sherpa-onnx/python/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
#include "sherpa-onnx/python/csrc/offline-stream.h"
... ... @@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindDisplay(&m);
PybindOfflineStream(&m);
PybindOfflineLMConfig(&m);
PybindOfflineModelConfig(&m);
PybindOfflineRecognizer(&m);
}
... ...