Zhong-Yi Li
Committed by GitHub

offline transducer: treat unk as blank (#1005)

Co-authored-by: chungyi.li <chungyi.li@ailabs.tw>
... ... @@ -78,9 +78,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
config_(config),
symbol_table_(config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
if (symbol_table_.Contains("<unk>")) {
unk_id_ = symbol_table_["<unk>"];
}
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
model_.get(), unk_id_, config_.blank_penalty);
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(config.lm_config);
... ... @@ -97,7 +101,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.blank_penalty);
config_.lm_config.scale, unk_id_, config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
... ... @@ -113,9 +117,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineTransducerModel>(mgr,
config_.model_config)) {
if (symbol_table_.Contains("<unk>")) {
unk_id_ = symbol_table_["<unk>"];
}
if (config_.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
model_.get(), unk_id_, config_.blank_penalty);
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(mgr, config.lm_config);
... ... @@ -133,7 +141,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, config_.blank_penalty);
config_.lm_config.scale, unk_id_, config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
... ... @@ -293,6 +301,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
std::unique_ptr<OfflineTransducerModel> model_;
std::unique_ptr<OfflineTransducerDecoder> decoder_;
std::unique_ptr<OfflineLM> lm_;
int32_t unk_id_ = -1;
};
} // namespace sherpa_onnx
... ...
... ... @@ -57,7 +57,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));
p_logit += vocab_size;
if (y != 0) {
// blank id is hardcoded to 0
// also, it treats unk as blank
if (y != 0 && y != unk_id_) {
ans[i].tokens.push_back(y);
ans[i].timestamps.push_back(t);
emitted = true;
... ...
... ... @@ -15,8 +15,9 @@ namespace sherpa_onnx {
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
public:
OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
int32_t unk_id,
float blank_penalty)
: model_(model), blank_penalty_(blank_penalty) {}
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
... ... @@ -24,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
private:
OfflineTransducerModel *model_; // Not owned
int32_t unk_id_;
float blank_penalty_;
};
... ...
... ... @@ -131,8 +131,9 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
float context_score = 0;
auto context_state = new_hyp.context_state;
if (new_token != 0) {
// blank id is fixed to 0
// blank is hardcoded to 0
// also, it treats unk as blank
if (new_token != 0 && new_token != unk_id_) {
new_hyp.ys.push_back(new_token);
new_hyp.timestamps.push_back(t);
if (context_graphs[i] != nullptr) {
... ...
... ... @@ -19,12 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
OfflineLM *lm,
int32_t max_active_paths,
float lm_scale,
float lm_scale, int32_t unk_id,
float blank_penalty)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale),
unk_id_(unk_id),
blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode(
... ... @@ -37,6 +38,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr
int32_t unk_id_;
float blank_penalty_;
};
... ...