Karel Vesely
Committed by GitHub

Adding temperature scaling on Joiner logits: (#789)

* Adding temperature scaling on Joiner logits:

- T hard-coded to 2.0
- so far best result NCE 0.122 (still not so high)
    - the BPE scores were rescaled with 0.2 (but then also incorrect words
      get high confidence, visually reasonable histograms are for 0.5 scale)
    - BPE->WORD score merging done by min(.) function
      (tried also prob-product, and also arithmetic, geometric, harmonic mean)

- without temperature scaling (i.e. scale 1.0), the best NCE was 0.032 (here product merging was best)

Results seem consistent with: https://arxiv.org/abs/2110.15222

Everything tuned on a very-small set of 100 sentences with 813 words and 10.2% WER, a Czech model.

I also experimented with blank posteriors mixed into the BPE confidences,
but no NCE improvement found, so not pushing that.

Temperature scling added also to the Greedy search confidences.

* making `temperature_scale` configurable from outside
@@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -103,11 +103,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
103 } 103 }
104 104
105 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 105 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
106 - model_.get(), lm_.get(), config_.max_active_paths,  
107 - config_.lm_config.scale, unk_id_, config_.blank_penalty); 106 + model_.get(),
  107 + lm_.get(),
  108 + config_.max_active_paths,
  109 + config_.lm_config.scale,
  110 + unk_id_,
  111 + config_.blank_penalty,
  112 + config_.temperature_scale);
  113 +
108 } else if (config.decoding_method == "greedy_search") { 114 } else if (config.decoding_method == "greedy_search") {
109 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 115 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
110 - model_.get(), unk_id_, config_.blank_penalty); 116 + model_.get(),
  117 + unk_id_,
  118 + config_.blank_penalty,
  119 + config_.temperature_scale);
  120 +
111 } else { 121 } else {
112 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 122 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
113 config.decoding_method.c_str()); 123 config.decoding_method.c_str());
@@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { @@ -141,11 +151,21 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
141 } 151 }
142 152
143 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>( 153 decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
144 - model_.get(), lm_.get(), config_.max_active_paths,  
145 - config_.lm_config.scale, unk_id_, config_.blank_penalty); 154 + model_.get(),
  155 + lm_.get(),
  156 + config_.max_active_paths,
  157 + config_.lm_config.scale,
  158 + unk_id_,
  159 + config_.blank_penalty,
  160 + config_.temperature_scale);
  161 +
146 } else if (config.decoding_method == "greedy_search") { 162 } else if (config.decoding_method == "greedy_search") {
147 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>( 163 decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoder>(
148 - model_.get(), unk_id_, config_.blank_penalty); 164 + model_.get(),
  165 + unk_id_,
  166 + config_.blank_penalty,
  167 + config_.temperature_scale);
  168 +
149 } else { 169 } else {
150 SHERPA_ONNX_LOGE("Unsupported decoding method: %s", 170 SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
151 config.decoding_method.c_str()); 171 config.decoding_method.c_str());
@@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { @@ -96,6 +96,8 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
96 po->Register("decoding-method", &decoding_method, 96 po->Register("decoding-method", &decoding_method,
97 "decoding method," 97 "decoding method,"
98 "now support greedy_search and modified_beam_search."); 98 "now support greedy_search and modified_beam_search.");
  99 + po->Register("temperature-scale", &temperature_scale,
  100 + "Temperature scale for confidence computation in decoding.");
99 } 101 }
100 102
101 bool OnlineRecognizerConfig::Validate() const { 103 bool OnlineRecognizerConfig::Validate() const {
@@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const { @@ -142,7 +144,8 @@ std::string OnlineRecognizerConfig::ToString() const {
142 os << "hotwords_score=" << hotwords_score << ", "; 144 os << "hotwords_score=" << hotwords_score << ", ";
143 os << "hotwords_file=\"" << hotwords_file << "\", "; 145 os << "hotwords_file=\"" << hotwords_file << "\", ";
144 os << "decoding_method=\"" << decoding_method << "\", "; 146 os << "decoding_method=\"" << decoding_method << "\", ";
145 - os << "blank_penalty=" << blank_penalty << ")"; 147 + os << "blank_penalty=" << blank_penalty << ", ";
  148 + os << "temperature_scale=" << temperature_scale << ")";
146 149
147 return os.str(); 150 return os.str();
148 } 151 }
@@ -96,16 +96,23 @@ struct OnlineRecognizerConfig { @@ -96,16 +96,23 @@ struct OnlineRecognizerConfig {
96 96
97 float blank_penalty = 0.0; 97 float blank_penalty = 0.0;
98 98
  99 + float temperature_scale = 2.0;
  100 +
99 OnlineRecognizerConfig() = default; 101 OnlineRecognizerConfig() = default;
100 102
101 OnlineRecognizerConfig( 103 OnlineRecognizerConfig(
102 const FeatureExtractorConfig &feat_config, 104 const FeatureExtractorConfig &feat_config,
103 - const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, 105 + const OnlineModelConfig &model_config,
  106 + const OnlineLMConfig &lm_config,
104 const EndpointConfig &endpoint_config, 107 const EndpointConfig &endpoint_config,
105 const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, 108 const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
106 - bool enable_endpoint, const std::string &decoding_method,  
107 - int32_t max_active_paths, const std::string &hotwords_file,  
108 - float hotwords_score, float blank_penalty) 109 + bool enable_endpoint,
  110 + const std::string &decoding_method,
  111 + int32_t max_active_paths,
  112 + const std::string &hotwords_file,
  113 + float hotwords_score,
  114 + float blank_penalty,
  115 + float temperature_scale)
109 : feat_config(feat_config), 116 : feat_config(feat_config),
110 model_config(model_config), 117 model_config(model_config),
111 lm_config(lm_config), 118 lm_config(lm_config),
@@ -114,9 +121,10 @@ struct OnlineRecognizerConfig { @@ -114,9 +121,10 @@ struct OnlineRecognizerConfig {
114 enable_endpoint(enable_endpoint), 121 enable_endpoint(enable_endpoint),
115 decoding_method(decoding_method), 122 decoding_method(decoding_method),
116 max_active_paths(max_active_paths), 123 max_active_paths(max_active_paths),
117 - hotwords_score(hotwords_score),  
118 hotwords_file(hotwords_file), 124 hotwords_file(hotwords_file),
119 - blank_penalty(blank_penalty) {} 125 + hotwords_score(hotwords_score),
  126 + blank_penalty(blank_penalty),
  127 + temperature_scale(temperature_scale) {}
120 128
121 void Register(ParseOptions *po); 129 void Register(ParseOptions *po);
122 bool Validate() const; 130 bool Validate() const;
@@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -144,6 +144,10 @@ void OnlineTransducerGreedySearchDecoder::Decode(
144 144
145 // export the per-token log scores 145 // export the per-token log scores
146 if (y != 0 && y != unk_id_) { 146 if (y != 0 && y != unk_id_) {
  147 + // apply temperature-scaling
  148 + for (int32_t n = 0; n < vocab_size; ++n) {
  149 + p_logit[n] /= temperature_scale_;
  150 + }
147 LogSoftmax(p_logit, vocab_size); // renormalize probabilities, 151 LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
148 // save time by doing it only for 152 // save time by doing it only for
149 // emitted symbols 153 // emitted symbols
@@ -15,8 +15,13 @@ namespace sherpa_onnx { @@ -15,8 +15,13 @@ namespace sherpa_onnx {
15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { 15 class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
16 public: 16 public:
17 OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, 17 OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
18 - int32_t unk_id, float blank_penalty)  
19 - : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} 18 + int32_t unk_id,
  19 + float blank_penalty,
  20 + float temperature_scale)
  21 + : model_(model),
  22 + unk_id_(unk_id),
  23 + blank_penalty_(blank_penalty),
  24 + temperature_scale_(temperature_scale) {}
20 25
21 OnlineTransducerDecoderResult GetEmptyResult() const override; 26 OnlineTransducerDecoderResult GetEmptyResult() const override;
22 27
@@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { @@ -29,6 +34,7 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
29 OnlineTransducerModel *model_; // Not owned 34 OnlineTransducerModel *model_; // Not owned
30 int32_t unk_id_; 35 int32_t unk_id_;
31 float blank_penalty_; 36 float blank_penalty_;
  37 + float temperature_scale_;
32 }; 38 };
33 39
34 } // namespace sherpa_onnx 40 } // namespace sherpa_onnx
@@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -129,6 +129,22 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
129 model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out)); 129 model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
130 130
131 float *p_logit = logit.GetTensorMutableData<float>(); 131 float *p_logit = logit.GetTensorMutableData<float>();
  132 +
  133 + // copy raw logits, apply temperature-scaling (for confidences)
  134 + // Note: temperature scaling is used only for the confidences,
  135 + // the decoding algorithm uses the original logits
  136 + int32_t p_logit_items = vocab_size * num_hyps;
  137 + std::vector<float> logit_with_temperature(p_logit_items);
  138 + {
  139 + std::copy(p_logit,
  140 + p_logit + p_logit_items,
  141 + logit_with_temperature.begin());
  142 + for (float& elem : logit_with_temperature) {
  143 + elem /= temperature_scale_;
  144 + }
  145 + LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps);
  146 + }
  147 +
132 if (blank_penalty_ > 0.0) { 148 if (blank_penalty_ > 0.0) {
133 // assuming blank id is 0 149 // assuming blank id is 0
134 SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); 150 SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_);
@@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -188,10 +204,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
188 // score of the transducer 204 // score of the transducer
189 // export the per-token log scores 205 // export the per-token log scores
190 if (new_token != 0 && new_token != unk_id_) { 206 if (new_token != 0 && new_token != unk_id_) {
191 - const Hypothesis &prev_i = prev[hyp_index];  
192 - // subtract 'prev[i]' path scores, which were added before  
193 - // getting topk tokens  
194 - float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob; 207 + float y_prob = logit_with_temperature[start * vocab_size + k];
195 new_hyp.ys_probs.push_back(y_prob); 208 new_hyp.ys_probs.push_back(y_prob);
196 209
197 if (lm_) { // export only when LM is used 210 if (lm_) { // export only when LM is used
@@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -213,7 +226,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
213 cur.push_back(std::move(hyps)); 226 cur.push_back(std::move(hyps));
214 p_logprob += (end - start) * vocab_size; 227 p_logprob += (end - start) * vocab_size;
215 } // for (int32_t b = 0; b != batch_size; ++b) 228 } // for (int32_t b = 0; b != batch_size; ++b)
216 - } 229 + } // for (int32_t t = 0; t != num_frames; ++t)
217 230
218 for (int32_t b = 0; b != batch_size; ++b) { 231 for (int32_t b = 0; b != batch_size; ++b) {
219 auto &hyps = cur[b]; 232 auto &hyps = cur[b];
@@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -22,13 +22,15 @@ class OnlineTransducerModifiedBeamSearchDecoder
22 OnlineLM *lm, 22 OnlineLM *lm,
23 int32_t max_active_paths, 23 int32_t max_active_paths,
24 float lm_scale, int32_t unk_id, 24 float lm_scale, int32_t unk_id,
25 - float blank_penalty) 25 + float blank_penalty,
  26 + float temperature_scale)
26 : model_(model), 27 : model_(model),
27 lm_(lm), 28 lm_(lm),
28 max_active_paths_(max_active_paths), 29 max_active_paths_(max_active_paths),
29 lm_scale_(lm_scale), 30 lm_scale_(lm_scale),
30 unk_id_(unk_id), 31 unk_id_(unk_id),
31 - blank_penalty_(blank_penalty) {} 32 + blank_penalty_(blank_penalty),
  33 + temperature_scale_(temperature_scale) {}
32 34
33 OnlineTransducerDecoderResult GetEmptyResult() const override; 35 OnlineTransducerDecoderResult GetEmptyResult() const override;
34 36
@@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder @@ -50,6 +52,7 @@ class OnlineTransducerModifiedBeamSearchDecoder
50 float lm_scale_; // used only when lm_ is not nullptr 52 float lm_scale_; // used only when lm_ is not nullptr
51 int32_t unk_id_; 53 int32_t unk_id_;
52 float blank_penalty_; 54 float blank_penalty_;
  55 + float temperature_scale_;
53 }; 56 };
54 57
55 } // namespace sherpa_onnx 58 } // namespace sherpa_onnx
@@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -50,17 +50,30 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
50 using PyClass = OnlineRecognizerConfig; 50 using PyClass = OnlineRecognizerConfig;
51 py::class_<PyClass>(*m, "OnlineRecognizerConfig") 51 py::class_<PyClass>(*m, "OnlineRecognizerConfig")
52 .def( 52 .def(
53 - py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,  
54 - const OnlineLMConfig &, const EndpointConfig &,  
55 - const OnlineCtcFstDecoderConfig &, bool, const std::string &,  
56 - int32_t, const std::string &, float, float>(),  
57 - py::arg("feat_config"), py::arg("model_config"), 53 + py::init<const FeatureExtractorConfig &,
  54 + const OnlineModelConfig &,
  55 + const OnlineLMConfig &,
  56 + const EndpointConfig &,
  57 + const OnlineCtcFstDecoderConfig &,
  58 + bool,
  59 + const std::string &,
  60 + int32_t,
  61 + const std::string &,
  62 + float,
  63 + float,
  64 + float>(),
  65 + py::arg("feat_config"),
  66 + py::arg("model_config"),
58 py::arg("lm_config") = OnlineLMConfig(), 67 py::arg("lm_config") = OnlineLMConfig(),
59 py::arg("endpoint_config") = EndpointConfig(), 68 py::arg("endpoint_config") = EndpointConfig(),
60 py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), 69 py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
61 - py::arg("enable_endpoint"), py::arg("decoding_method"),  
62 - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",  
63 - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0) 70 + py::arg("enable_endpoint"),
  71 + py::arg("decoding_method"),
  72 + py::arg("max_active_paths") = 4,
  73 + py::arg("hotwords_file") = "",
  74 + py::arg("hotwords_score") = 0,
  75 + py::arg("blank_penalty") = 0.0,
  76 + py::arg("temperature_scale") = 2.0)
64 .def_readwrite("feat_config", &PyClass::feat_config) 77 .def_readwrite("feat_config", &PyClass::feat_config)
65 .def_readwrite("model_config", &PyClass::model_config) 78 .def_readwrite("model_config", &PyClass::model_config)
66 .def_readwrite("lm_config", &PyClass::lm_config) 79 .def_readwrite("lm_config", &PyClass::lm_config)
@@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { @@ -72,6 +85,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
72 .def_readwrite("hotwords_file", &PyClass::hotwords_file) 85 .def_readwrite("hotwords_file", &PyClass::hotwords_file)
73 .def_readwrite("hotwords_score", &PyClass::hotwords_score) 86 .def_readwrite("hotwords_score", &PyClass::hotwords_score)
74 .def_readwrite("blank_penalty", &PyClass::blank_penalty) 87 .def_readwrite("blank_penalty", &PyClass::blank_penalty)
  88 + .def_readwrite("temperature_scale", &PyClass::temperature_scale)
75 .def("__str__", &PyClass::ToString); 89 .def("__str__", &PyClass::ToString);
76 } 90 }
77 91
@@ -58,6 +58,7 @@ class OnlineRecognizer(object): @@ -58,6 +58,7 @@ class OnlineRecognizer(object):
58 model_type: str = "", 58 model_type: str = "",
59 lm: str = "", 59 lm: str = "",
60 lm_scale: float = 0.1, 60 lm_scale: float = 0.1,
  61 + temperature_scale: float = 2.0,
61 ): 62 ):
62 """ 63 """
63 Please refer to 64 Please refer to
@@ -123,6 +124,10 @@ class OnlineRecognizer(object): @@ -123,6 +124,10 @@ class OnlineRecognizer(object):
123 hotwords_score: 124 hotwords_score:
124 The hotword score of each token for biasing word/phrase. Used only if 125 The hotword score of each token for biasing word/phrase. Used only if
125 hotwords_file is given with modified_beam_search as decoding method. 126 hotwords_file is given with modified_beam_search as decoding method.
  127 + temperature_scale:
  128 + Temperature scaling for output symbol confidence estiamation.
  129 + It affects only confidence values, the decoding uses the original
  130 + logits without temperature.
126 provider: 131 provider:
127 onnxruntime execution providers. Valid values are: cpu, cuda, coreml. 132 onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
128 model_type: 133 model_type:
@@ -193,6 +198,7 @@ class OnlineRecognizer(object): @@ -193,6 +198,7 @@ class OnlineRecognizer(object):
193 hotwords_score=hotwords_score, 198 hotwords_score=hotwords_score,
194 hotwords_file=hotwords_file, 199 hotwords_file=hotwords_file,
195 blank_penalty=blank_penalty, 200 blank_penalty=blank_penalty,
  201 + temperature_scale=temperature_scale,
196 ) 202 )
197 203
198 self.recognizer = _Recognizer(recognizer_config) 204 self.recognizer = _Recognizer(recognizer_config)