PF Luo
Committed by GitHub

add shallow fusion (#147)

@@ -34,13 +34,11 @@ struct Hypothesis { @@ -34,13 +34,11 @@ struct Hypothesis {
34 // LM log prob if any. 34 // LM log prob if any.
35 double lm_log_prob = 0; 35 double lm_log_prob = 0;
36 36
37 - int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM 37 + // the nn lm score for next token given the current ys
  38 + CopyableOrtValue nn_lm_scores;
  39 + // the nn lm states
38 std::vector<CopyableOrtValue> nn_lm_states; 40 std::vector<CopyableOrtValue> nn_lm_states;
39 41
40 - // TODO(fangjun): Make it configurable  
41 - // the minimum of tokens in a chunk for streaming RNN LM  
42 - int32_t lm_rescore_min_chunk = 2; // a const  
43 -  
44 int32_t num_trailing_blanks = 0; 42 int32_t num_trailing_blanks = 0;
45 43
46 Hypothesis() = default; 44 Hypothesis() = default;
@@ -13,80 +13,8 @@ @@ -13,80 +13,8 @@
13 13
14 namespace sherpa_onnx { 14 namespace sherpa_onnx {
15 15
16 -static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {  
17 - std::vector<CopyableOrtValue> ans;  
18 - ans.reserve(values.size());  
19 -  
20 - for (auto &v : values) {  
21 - ans.emplace_back(std::move(v));  
22 - }  
23 -  
24 - return ans;  
25 -}  
26 -  
27 -static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {  
28 - std::vector<Ort::Value> ans;  
29 - ans.reserve(values.size());  
30 -  
31 - for (auto &v : values) {  
32 - ans.emplace_back(std::move(v.value));  
33 - }  
34 -  
35 - return ans;  
36 -}  
37 -  
38 std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { 16 std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
39 return std::make_unique<OnlineRnnLM>(config); 17 return std::make_unique<OnlineRnnLM>(config);
40 } 18 }
41 19
42 -void OnlineLM::ComputeLMScore(float scale, int32_t context_size,  
43 - std::vector<Hypotheses> *hyps) {  
44 - Ort::AllocatorWithDefaultOptions allocator;  
45 -  
46 - for (auto &hyp : *hyps) {  
47 - for (auto &h_m : hyp) {  
48 - auto &h = h_m.second;  
49 - auto &ys = h.ys;  
50 - const int32_t token_num_in_chunk =  
51 - ys.size() - context_size - h.cur_scored_pos - 1;  
52 -  
53 - if (token_num_in_chunk < 1) {  
54 - continue;  
55 - }  
56 -  
57 - if (h.nn_lm_states.empty()) {  
58 - h.nn_lm_states = Convert(GetInitStates());  
59 - }  
60 -  
61 - if (token_num_in_chunk >= h.lm_rescore_min_chunk) {  
62 - std::array<int64_t, 2> x_shape{1, token_num_in_chunk};  
63 - // shape of x and y are same  
64 - Ort::Value x = Ort::Value::CreateTensor<int64_t>(  
65 - allocator, x_shape.data(), x_shape.size());  
66 - Ort::Value y = Ort::Value::CreateTensor<int64_t>(  
67 - allocator, x_shape.data(), x_shape.size());  
68 - int64_t *p_x = x.GetTensorMutableData<int64_t>();  
69 - int64_t *p_y = y.GetTensorMutableData<int64_t>();  
70 - std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,  
71 - p_x);  
72 - std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(),  
73 - p_y);  
74 -  
75 - // streaming forward by NN LM  
76 - auto out = Rescore(std::move(x), std::move(y),  
77 - Convert(std::move(h.nn_lm_states)));  
78 -  
79 - // update NN LM score in hyp  
80 - const float *p_nll = out.first.GetTensorData<float>();  
81 - h.lm_log_prob = -scale * (*p_nll);  
82 -  
83 - // update NN LM states in hyp  
84 - h.nn_lm_states = Convert(std::move(out.second));  
85 -  
86 - h.cur_scored_pos += token_num_in_chunk;  
87 - }  
88 - }  
89 - }  
90 -}  
91 -  
92 } // namespace sherpa_onnx 20 } // namespace sherpa_onnx
@@ -21,29 +21,27 @@ class OnlineLM { @@ -21,29 +21,27 @@ class OnlineLM {
21 21
22 static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); 22 static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
23 23
24 - virtual std::vector<Ort::Value> GetInitStates() = 0; 24 + virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0;
25 25
26 - /** Rescore a batch of sentences. 26 + /** ScoreToken a batch of sentences.
27 * 27 *
28 - * @param x A 2-D tensor of shape (N, L) with data type int64.  
29 - * @param y A 2-D tensor of shape (N, L) with data type int64. 28 + * @param x A 2-D tensor of shape (N, 1) with data type int64.
30 * @param states It contains the states for the LM model 29 * @param states It contains the states for the LM model
31 * @return Return a pair containingo 30 * @return Return a pair containingo
32 - * - negative loglike 31 + * - log_prob of NN LM
33 * - updated states 32 * - updated states
34 * 33 *
35 - * Caution: It returns negative log likelihood (nll), not log likelihood  
36 */ 34 */
37 - virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(  
38 - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;  
39 -  
40 - // This function updates hyp.lm_lob_prob of hyps.  
41 - //  
42 - // @param scale LM score  
43 - // @param context_size Context size of the transducer decoder model  
44 - // @param hyps It is changed in-place.  
45 - void ComputeLMScore(float scale, int32_t context_size,  
46 - std::vector<Hypotheses> *hyps); 35 + virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
  36 + Ort::Value x, std::vector<Ort::Value> states) = 0;
  37 +
  38 + /** This function updates lm_lob_prob and nn_lm_scores of hyp
  39 + *
  40 + * @param scale LM score
  41 + * @param hyps It is changed in-place.
  42 + *
  43 + */
  44 + virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0;
47 }; 45 };
48 46
49 } // namespace sherpa_onnx 47 } // namespace sherpa_onnx
@@ -26,10 +26,33 @@ class OnlineRnnLM::Impl { @@ -26,10 +26,33 @@ class OnlineRnnLM::Impl {
26 Init(config); 26 Init(config);
27 } 27 }
28 28
29 - std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(  
30 - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {  
31 - std::array<Ort::Value, 4> inputs = {  
32 - std::move(x), std::move(y), std::move(states[0]), std::move(states[1])}; 29 + void ComputeLMScore(float scale, Hypothesis *hyp) {
  30 + if (hyp->nn_lm_states.empty()) {
  31 + auto init_states = GetInitStates();
  32 + hyp->nn_lm_scores.value = std::move(init_states.first);
  33 + hyp->nn_lm_states = Convert(std::move(init_states.second));
  34 + }
  35 +
  36 + // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob
  37 + const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>();
  38 + hyp->lm_log_prob = nn_lm_scores[hyp->ys.back()] * scale;
  39 +
  40 + // get lm scores for next tokens given the hyp->ys[:] and save to
  41 + // nn_lm_scores
  42 + std::array<int64_t, 2> x_shape{1, 1};
  43 + Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
  44 + x_shape.size());
  45 + *x.GetTensorMutableData<int64_t>() = hyp->ys.back();
  46 + auto lm_out =
  47 + ScoreToken(std::move(x), Convert(hyp->nn_lm_states));
  48 + hyp->nn_lm_scores.value = std::move(lm_out.first);
  49 + hyp->nn_lm_states = Convert(std::move(lm_out.second));
  50 + }
  51 +
  52 + std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
  53 + Ort::Value x, std::vector<Ort::Value> states) {
  54 + std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]),
  55 + std::move(states[1])};
33 56
34 auto out = 57 auto out =
35 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), 58 sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
@@ -43,15 +66,13 @@ class OnlineRnnLM::Impl { @@ -43,15 +66,13 @@ class OnlineRnnLM::Impl {
43 return {std::move(out[0]), std::move(next_states)}; 66 return {std::move(out[0]), std::move(next_states)};
44 } 67 }
45 68
46 - std::vector<Ort::Value> GetInitStates() const { 69 + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const {
47 std::vector<Ort::Value> ans; 70 std::vector<Ort::Value> ans;
48 ans.reserve(init_states_.size()); 71 ans.reserve(init_states_.size());
49 -  
50 for (const auto &s : init_states_) { 72 for (const auto &s : init_states_) {
51 ans.emplace_back(Clone(allocator_, &s)); 73 ans.emplace_back(Clone(allocator_, &s));
52 } 74 }
53 -  
54 - return ans; 75 + return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)};
55 } 76 }
56 77
57 private: 78 private:
@@ -86,19 +107,16 @@ class OnlineRnnLM::Impl { @@ -86,19 +107,16 @@ class OnlineRnnLM::Impl {
86 Fill<float>(&h, 0); 107 Fill<float>(&h, 0);
87 Fill<float>(&c, 0); 108 Fill<float>(&c, 0);
88 std::array<int64_t, 2> x_shape{1, 1}; 109 std::array<int64_t, 2> x_shape{1, 1};
89 - // shape of x and y are same  
90 Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), 110 Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
91 x_shape.size()); 111 x_shape.size());
92 - Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),  
93 - x_shape.size());  
94 *x.GetTensorMutableData<int64_t>() = sos_id_; 112 *x.GetTensorMutableData<int64_t>() = sos_id_;
95 - *y.GetTensorMutableData<int64_t>() = sos_id_;  
96 113
97 std::vector<Ort::Value> states; 114 std::vector<Ort::Value> states;
98 states.push_back(std::move(h)); 115 states.push_back(std::move(h));
99 states.push_back(std::move(c)); 116 states.push_back(std::move(c));
100 - auto pair = Rescore(std::move(x), std::move(y), std::move(states)); 117 + auto pair = ScoreToken(std::move(x), std::move(states));
101 118
  119 + init_scores_.value = std::move(pair.first);
102 init_states_ = std::move(pair.second); 120 init_states_ = std::move(pair.second);
103 } 121 }
104 122
@@ -116,6 +134,7 @@ class OnlineRnnLM::Impl { @@ -116,6 +134,7 @@ class OnlineRnnLM::Impl {
116 std::vector<std::string> output_names_; 134 std::vector<std::string> output_names_;
117 std::vector<const char *> output_names_ptr_; 135 std::vector<const char *> output_names_ptr_;
118 136
  137 + CopyableOrtValue init_scores_;
119 std::vector<Ort::Value> init_states_; 138 std::vector<Ort::Value> init_states_;
120 139
121 int32_t rnn_num_layers_ = 2; 140 int32_t rnn_num_layers_ = 2;
@@ -128,13 +147,17 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) @@ -128,13 +147,17 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
128 147
129 OnlineRnnLM::~OnlineRnnLM() = default; 148 OnlineRnnLM::~OnlineRnnLM() = default;
130 149
131 -std::vector<Ort::Value> OnlineRnnLM::GetInitStates() { 150 +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() {
132 return impl_->GetInitStates(); 151 return impl_->GetInitStates();
133 } 152 }
134 153
135 -std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore(  
136 - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {  
137 - return impl_->Rescore(std::move(x), std::move(y), std::move(states)); 154 +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken(
  155 + Ort::Value x, std::vector<Ort::Value> states) {
  156 + return impl_->ScoreToken(std::move(x), std::move(states));
  157 +}
  158 +
  159 +void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) {
  160 + return impl_->ComputeLMScore(scale, hyp);
138 } 161 }
139 162
140 } // namespace sherpa_onnx 163 } // namespace sherpa_onnx
@@ -22,21 +22,27 @@ class OnlineRnnLM : public OnlineLM { @@ -22,21 +22,27 @@ class OnlineRnnLM : public OnlineLM {
22 22
23 explicit OnlineRnnLM(const OnlineLMConfig &config); 23 explicit OnlineRnnLM(const OnlineLMConfig &config);
24 24
25 - std::vector<Ort::Value> GetInitStates() override; 25 + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override;
26 26
27 - /** Rescore a batch of sentences. 27 + /** ScoreToken a batch of sentences.
28 * 28 *
29 * @param x A 2-D tensor of shape (N, L) with data type int64. 29 * @param x A 2-D tensor of shape (N, L) with data type int64.
30 - * @param y A 2-D tensor of shape (N, L) with data type int64.  
31 * @param states It contains the states for the LM model 30 * @param states It contains the states for the LM model
32 * @return Return a pair containingo 31 * @return Return a pair containingo
33 - * - negative loglike 32 + * - log_prob of NN LM
34 * - updated states 33 * - updated states
35 * 34 *
36 - * Caution: It returns negative log likelihood (nll), not log likelihood  
37 */ 35 */
38 - std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(  
39 - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override; 36 + std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken(
  37 + Ort::Value x, std::vector<Ort::Value> states) override;
  38 +
  39 + /** This function updates lm_lob_prob and nn_lm_scores of hyp
  40 + *
  41 + * @param scale LM score
  42 + * @param hyps It is changed in-place.
  43 + *
  44 + */
  45 + void ComputeLMScore(float scale, Hypothesis *hyp) override;
40 46
41 private: 47 private:
42 class Impl; 48 class Impl;
@@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
121 121
122 // add log_prob of each hypothesis to p_logprob before taking top_k 122 // add log_prob of each hypothesis to p_logprob before taking top_k
123 for (int32_t i = 0; i != num_hyps; ++i) { 123 for (int32_t i = 0; i != num_hyps; ++i) {
124 - float log_prob = prev[i].log_prob; 124 + float log_prob = prev[i].log_prob + prev[i].lm_log_prob;
125 for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { 125 for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
126 *p_logprob += log_prob; 126 *p_logprob += log_prob;
127 } 127 }
@@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
141 int32_t new_token = k % vocab_size; 141 int32_t new_token = k % vocab_size;
142 142
143 Hypothesis new_hyp = prev[hyp_index]; 143 Hypothesis new_hyp = prev[hyp_index];
  144 + const float prev_lm_log_prob = new_hyp.lm_log_prob;
144 if (new_token != 0) { 145 if (new_token != 0) {
145 new_hyp.ys.push_back(new_token); 146 new_hyp.ys.push_back(new_token);
146 new_hyp.timestamps.push_back(t + frame_offset); 147 new_hyp.timestamps.push_back(t + frame_offset);
147 new_hyp.num_trailing_blanks = 0; 148 new_hyp.num_trailing_blanks = 0;
  149 + if (lm_) {
  150 + lm_->ComputeLMScore(lm_scale_, &new_hyp);
  151 + }
148 } else { 152 } else {
149 ++new_hyp.num_trailing_blanks; 153 ++new_hyp.num_trailing_blanks;
150 } 154 }
151 - new_hyp.log_prob = p_logprob[k]; 155 + new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob;
152 hyps.Add(std::move(new_hyp)); 156 hyps.Add(std::move(new_hyp));
153 } // for (auto k : topk) 157 } // for (auto k : topk)
154 cur.push_back(std::move(hyps)); 158 cur.push_back(std::move(hyps));
@@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
156 } // for (int32_t b = 0; b != batch_size; ++b) 160 } // for (int32_t b = 0; b != batch_size; ++b)
157 } 161 }
158 162
159 - if (lm_) {  
160 - lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);  
161 - }  
162 -  
163 for (int32_t b = 0; b != batch_size; ++b) { 163 for (int32_t b = 0; b != batch_size; ++b) {
164 auto &hyps = cur[b]; 164 auto &hyps = cur[b];
165 auto best_hyp = hyps.GetMostProbable(true); 165 auto best_hyp = hyps.GetMostProbable(true);
@@ -245,4 +245,26 @@ CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) { @@ -245,4 +245,26 @@ CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) {
245 return *this; 245 return *this;
246 } 246 }
247 247
  248 +std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {
  249 + std::vector<CopyableOrtValue> ans;
  250 + ans.reserve(values.size());
  251 +
  252 + for (auto &v : values) {
  253 + ans.emplace_back(std::move(v));
  254 + }
  255 +
  256 + return ans;
  257 +}
  258 +
  259 +std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
  260 + std::vector<Ort::Value> ans;
  261 + ans.reserve(values.size());
  262 +
  263 + for (auto &v : values) {
  264 + ans.emplace_back(std::move(v.value));
  265 + }
  266 +
  267 + return ans;
  268 +}
  269 +
248 } // namespace sherpa_onnx 270 } // namespace sherpa_onnx
@@ -97,8 +97,8 @@ struct CopyableOrtValue { @@ -97,8 +97,8 @@ struct CopyableOrtValue {
97 97
98 CopyableOrtValue() = default; 98 CopyableOrtValue() = default;
99 99
100 - /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT  
101 - : value(std::move(v)) {} 100 + /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT
  101 + : value(std::move(v)) {}
102 102
103 CopyableOrtValue(const CopyableOrtValue &other); 103 CopyableOrtValue(const CopyableOrtValue &other);
104 104
@@ -109,6 +109,10 @@ struct CopyableOrtValue { @@ -109,6 +109,10 @@ struct CopyableOrtValue {
109 CopyableOrtValue &operator=(CopyableOrtValue &&other); 109 CopyableOrtValue &operator=(CopyableOrtValue &&other);
110 }; 110 };
111 111
  112 +std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values);
  113 +
  114 +std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values);
  115 +
112 } // namespace sherpa_onnx 116 } // namespace sherpa_onnx
113 117
114 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ 118 #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
@@ -94,7 +94,7 @@ for a list of pre-trained models to download. @@ -94,7 +94,7 @@ for a list of pre-trained models to download.
94 auto s = recognizer.CreateStream(); 94 auto s = recognizer.CreateStream();
95 s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); 95 s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
96 96
97 - std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate)); 97 + std::vector<float> tail_paddings(static_cast<int>(0.5 * sampling_rate));
98 // Note: We can call AcceptWaveform() multiple times. 98 // Note: We can call AcceptWaveform() multiple times.
99 s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); 99 s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
100 100