PF Luo
Committed by GitHub

Add lm rescore to online-modified-beam-search (#133)

... ... @@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() {
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
lmConfig = getOnlineLMConfig(type = type),
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
decodingMethod = "greedy_search",
decodingMethod = "modified_beam_search",
maxActivePaths = 4,
)
... ...
... ... @@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig(
var debug: Boolean = false,
)
data class OnlineLMConfig(
var model: String = "",
var scale: Float = 0.5f,
)
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
... ... @@ -31,6 +36,7 @@ data class FeatureConfig(
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var lmConfig : OnlineLMConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
... ... @@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
return null;
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own LM model. (It should be straightforward to train a new NN LM model
by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)
@param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/
fun getOnlineLMConfig(type : Int): OnlineLMConfig {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineLMConfig(
model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",
scale = 0.5f,
)
}
}
return OnlineLMConfig();
}
fun getEndpointConfig(): EndpointConfig {
return EndpointConfig(
rule1 = EndpointRule(false, 2.4f, 0.0f),
... ...
... ... @@ -22,8 +22,11 @@ fun main() {
var endpointConfig = EndpointConfig()
var lmConfig = OnlineLMConfig()
var config = OnlineRecognizerConfig(
modelConfig = modelConfig,
lmConfig = lmConfig,
featConfig = featConfig,
endpointConfig = endpointConfig,
enableEndpoint = true,
... ...
... ... @@ -34,9 +34,11 @@ set(sources
offline-transducer-model-config.cc
offline-transducer-model.cc
offline-transducer-modified-beam-search-decoder.cc
online-lm.cc
online-lm-config.cc
online-lstm-transducer-model.cc
online-recognizer.cc
online-rnn-lm.cc
online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
... ...
/**
* Copyright (c) 2023 Xiaomi Corporation
*
* Copyright (c) 2023 Pingfeng Luo
*/
#include "sherpa-onnx/csrc/hypothesis.h"
... ...
/**
* Copyright (c) 2023 Xiaomi Corporation
* Copyright (c) 2023 Pingfeng Luo
*
*/
... ... @@ -12,7 +13,9 @@
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
... ... @@ -31,6 +34,13 @@ struct Hypothesis {
// LM log prob if any.
double lm_log_prob = 0;
int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM
std::vector<CopyableOrtValue> nn_lm_states;
// TODO(fangjun): Make it configurable
// the minimum of tokens in a chunk for streaming RNN LM
int32_t lm_rescore_min_chunk = 2; // a const
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
... ...
... ... @@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
}
}
// TODO(fangjun): use std::partial_sort to replace std::sort.
// Remember also to fix sherpa-ncnn
template <class T>
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
std::vector<int32_t> vec_index(size);
std::iota(vec_index.begin(), vec_index.end(), 0);
std::sort(vec_index.begin(), vec_index.end(),
[vec](int32_t index_1, int32_t index_2) {
return vec[index_1] > vec[index_2];
});
std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
return vec[index_1] > vec[index_2];
});
int32_t k_num = std::min<int32_t>(size, topk);
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
... ...
... ... @@ -15,7 +15,7 @@ struct OfflineLMConfig {
std::string model;
// LM scale
float scale = 1.0;
float scale = 0.5;
OfflineLMConfig() = default;
... ...
... ... @@ -15,7 +15,7 @@ struct OnlineLMConfig {
std::string model;
// LM scale
float scale = 1.0;
float scale = 0.5;
OnlineLMConfig() = default;
... ...
// sherpa-onnx/csrc/online-lm.cc
//
// Copyright (c) 2023 Pingfeng Luo
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lm.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/online-rnn-lm.h"
namespace sherpa_onnx {
static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {
std::vector<CopyableOrtValue> ans;
ans.reserve(values.size());
for (auto &v : values) {
ans.emplace_back(std::move(v));
}
return ans;
}
static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
std::vector<Ort::Value> ans;
ans.reserve(values.size());
for (auto &v : values) {
ans.emplace_back(std::move(v.value));
}
return ans;
}
std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
return std::make_unique<OnlineRnnLM>(config);
}
void OnlineLM::ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps) {
Ort::AllocatorWithDefaultOptions allocator;
for (auto &hyp : *hyps) {
for (auto &h_m : hyp) {
auto &h = h_m.second;
auto &ys = h.ys;
const int32_t token_num_in_chunk =
ys.size() - context_size - h.cur_scored_pos - 1;
if (token_num_in_chunk < 1) {
continue;
}
if (h.nn_lm_states.empty()) {
h.nn_lm_states = Convert(GetInitStates());
}
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
// shape of x and y are same
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
Ort::Value y = Ort::Value::CreateTensor<int64_t>(
allocator, x_shape.data(), x_shape.size());
int64_t *p_x = x.GetTensorMutableData<int64_t>();
int64_t *p_y = y.GetTensorMutableData<int64_t>();
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
p_x);
std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(),
p_y);
// streaming forward by NN LM
auto out = Rescore(std::move(x), std::move(y),
Convert(std::move(h.nn_lm_states)));
// update NN LM score in hyp
const float *p_nll = out.first.GetTensorData<float>();
h.lm_log_prob = -scale * (*p_nll);
// update NN LM states in hyp
h.nn_lm_states = Convert(std::move(out.second));
h.cur_scored_pos += token_num_in_chunk;
}
}
}
}
} // namespace sherpa_onnx
... ...
... ... @@ -34,7 +34,7 @@ class OnlineLM {
*
* Caution: It returns negative log likelihood (nll), not log likelihood
*/
std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore(
virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
// This function updates hyp.lm_lob_prob of hyps.
... ... @@ -44,19 +44,6 @@ class OnlineLM {
// @param hyps It is changed in-place.
void ComputeLMScore(float scale, int32_t context_size,
std::vector<Hypotheses> *hyps);
/** TODO(fangjun):
*
* 1. Add two fields to Hypothesis
* (a) int32_t lm_cur_pos = 0; number of scored tokens so far
* (b) std::vector<Ort::Value> lm_states;
* 2. When we want to score a hypothesis, we construct x and y as follows:
*
* std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,
* hyp.ys.end() - 1};
* std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1
* hyp.ys.end()};
* hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;
*/
};
} // namespace sherpa_onnx
... ...
... ... @@ -16,6 +16,8 @@
#include "nlohmann/json.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
... ... @@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
feat_config.Register(po);
model_config.Register(po);
endpoint_config.Register(po);
lm_config.Register(po);
po->Register("enable-endpoint", &enable_endpoint,
"True to enable endpoint detection. False to disable it.");
... ... @@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
}
bool OnlineRecognizerConfig::Validate() const {
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
if (max_active_paths <= 0) {
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
max_active_paths);
return false;
}
if (!lm_config.Validate()) return false;
}
return model_config.Validate();
}
... ... @@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "OnlineRecognizerConfig(";
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "lm_config=" << lm_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
... ... @@ -116,8 +128,13 @@ class OnlineRecognizer::Impl {
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (config.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OnlineLM::Create(config.lm_config);
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), config_.max_active_paths);
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
... ... @@ -136,7 +153,8 @@ class OnlineRecognizer::Impl {
endpoint_(config_.endpoint_config) {
if (config.decoding_method == "modified_beam_search") {
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), config_.max_active_paths);
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
... ... @@ -246,6 +264,7 @@ class OnlineRecognizer::Impl {
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<OnlineLM> lm_;
std::unique_ptr<OnlineTransducerDecoder> decoder_;
SymbolTable sym_;
Endpoint endpoint_;
... ...
... ... @@ -16,6 +16,7 @@
#include "sherpa-onnx/csrc/endpoint.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/parse-options.h"
... ... @@ -67,10 +68,11 @@ struct OnlineRecognizerResult {
struct OnlineRecognizerConfig {
FeatureExtractorConfig feat_config;
OnlineTransducerModelConfig model_config;
OnlineLMConfig lm_config;
EndpointConfig endpoint_config;
bool enable_endpoint = true;
std::string decoding_method = "greedy_search";
std::string decoding_method = "modified_beam_search";
// now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
... ... @@ -79,6 +81,7 @@ struct OnlineRecognizerConfig {
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const OnlineLMConfig &lm_config,
const EndpointConfig &endpoint_config,
bool enable_endpoint,
const std::string &decoding_method,
... ...
// sherpa-onnx/csrc/on-rnn-lm.cc
//
// Copyright (c) 2023 Pingfeng Luo
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-rnn-lm.h"
#include <string>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OnlineRnnLM::Impl {
public:
explicit Impl(const OnlineLMConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_{},
allocator_{} {
Init(config);
}
std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
std::array<Ort::Value, 4> inputs = {
std::move(x), std::move(y), std::move(states[0]), std::move(states[1])};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
std::vector<Ort::Value> next_states;
next_states.reserve(2);
next_states.push_back(std::move(out[1]));
next_states.push_back(std::move(out[2]));
return {std::move(out[0]), std::move(next_states)};
}
std::vector<Ort::Value> GetInitStates() const {
std::vector<Ort::Value> ans;
ans.reserve(init_states_.size());
for (const auto &s : init_states_) {
ans.emplace_back(Clone(allocator_, &s));
}
return ans;
}
private:
void Init(const OnlineLMConfig &config) {
auto buf = ReadFile(config_.model);
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers");
SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size");
SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id");
ComputeInitStates();
}
void ComputeInitStates() {
constexpr int32_t kBatchSize = 1;
std::array<int64_t, 3> h_shape{rnn_num_layers_, kBatchSize,
rnn_hidden_size_};
std::array<int64_t, 3> c_shape{rnn_num_layers_, kBatchSize,
rnn_hidden_size_};
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
Fill<float>(&h, 0);
Fill<float>(&c, 0);
std::array<int64_t, 2> x_shape{1, 1};
// shape of x and y are same
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
x_shape.size());
Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
x_shape.size());
*x.GetTensorMutableData<int64_t>() = sos_id_;
*y.GetTensorMutableData<int64_t>() = sos_id_;
std::vector<Ort::Value> states;
states.push_back(std::move(h));
states.push_back(std::move(c));
auto pair = Rescore(std::move(x), std::move(y), std::move(states));
init_states_ = std::move(pair.second);
}
private:
OnlineLMConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
std::vector<Ort::Value> init_states_;
int32_t rnn_num_layers_ = 2;
int32_t rnn_hidden_size_ = 512;
int32_t sos_id_ = 1;
};
OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OnlineRnnLM::~OnlineRnnLM() = default;
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
return impl_->GetInitStates();
}
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore(
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
return impl_->Rescore(std::move(x), std::move(y), std::move(states));
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/online-rnn-lm.h
//
// Copyright (c) 2023 Pingfeng Luo
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
#define SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
#include <memory>
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-lm.h"
namespace sherpa_onnx {
class OnlineRnnLM : public OnlineLM {
public:
~OnlineRnnLM() override;
explicit OnlineRnnLM(const OnlineLMConfig &config);
std::vector<Ort::Value> GetInitStates() override;
/** Rescore a batch of sentences.
*
* @param x A 2-D tensor of shape (N, L) with data type int64.
* @param y A 2-D tensor of shape (N, L) with data type int64.
* @param states It contains the states for the LM model
* @return Return a pair containingo
* - negative loglike
* - updated states
*
* Caution: It returns negative log likelihood (nll), not log likelihood
*/
std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
... ...
... ... @@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
} // for (int32_t b = 0; b != batch_size; ++b)
}
if (lm_) {
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
}
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
... ...
... ... @@ -8,6 +8,7 @@
#include <vector>
#include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-model.h"
... ... @@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder
: public OnlineTransducerDecoder {
public:
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
int32_t max_active_paths)
: model_(model), max_active_paths_(max_active_paths) {}
OnlineLM *lm,
int32_t max_active_paths,
float lm_scale)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale) {}
OnlineTransducerDecoderResult GetEmptyResult() const override;
... ... @@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder
private:
OnlineTransducerModel *model_; // Not owned
OnlineLM *lm_; // Not owned
int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr
};
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/onnx-utils.cc
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#include "sherpa-onnx/csrc/onnx-utils.h"
#include <algorithm>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
... ... @@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
return ans;
}
CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) {
*this = other;
}
CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) {
if (this == &other) {
return *this;
}
if (other.value) {
Ort::AllocatorWithDefaultOptions allocator;
value = Clone(allocator, &other.value);
}
return *this;
}
CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) {
*this = std::move(other);
}
CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) {
if (this == &other) {
return *this;
}
value = std::move(other.value);
return *this;
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/csrc/onnx-utils.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
... ... @@ -13,6 +14,7 @@
#include <cassert>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
... ... @@ -89,6 +91,24 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
// TODO(fangjun): Document it
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split);
struct CopyableOrtValue {
Ort::Value value{nullptr};
CopyableOrtValue() = default;
/*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT
: value(std::move(v)) {}
CopyableOrtValue(const CopyableOrtValue &other);
CopyableOrtValue &operator=(const CopyableOrtValue &other);
CopyableOrtValue(CopyableOrtValue &&other);
CopyableOrtValue &operator=(CopyableOrtValue &&other);
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
... ...
... ... @@ -13,8 +13,9 @@
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
// TODO(fangjun): Use ParseOptions as we are getting more args
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 8) {
if (argc < 6 || argc > 9) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx \
... ... @@ -22,7 +23,7 @@ Usage:
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/foo.wav [num_threads [decoding_method]]
/path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
... ... @@ -53,10 +54,12 @@ for a list of pre-trained models to download.
if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
if (argc == 9) {
config.lm_config.model = argv[8];
}
config.max_active_paths = 4;
fprintf(stderr, "%s\n", config.ToString().c_str());
... ...
... ... @@ -16,9 +16,8 @@
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#else
#include <fstream>
#endif
#include <fstream>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
... ... @@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid);
//---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
jobject lm_model_config = env->GetObjectField(config, fid);
jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(lm_model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.lm_config.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
return ans;
}
... ...
... ... @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
offline-recognizer.cc
offline-stream.cc
offline-transducer-model-config.cc
online-lm-config.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
... ...
// sherpa-onnx/python/csrc/online-lm-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/online-lm-config.h"
#include <string>
#include "sherpa-onnx//csrc/online-lm-config.h"
namespace sherpa_onnx {
void PybindOnlineLMConfig(py::module *m) {
using PyClass = OnlineLMConfig;
py::class_<PyClass>(*m, "OnlineLMConfig")
.def(py::init<const std::string &, float>(), py::arg("model") = "",
py::arg("scale") = 0.5f)
.def_readwrite("model", &PyClass::model)
.def_readwrite("scale", &PyClass::scale)
.def("__str__", &PyClass::ToString);
}
} // namespace sherpa_onnx
... ...
// sherpa-onnx/python/csrc/online-lm-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOnlineLMConfig(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
... ...
... ... @@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const EndpointConfig &,
bool, const std::string &, int32_t>(),
const OnlineTransducerModelConfig &, const OnlineLMConfig &,
const EndpointConfig &, bool, const std::string &,
int32_t>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("endpoint_config"), py::arg("enable_endpoint"),
py::arg("decoding_method"), py::arg("max_active_paths"))
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
... ...
... ... @@ -11,6 +11,7 @@
#include "sherpa-onnx/python/csrc/offline-model-config.h"
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
#include "sherpa-onnx/python/csrc/offline-stream.h"
#include "sherpa-onnx/python/csrc/online-lm-config.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
... ... @@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindFeatures(&m);
PybindOnlineTransducerModelConfig(&m);
PybindOnlineLMConfig(&m);
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
... ...