chiiyeh
Committed by GitHub

add blank_penalty for offline transducer (#542)

... ... @@ -383,6 +383,19 @@ def add_hotwords_args(parser: argparse.ArgumentParser):
""",
)
def add_blank_penalty_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
def check_args(args):
if not Path(args.tokens).is_file():
... ... @@ -414,6 +427,7 @@ def get_args():
add_feature_config_args(parser)
add_decoding_args(parser)
add_hotwords_args(parser)
add_blank_penalty_args(parser)
parser.add_argument(
"--port",
... ... @@ -862,6 +876,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
max_active_paths=args.max_active_paths,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
provider=args.provider,
)
elif args.paraformer:
... ...
... ... @@ -232,6 +232,18 @@ def get_args():
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
... ... @@ -335,6 +347,7 @@ def main():
decoding_method=args.decoding_method,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
blank_penalty=args.blank_penalty,
debug=args.debug,
)
elif args.paraformer:
... ...
... ... @@ -178,6 +178,18 @@ def get_args():
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
... ... @@ -237,6 +249,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
blank_penalty=args.blank_penalty,
debug=args.debug,
)
elif args.paraformer:
... ...
... ... @@ -96,6 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
}
}
template <typename T>
void SubtractBlank(T *in, int32_t w, int32_t h,
int32_t blank_idx, float blank_penalty) {
for (int32_t i = 0; i != h; ++i) {
in[blank_idx] -= blank_penalty;
in += w;
}
}
template <class T>
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
std::vector<int32_t> vec_index(size);
... ...
... ... @@ -79,7 +79,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(config.lm_config);
... ... @@ -87,7 +88,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
config_.lm_config.scale, config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
... ... @@ -104,7 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
config_.model_config)) {
if (config_.decoding_method == "greedy_search") {
decoder_ =
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
std::make_unique<OfflineTransducerGreedySearchDecoder>(
model_.get(), config_.blank_penalty);
} else if (config_.decoding_method == "modified_beam_search") {
if (!config_.lm_config.model.empty()) {
lm_ = OfflineLM::Create(mgr, config.lm_config);
... ... @@ -112,7 +114,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale);
config_.lm_config.scale, config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config_.decoding_method.c_str());
... ...
... ... @@ -28,6 +28,13 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
po->Register("max-active-paths", &max_active_paths,
"Used only when decoding_method is modified_beam_search");
po->Register("blank-penalty", &blank_penalty,
"The penalty applied on blank symbol during decoding. "
"Note: It is a positive value. "
"Increasing value will lead to lower deletion at the cost"
"of higher insertions. "
"Currently only applicable for transducer models.");
po->Register(
"hotwords-file", &hotwords_file,
"The file containing hotwords, one words/phrases per line, and for each"
... ... @@ -74,7 +81,8 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "decoding_method=\"" << decoding_method << "\", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ")";
os << "hotwords_score=" << hotwords_score << ", ";
os << "blank_penalty=" << blank_penalty << ")";
return os.str();
}
... ...
... ... @@ -37,6 +37,8 @@ struct OfflineRecognizerConfig {
std::string hotwords_file;
float hotwords_score = 1.5;
float blank_penalty = 0.0;
// only greedy_search is implemented
// TODO(fangjun): Implement modified_beam_search
... ... @@ -46,7 +48,8 @@ struct OfflineRecognizerConfig {
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score)
const std::string &hotwords_file, float hotwords_score,
float blank_penalty)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
... ... @@ -54,7 +57,8 @@ struct OfflineRecognizerConfig {
decoding_method(decoding_method),
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
hotwords_score(hotwords_score) {}
hotwords_score(hotwords_score),
blank_penalty(blank_penalty) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -46,9 +46,12 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
start += n;
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
std::move(cur_decoder_out));
const float *p_logit = logit.GetTensorData<float>();
float *p_logit = logit.GetTensorMutableData<float>();
bool emitted = false;
for (int32_t i = 0; i != n; ++i) {
if (blank_penalty_ > 0.0) {
p_logit[0] -= blank_penalty_; // assuming blank id is 0
}
auto y = static_cast<int32_t>(std::distance(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
... ...
... ... @@ -14,8 +14,10 @@ namespace sherpa_onnx {
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
public:
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)
: model_(model) {}
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
float blank_penalty)
: model_(model),
blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
... ... @@ -23,6 +25,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
private:
OfflineTransducerModel *model_; // Not owned
float blank_penalty_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -97,6 +97,10 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
if (blank_penalty_ > 0.0) {
// assuming blank id is 0
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
}
LogSoftmax(p_logit, vocab_size, num_hyps);
// now p_logit contains log_softmax output, we rename it to p_logprob
... ...
... ... @@ -19,11 +19,13 @@ class OfflineTransducerModifiedBeamSearchDecoder
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
OfflineLM *lm,
int32_t max_active_paths,
float lm_scale)
float lm_scale,
float blank_penalty)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale) {}
lm_scale_(lm_scale),
blank_penalty_(blank_penalty) {}
std::vector<OfflineTransducerDecoderResult> Decode(
Ort::Value encoder_out, Ort::Value encoder_out_length,
... ... @@ -35,6 +37,7 @@ class OfflineTransducerModifiedBeamSearchDecoder
int32_t max_active_paths_;
float lm_scale_; // used only when lm_ is not nullptr
float blank_penalty_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -17,13 +17,14 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def(py::init<const OfflineFeatureExtractorConfig &,
const OfflineModelConfig &, const OfflineLMConfig &,
const OfflineCtcFstDecoderConfig &, const std::string &,
int32_t, const std::string &, float>(),
int32_t, const std::string &, float, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5)
py::arg("hotwords_score") = 1.5,
py::arg("blank_penalty") = 0.0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
... ... @@ -32,6 +33,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -48,6 +48,7 @@ class OfflineRecognizer(object):
max_active_paths: int = 4,
hotwords_file: str = "",
hotwords_score: float = 1.5,
blank_penalty: float = 0.0,
debug: bool = False,
provider: str = "cpu",
):
... ... @@ -81,6 +82,8 @@ class OfflineRecognizer(object):
max_active_paths:
Maximum number of active paths to keep. Used only when
decoding_method is modified_beam_search.
blank_penalty:
The penalty applied on blank symbol during decoding.
debug:
True to show debug messages.
provider:
... ... @@ -117,6 +120,7 @@ class OfflineRecognizer(object):
decoding_method=decoding_method,
hotwords_file=hotwords_file,
hotwords_score=hotwords_score,
blank_penalty=blank_penalty,
)
self.recognizer = _Recognizer(recognizer_config)
self.config = recognizer_config
... ...