Fangjun Kuang
Committed by GitHub

Support batch greedy search decoding (#30)

... ... @@ -3,6 +3,7 @@
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
... ... @@ -10,6 +11,7 @@
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \
... ... @@ -114,23 +116,85 @@ void OnlineLstmTransducerModel::InitJoiner(const std::string &filename) {
}
}
Ort::Value OnlineLstmTransducerModel::StackStates(
const std::vector<Ort::Value> &states) const {
fprintf(stderr, "implement me: %s:%d!\n", __func__,
static_cast<int>(__LINE__));
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int64_t a;
std::array<int64_t, 3> x_shape{1, 1, 1};
Ort::Value x = Ort::Value::CreateTensor(memory_info, &a, 0, &a, 0);
return x;
std::vector<Ort::Value> OnlineLstmTransducerModel::StackStates(
const std::vector<std::vector<Ort::Value>> &states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
std::array<int64_t, 3> h_shape{num_encoder_layers_, batch_size, d_model_};
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
std::array<int64_t, 3> c_shape{num_encoder_layers_, batch_size,
rnn_hidden_size_};
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
float *dst_h = h.GetTensorMutableData<float>();
float *dst_c = c.GetTensorMutableData<float>();
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
for (int32_t i = 0; i != batch_size; ++i) {
const float *src_h =
states[i][0].GetTensorData<float>() + layer * d_model_;
const float *src_c =
states[i][1].GetTensorData<float>() + layer * rnn_hidden_size_;
std::copy(src_h, src_h + d_model_, dst_h);
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
dst_h += d_model_;
dst_c += rnn_hidden_size_;
}
}
std::vector<Ort::Value> ans;
ans.reserve(2);
ans.push_back(std::move(h));
ans.push_back(std::move(c));
return ans;
}
std::vector<Ort::Value> OnlineLstmTransducerModel::UnStackStates(
Ort::Value states) const {
fprintf(stderr, "implement me: %s:%d!\n", __func__,
static_cast<int>(__LINE__));
return {};
std::vector<std::vector<Ort::Value>> OnlineLstmTransducerModel::UnStackStates(
const std::vector<Ort::Value> &states) const {
int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1];
std::vector<std::vector<Ort::Value>> ans(batch_size);
// allocate space
std::array<int64_t, 3> h_shape{num_encoder_layers_, 1, d_model_};
std::array<int64_t, 3> c_shape{num_encoder_layers_, 1, rnn_hidden_size_};
for (int32_t i = 0; i != batch_size; ++i) {
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
h_shape.size());
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
c_shape.size());
ans[i].push_back(std::move(h));
ans[i].push_back(std::move(c));
}
for (int32_t layer = 0; layer != num_encoder_layers_; ++layer) {
for (int32_t i = 0; i != batch_size; ++i) {
const float *src_h = states[0].GetTensorData<float>() +
layer * batch_size * d_model_ + i * d_model_;
const float *src_c = states[1].GetTensorData<float>() +
layer * batch_size * rnn_hidden_size_ +
i * rnn_hidden_size_;
float *dst_h = ans[i][0].GetTensorMutableData<float>() + layer * d_model_;
float *dst_c =
ans[i][1].GetTensorMutableData<float>() + layer * rnn_hidden_size_;
std::copy(src_h, src_h + d_model_, dst_h);
std::copy(src_c, src_c + rnn_hidden_size_, dst_c);
}
}
return ans;
}
std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
... ... @@ -189,16 +253,21 @@ OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
}
Ort::Value OnlineLstmTransducerModel::BuildDecoderInput(
const std::vector<int64_t> &hyp) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> shape{1, context_size_};
const std::vector<OnlineTransducerDecoderResult> &results) {
int32_t batch_size = static_cast<int32_t>(results.size());
std::array<int64_t, 2> shape{batch_size, context_size_};
Ort::Value decoder_input =
Ort::Value::CreateTensor<int64_t>(allocator_, shape.data(), shape.size());
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
for (const auto &r : results) {
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size_;
const int64_t *end = r.tokens.data() + r.tokens.size();
std::copy(begin, end, p);
p += context_size_;
}
return Ort::Value::CreateTensor(
memory_info,
const_cast<int64_t *>(hyp.data() + hyp.size() - context_size_),
context_size_, shape.data(), shape.size());
return decoder_input;
}
Ort::Value OnlineLstmTransducerModel::RunDecoder(Ort::Value decoder_input) {
... ...
... ... @@ -19,16 +19,19 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
public:
explicit OnlineLstmTransducerModel(const OnlineTransducerModelConfig &config);
Ort::Value StackStates(const std::vector<Ort::Value> &states) const override;
std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const override;
std::vector<Ort::Value> UnStackStates(Ort::Value states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const override;
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> &states) override;
Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) override;
Ort::Value BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) override;
Ort::Value RunDecoder(Ort::Value decoder_input) override;
... ... @@ -41,6 +44,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
int32_t ChunkShift() const override { return decode_chunk_len_; }
int32_t VocabSize() const override { return vocab_size_; }
OrtAllocator *Allocator() override { return allocator_; }
private:
void InitEncoder(const std::string &encoder_filename);
... ... @@ -50,7 +54,6 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
... ...
... ... @@ -6,6 +6,7 @@
#include <assert.h>
#include <algorithm>
#include <memory>
#include <sstream>
#include <utility>
... ... @@ -64,39 +65,50 @@ class OnlineRecognizer::Impl {
}
void DecodeStreams(OnlineStream **ss, int32_t n) {
if (n != 1) {
fprintf(stderr, "only n == 1 is implemented\n");
exit(-1);
}
OnlineStream *s = ss[0];
assert(IsReady(s));
int32_t chunk_size = model_->ChunkSize();
int32_t chunk_shift = model_->ChunkShift();
int32_t feature_dim = s->FeatureDim();
int32_t feature_dim = ss[0]->FeatureDim();
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
std::vector<OnlineTransducerDecoderResult> results(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
for (int32_t i = 0; i != n; ++i) {
std::vector<float> features =
ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size);
ss[i]->GetNumProcessedFrames() += chunk_shift;
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);
results[i] = std::move(ss[i]->GetResult());
states_vec[i] = std::move(ss[i]->GetStates());
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
s->GetNumProcessedFrames() += chunk_shift;
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
features_vec.size(), x_shape.data(),
x_shape.size());
auto pair = model_->RunEncoder(std::move(x), s->GetStates());
auto states = model_->StackStates(states_vec);
s->SetStates(std::move(pair.second));
std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
auto pair = model_->RunEncoder(std::move(x), states);
decoder_->Decode(std::move(pair.first), &results);
s->SetResult(results[0]);
std::vector<std::vector<Ort::Value>> next_states =
model_->UnStackStates(pair.second);
for (int32_t i = 0; i != n; ++i) {
ss[i]->SetResult(results[i]);
ss[i]->SetStates(std::move(next_states[i]));
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) {
... ...
... ... @@ -32,6 +32,30 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
encoder_out_dim, shape.data(), shape.size());
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
int32_t n) {
if (n == 1) {
return std::move(*cur_encoder_out);
}
std::vector<int64_t> cur_encoder_out_shape =
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
std::array<int64_t, 2> ans_shape{n, cur_encoder_out_shape[1]};
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
ans_shape.size());
const float *src = cur_encoder_out->GetTensorData<float>();
float *dst = ans.GetTensorMutableData<float>();
for (int32_t i = 0; i != n; ++i) {
std::copy(src, src + cur_encoder_out_shape[1], dst);
dst += cur_encoder_out_shape[1];
}
return ans;
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
... ... @@ -66,33 +90,33 @@ void OnlineTransducerGreedySearchDecoder::Decode(
exit(-1);
}
if (result->size() != 1) {
fprintf(stderr, "only batch size == 1 is implemented. Given: %d",
static_cast<int32_t>(result->size()));
exit(-1);
}
auto &hyp = (*result)[0].tokens;
int32_t num_frames = encoder_out_shape[1];
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
int32_t vocab_size = model_->VocabSize();
Ort::Value decoder_input = model_->BuildDecoderInput(hyp);
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
Ort::Value logit =
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
const float *p_logit = logit.GetTensorData<float>();
bool emitted = false;
for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) {
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
if (y != 0) {
hyp.push_back(y);
decoder_input = model_->BuildDecoderInput(hyp);
emitted = true;
(*result)[i].tokens.push_back(y);
}
}
if (emitted) {
decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
}
... ...
... ... @@ -13,6 +13,8 @@
namespace sherpa_onnx {
class OnlineTransducerDecoderResult;
class OnlineTransducerModel {
public:
virtual ~OnlineTransducerModel() = default;
... ... @@ -27,8 +29,8 @@ class OnlineTransducerModel {
* @param states states[i] contains the state for the i-th utterance.
* @return Return a single value representing the batched state.
*/
virtual Ort::Value StackStates(
const std::vector<Ort::Value> &states) const = 0;
virtual std::vector<Ort::Value> StackStates(
const std::vector<std::vector<Ort::Value>> &states) const = 0;
/** Unstack a batch state into a list of individual states.
*
... ... @@ -37,7 +39,8 @@ class OnlineTransducerModel {
* @param states A batched state.
* @return ans[i] contains the state for the i-th utterance.
*/
virtual std::vector<Ort::Value> UnStackStates(Ort::Value states) const = 0;
virtual std::vector<std::vector<Ort::Value>> UnStackStates(
const std::vector<Ort::Value> &states) const = 0;
/** Get the initial encoder states.
*
... ... @@ -58,7 +61,8 @@ class OnlineTransducerModel {
Ort::Value features,
std::vector<Ort::Value> &states) = 0; // NOLINT
virtual Ort::Value BuildDecoderInput(const std::vector<int64_t> &hyp) = 0;
virtual Ort::Value BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
/** Run the decoder network.
*
... ... @@ -111,6 +115,7 @@ class OnlineTransducerModel {
virtual int32_t VocabSize() const = 0;
virtual int32_t SubsamplingFactor() const { return 4; }
virtual OrtAllocator *Allocator() = 0;
};
} // namespace sherpa_onnx
... ...