Karel Vesely
Committed by GitHub

Adding temperature scaling on Joiner logits: (#789)

* Adding temperature scaling on Joiner logits:

- T hard-coded to 2.0
- so far best result NCE 0.122 (still not so high)
    - the BPE scores were rescaled with 0.2 (but then also incorrect words
      get high confidence, visually reasonable histograms are for 0.5 scale)
    - BPE->WORD score merging done by min(.) function
      (tried also prob-product, and also arithmetic, geometric, harmonic mean)

- without temperature scaling (i.e. scale 1.0), the best NCE was 0.032 (here product merging was best)

Results seem consistent with: https://arxiv.org/abs/2110.15222

Everything tuned on a very-small set of 100 sentences with 813 words and 10.2% WER, a Czech model.

I also experimented with blank posteriors mixed into the BPE confidences,
but no NCE improvement found, so not pushing that.

Temperature scling added also to the Greedy search confidences.

* making `temperature_scale` configurable from outside
... ... @@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty);
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_, config_.blank_penalty);
model_.get(),
unk_id_,
config_.blank_penalty,
config_.temperature_scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
... ... @@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
model_.get(), lm_.get(), config_.max_active_paths,
config_.lm_config.scale, unk_id_, config_.blank_penalty);
model_.get(),
lm_.get(),
config_.max_active_paths,
config_.lm_config.scale,
unk_id_,
config_.blank_penalty,
config_.temperature_scale);
} else if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
model_.get(), unk_id_, config_.blank_penalty);
model_.get(),
unk_id_,
config_.blank_penalty,
config_.temperature_scale);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
... ...
... ... @@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
po->Register("temperature-scale", &temperature_scale,
"Temperature scale for confidence computation in decoding.");
}
bool OnlineRecognizerConfig::Validate() const {
... ... @@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "blank_penalty=" << blank_penalty << ")";
os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ")";
return os.str();
}
... ...
... ... @@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {
float blank_penalty = 0.0;
float temperature_scale = 2.0;
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(
const FeatureExtractorConfig &feat_config,
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
const OnlineModelConfig &model_config,
const OnlineLMConfig &lm_config,
const EndpointConfig &endpoint_config,
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty)
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths,
const std::string &hotwords_file,
float hotwords_score,
float blank_penalty,
float temperature_scale)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
... ... @@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths),
hotwords_score(hotwords_score),
hotwords_file(hotwords_file),
blank_penalty(blank_penalty) {}
hotwords_score(hotwords_score),
blank_penalty(blank_penalty),
temperature_scale(temperature_scale) {}
void Register(ParseOptions *po);
bool Validate() const;
... ...
... ... @@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
// export the per-token log scores
if (y != 0 && y != unk_id_) {
// apply temperature-scaling
for (int32_t n = 0; n < vocab_size; ++n) {
p_logit[n] /= temperature_scale_;
}
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
// save time by doing it only for
// emitted symbols
... ...
... ... @@ -15,8 +15,13 @@ namespace sherpa_onnx {
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
public:
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
int32_t unk_id, float blank_penalty)
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
int32_t unk_id,
float blank_penalty,
float temperature_scale)
: model_(model),
unk_id_(unk_id),
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}
OnlineTransducerDecoderResult GetEmptyResult() const override;
... ... @@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
OnlineTransducerModel *model_; // Not owned
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
float *p_logit = logit.GetTensorMutableData<float>();
// copy raw logits, apply temperature-scaling (for confidences)
// Note: temperature scaling is used only for the confidences,
// the decoding algorithm uses the original logits
int32_t p_logit_items = vocab_size * num_hyps;
std::vector<float> logit_with_temperature(p_logit_items);
{
std::copy(p_logit,
p_logit + p_logit_items,
logit_with_temperature.begin());
for (float& elem : logit_with_temperature) {
elem /= temperature_scale_;
}
LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
}
if (blank_penalty_ > 0.0) {
// assuming blank id is 0
SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
... ... @@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
// score of the transducer
// export the per-token log scores
if (new_token != 0 && new_token != unk_id_) {
const Hypothesis &prev_i = prev[hyp_index];
// subtract 'prev[i]' path scores, which were added before
// getting topk tokens
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
float y_prob = logit_with_temperature[start * vocab_size + k];
new_hyp.ys_probs.push_back(y_prob);
if (lm_) { // export only when LM is used
... ... @@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
cur.push_back(std::move(hyps));
p_logprob += (end - start) * vocab_size;
} // for (int32_t b = 0; b != batch_size; ++b)
}
} // for (int32_t t = 0; t != num_frames; ++t)
for (int32_t b = 0; b != batch_size; ++b) {
auto &hyps = cur[b];
... ...
... ... @@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
OnlineLM *lm,
int32_t max_active_paths,
float lm_scale, int32_t unk_id,
float blank_penalty)
float blank_penalty,
float temperature_scale)
: model_(model),
lm_(lm),
max_active_paths_(max_active_paths),
lm_scale_(lm_scale),
unk_id_(unk_id),
blank_penalty_(blank_penalty) {}
blank_penalty_(blank_penalty),
temperature_scale_(temperature_scale) {}
OnlineTransducerDecoderResult GetEmptyResult() const override;
... ... @@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
float lm_scale_; // used only when lm_ is not nullptr
int32_t unk_id_;
float blank_penalty_;
float temperature_scale_;
};
} // namespace sherpa_onnx
... ...
... ... @@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
int32_t, const std::string &, float, float>(),
py::arg("feat_config"), py::arg("model_config"),
py::init<const FeatureExtractorConfig &,
const OnlineModelConfig &,
const OnlineLMConfig &,
const EndpointConfig &,
const OnlineCtcFstDecoderConfig &,
bool,
const std::string &,
int32_t,
const std::string &,
float,
float,
float>(),
py::arg("feat_config"),
py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
py::arg("enable_endpoint"),
py::arg("decoding_method"),
py::arg("max_active_paths") = 4,
py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0,
py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
... ... @@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def("__str__", &PyClass::ToString);
}
... ...
... ... @@ -58,6 +58,7 @@ class OnlineRecognizer(object):
model_type: str = "",
lm: str = "",
lm_scale: float = 0.1,
temperature_scale: float = 2.0,
):
"""
Please refer to
... ... @@ -123,6 +124,10 @@ class OnlineRecognizer(object):
hotwords_score:
The hotword score of each token for biasing word/phrase. Used only if
hotwords_file is given with modified_beam_search as decoding method.
temperature_scale:
Temperature scaling for output symbol confidence estiamation.
It affects only confidence values, the decoding uses the original
logits without temperature.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
model_type:
... ... @@ -193,6 +198,7 @@ class OnlineRecognizer(object):
hotwords_score=hotwords_score,
hotwords_file=hotwords_file,
blank_penalty=blank_penalty,
temperature_scale=temperature_scale,
)
self.recognizer = _Recognizer(recognizer_config)
... ...