正在显示
9 个修改的文件
包含
105 行增加
和
126 行删除
| @@ -34,13 +34,11 @@ struct Hypothesis { | @@ -34,13 +34,11 @@ struct Hypothesis { | ||
| 34 | // LM log prob if any. | 34 | // LM log prob if any. |
| 35 | double lm_log_prob = 0; | 35 | double lm_log_prob = 0; |
| 36 | 36 | ||
| 37 | - int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM | 37 | + // the nn lm score for next token given the current ys |
| 38 | + CopyableOrtValue nn_lm_scores; | ||
| 39 | + // the nn lm states | ||
| 38 | std::vector<CopyableOrtValue> nn_lm_states; | 40 | std::vector<CopyableOrtValue> nn_lm_states; |
| 39 | 41 | ||
| 40 | - // TODO(fangjun): Make it configurable | ||
| 41 | - // the minimum of tokens in a chunk for streaming RNN LM | ||
| 42 | - int32_t lm_rescore_min_chunk = 2; // a const | ||
| 43 | - | ||
| 44 | int32_t num_trailing_blanks = 0; | 42 | int32_t num_trailing_blanks = 0; |
| 45 | 43 | ||
| 46 | Hypothesis() = default; | 44 | Hypothesis() = default; |
| @@ -13,80 +13,8 @@ | @@ -13,80 +13,8 @@ | ||
| 13 | 13 | ||
| 14 | namespace sherpa_onnx { | 14 | namespace sherpa_onnx { |
| 15 | 15 | ||
| 16 | -static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) { | ||
| 17 | - std::vector<CopyableOrtValue> ans; | ||
| 18 | - ans.reserve(values.size()); | ||
| 19 | - | ||
| 20 | - for (auto &v : values) { | ||
| 21 | - ans.emplace_back(std::move(v)); | ||
| 22 | - } | ||
| 23 | - | ||
| 24 | - return ans; | ||
| 25 | -} | ||
| 26 | - | ||
| 27 | -static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) { | ||
| 28 | - std::vector<Ort::Value> ans; | ||
| 29 | - ans.reserve(values.size()); | ||
| 30 | - | ||
| 31 | - for (auto &v : values) { | ||
| 32 | - ans.emplace_back(std::move(v.value)); | ||
| 33 | - } | ||
| 34 | - | ||
| 35 | - return ans; | ||
| 36 | -} | ||
| 37 | - | ||
| 38 | std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { | 16 | std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) { |
| 39 | return std::make_unique<OnlineRnnLM>(config); | 17 | return std::make_unique<OnlineRnnLM>(config); |
| 40 | } | 18 | } |
| 41 | 19 | ||
| 42 | -void OnlineLM::ComputeLMScore(float scale, int32_t context_size, | ||
| 43 | - std::vector<Hypotheses> *hyps) { | ||
| 44 | - Ort::AllocatorWithDefaultOptions allocator; | ||
| 45 | - | ||
| 46 | - for (auto &hyp : *hyps) { | ||
| 47 | - for (auto &h_m : hyp) { | ||
| 48 | - auto &h = h_m.second; | ||
| 49 | - auto &ys = h.ys; | ||
| 50 | - const int32_t token_num_in_chunk = | ||
| 51 | - ys.size() - context_size - h.cur_scored_pos - 1; | ||
| 52 | - | ||
| 53 | - if (token_num_in_chunk < 1) { | ||
| 54 | - continue; | ||
| 55 | - } | ||
| 56 | - | ||
| 57 | - if (h.nn_lm_states.empty()) { | ||
| 58 | - h.nn_lm_states = Convert(GetInitStates()); | ||
| 59 | - } | ||
| 60 | - | ||
| 61 | - if (token_num_in_chunk >= h.lm_rescore_min_chunk) { | ||
| 62 | - std::array<int64_t, 2> x_shape{1, token_num_in_chunk}; | ||
| 63 | - // shape of x and y are same | ||
| 64 | - Ort::Value x = Ort::Value::CreateTensor<int64_t>( | ||
| 65 | - allocator, x_shape.data(), x_shape.size()); | ||
| 66 | - Ort::Value y = Ort::Value::CreateTensor<int64_t>( | ||
| 67 | - allocator, x_shape.data(), x_shape.size()); | ||
| 68 | - int64_t *p_x = x.GetTensorMutableData<int64_t>(); | ||
| 69 | - int64_t *p_y = y.GetTensorMutableData<int64_t>(); | ||
| 70 | - std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, | ||
| 71 | - p_x); | ||
| 72 | - std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(), | ||
| 73 | - p_y); | ||
| 74 | - | ||
| 75 | - // streaming forward by NN LM | ||
| 76 | - auto out = Rescore(std::move(x), std::move(y), | ||
| 77 | - Convert(std::move(h.nn_lm_states))); | ||
| 78 | - | ||
| 79 | - // update NN LM score in hyp | ||
| 80 | - const float *p_nll = out.first.GetTensorData<float>(); | ||
| 81 | - h.lm_log_prob = -scale * (*p_nll); | ||
| 82 | - | ||
| 83 | - // update NN LM states in hyp | ||
| 84 | - h.nn_lm_states = Convert(std::move(out.second)); | ||
| 85 | - | ||
| 86 | - h.cur_scored_pos += token_num_in_chunk; | ||
| 87 | - } | ||
| 88 | - } | ||
| 89 | - } | ||
| 90 | -} | ||
| 91 | - | ||
| 92 | } // namespace sherpa_onnx | 20 | } // namespace sherpa_onnx |
| @@ -21,29 +21,27 @@ class OnlineLM { | @@ -21,29 +21,27 @@ class OnlineLM { | ||
| 21 | 21 | ||
| 22 | static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); | 22 | static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config); |
| 23 | 23 | ||
| 24 | - virtual std::vector<Ort::Value> GetInitStates() = 0; | 24 | + virtual std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() = 0; |
| 25 | 25 | ||
| 26 | - /** Rescore a batch of sentences. | 26 | + /** ScoreToken a batch of sentences. |
| 27 | * | 27 | * |
| 28 | - * @param x A 2-D tensor of shape (N, L) with data type int64. | ||
| 29 | - * @param y A 2-D tensor of shape (N, L) with data type int64. | 28 | + * @param x A 2-D tensor of shape (N, 1) with data type int64. |
| 30 | * @param states It contains the states for the LM model | 29 | * @param states It contains the states for the LM model |
| 31 | * @return Return a pair containingo | 30 | * @return Return a pair containingo |
| 32 | - * - negative loglike | 31 | + * - log_prob of NN LM |
| 33 | * - updated states | 32 | * - updated states |
| 34 | * | 33 | * |
| 35 | - * Caution: It returns negative log likelihood (nll), not log likelihood | ||
| 36 | */ | 34 | */ |
| 37 | - virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore( | ||
| 38 | - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0; | ||
| 39 | - | ||
| 40 | - // This function updates hyp.lm_lob_prob of hyps. | ||
| 41 | - // | ||
| 42 | - // @param scale LM score | ||
| 43 | - // @param context_size Context size of the transducer decoder model | ||
| 44 | - // @param hyps It is changed in-place. | ||
| 45 | - void ComputeLMScore(float scale, int32_t context_size, | ||
| 46 | - std::vector<Hypotheses> *hyps); | 35 | + virtual std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( |
| 36 | + Ort::Value x, std::vector<Ort::Value> states) = 0; | ||
| 37 | + | ||
| 38 | + /** This function updates lm_lob_prob and nn_lm_scores of hyp | ||
| 39 | + * | ||
| 40 | + * @param scale LM score | ||
| 41 | + * @param hyps It is changed in-place. | ||
| 42 | + * | ||
| 43 | + */ | ||
| 44 | + virtual void ComputeLMScore(float scale, Hypothesis *hyp) = 0; | ||
| 47 | }; | 45 | }; |
| 48 | 46 | ||
| 49 | } // namespace sherpa_onnx | 47 | } // namespace sherpa_onnx |
| @@ -26,10 +26,33 @@ class OnlineRnnLM::Impl { | @@ -26,10 +26,33 @@ class OnlineRnnLM::Impl { | ||
| 26 | Init(config); | 26 | Init(config); |
| 27 | } | 27 | } |
| 28 | 28 | ||
| 29 | - std::pair<Ort::Value, std::vector<Ort::Value>> Rescore( | ||
| 30 | - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) { | ||
| 31 | - std::array<Ort::Value, 4> inputs = { | ||
| 32 | - std::move(x), std::move(y), std::move(states[0]), std::move(states[1])}; | 29 | + void ComputeLMScore(float scale, Hypothesis *hyp) { |
| 30 | + if (hyp->nn_lm_states.empty()) { | ||
| 31 | + auto init_states = GetInitStates(); | ||
| 32 | + hyp->nn_lm_scores.value = std::move(init_states.first); | ||
| 33 | + hyp->nn_lm_states = Convert(std::move(init_states.second)); | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob | ||
| 37 | + const float *nn_lm_scores = hyp->nn_lm_scores.value.GetTensorData<float>(); | ||
| 38 | + hyp->lm_log_prob = nn_lm_scores[hyp->ys.back()] * scale; | ||
| 39 | + | ||
| 40 | + // get lm scores for next tokens given the hyp->ys[:] and save to | ||
| 41 | + // nn_lm_scores | ||
| 42 | + std::array<int64_t, 2> x_shape{1, 1}; | ||
| 43 | + Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | ||
| 44 | + x_shape.size()); | ||
| 45 | + *x.GetTensorMutableData<int64_t>() = hyp->ys.back(); | ||
| 46 | + auto lm_out = | ||
| 47 | + ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); | ||
| 48 | + hyp->nn_lm_scores.value = std::move(lm_out.first); | ||
| 49 | + hyp->nn_lm_states = Convert(std::move(lm_out.second)); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( | ||
| 53 | + Ort::Value x, std::vector<Ort::Value> states) { | ||
| 54 | + std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states[0]), | ||
| 55 | + std::move(states[1])}; | ||
| 33 | 56 | ||
| 34 | auto out = | 57 | auto out = |
| 35 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), | 58 | sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), |
| @@ -43,15 +66,13 @@ class OnlineRnnLM::Impl { | @@ -43,15 +66,13 @@ class OnlineRnnLM::Impl { | ||
| 43 | return {std::move(out[0]), std::move(next_states)}; | 66 | return {std::move(out[0]), std::move(next_states)}; |
| 44 | } | 67 | } |
| 45 | 68 | ||
| 46 | - std::vector<Ort::Value> GetInitStates() const { | 69 | + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const { |
| 47 | std::vector<Ort::Value> ans; | 70 | std::vector<Ort::Value> ans; |
| 48 | ans.reserve(init_states_.size()); | 71 | ans.reserve(init_states_.size()); |
| 49 | - | ||
| 50 | for (const auto &s : init_states_) { | 72 | for (const auto &s : init_states_) { |
| 51 | ans.emplace_back(Clone(allocator_, &s)); | 73 | ans.emplace_back(Clone(allocator_, &s)); |
| 52 | } | 74 | } |
| 53 | - | ||
| 54 | - return ans; | 75 | + return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; |
| 55 | } | 76 | } |
| 56 | 77 | ||
| 57 | private: | 78 | private: |
| @@ -86,19 +107,16 @@ class OnlineRnnLM::Impl { | @@ -86,19 +107,16 @@ class OnlineRnnLM::Impl { | ||
| 86 | Fill<float>(&h, 0); | 107 | Fill<float>(&h, 0); |
| 87 | Fill<float>(&c, 0); | 108 | Fill<float>(&c, 0); |
| 88 | std::array<int64_t, 2> x_shape{1, 1}; | 109 | std::array<int64_t, 2> x_shape{1, 1}; |
| 89 | - // shape of x and y are same | ||
| 90 | Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | 110 | Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), |
| 91 | x_shape.size()); | 111 | x_shape.size()); |
| 92 | - Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(), | ||
| 93 | - x_shape.size()); | ||
| 94 | *x.GetTensorMutableData<int64_t>() = sos_id_; | 112 | *x.GetTensorMutableData<int64_t>() = sos_id_; |
| 95 | - *y.GetTensorMutableData<int64_t>() = sos_id_; | ||
| 96 | 113 | ||
| 97 | std::vector<Ort::Value> states; | 114 | std::vector<Ort::Value> states; |
| 98 | states.push_back(std::move(h)); | 115 | states.push_back(std::move(h)); |
| 99 | states.push_back(std::move(c)); | 116 | states.push_back(std::move(c)); |
| 100 | - auto pair = Rescore(std::move(x), std::move(y), std::move(states)); | 117 | + auto pair = ScoreToken(std::move(x), std::move(states)); |
| 101 | 118 | ||
| 119 | + init_scores_.value = std::move(pair.first); | ||
| 102 | init_states_ = std::move(pair.second); | 120 | init_states_ = std::move(pair.second); |
| 103 | } | 121 | } |
| 104 | 122 | ||
| @@ -116,6 +134,7 @@ class OnlineRnnLM::Impl { | @@ -116,6 +134,7 @@ class OnlineRnnLM::Impl { | ||
| 116 | std::vector<std::string> output_names_; | 134 | std::vector<std::string> output_names_; |
| 117 | std::vector<const char *> output_names_ptr_; | 135 | std::vector<const char *> output_names_ptr_; |
| 118 | 136 | ||
| 137 | + CopyableOrtValue init_scores_; | ||
| 119 | std::vector<Ort::Value> init_states_; | 138 | std::vector<Ort::Value> init_states_; |
| 120 | 139 | ||
| 121 | int32_t rnn_num_layers_ = 2; | 140 | int32_t rnn_num_layers_ = 2; |
| @@ -128,13 +147,17 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | @@ -128,13 +147,17 @@ OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) | ||
| 128 | 147 | ||
| 129 | OnlineRnnLM::~OnlineRnnLM() = default; | 148 | OnlineRnnLM::~OnlineRnnLM() = default; |
| 130 | 149 | ||
| 131 | -std::vector<Ort::Value> OnlineRnnLM::GetInitStates() { | 150 | +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::GetInitStates() { |
| 132 | return impl_->GetInitStates(); | 151 | return impl_->GetInitStates(); |
| 133 | } | 152 | } |
| 134 | 153 | ||
| 135 | -std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore( | ||
| 136 | - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) { | ||
| 137 | - return impl_->Rescore(std::move(x), std::move(y), std::move(states)); | 154 | +std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::ScoreToken( |
| 155 | + Ort::Value x, std::vector<Ort::Value> states) { | ||
| 156 | + return impl_->ScoreToken(std::move(x), std::move(states)); | ||
| 157 | +} | ||
| 158 | + | ||
| 159 | +void OnlineRnnLM::ComputeLMScore(float scale, Hypothesis *hyp) { | ||
| 160 | + return impl_->ComputeLMScore(scale, hyp); | ||
| 138 | } | 161 | } |
| 139 | 162 | ||
| 140 | } // namespace sherpa_onnx | 163 | } // namespace sherpa_onnx |
| @@ -22,21 +22,27 @@ class OnlineRnnLM : public OnlineLM { | @@ -22,21 +22,27 @@ class OnlineRnnLM : public OnlineLM { | ||
| 22 | 22 | ||
| 23 | explicit OnlineRnnLM(const OnlineLMConfig &config); | 23 | explicit OnlineRnnLM(const OnlineLMConfig &config); |
| 24 | 24 | ||
| 25 | - std::vector<Ort::Value> GetInitStates() override; | 25 | + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() override; |
| 26 | 26 | ||
| 27 | - /** Rescore a batch of sentences. | 27 | + /** ScoreToken a batch of sentences. |
| 28 | * | 28 | * |
| 29 | * @param x A 2-D tensor of shape (N, L) with data type int64. | 29 | * @param x A 2-D tensor of shape (N, L) with data type int64. |
| 30 | - * @param y A 2-D tensor of shape (N, L) with data type int64. | ||
| 31 | * @param states It contains the states for the LM model | 30 | * @param states It contains the states for the LM model |
| 32 | * @return Return a pair containingo | 31 | * @return Return a pair containingo |
| 33 | - * - negative loglike | 32 | + * - log_prob of NN LM |
| 34 | * - updated states | 33 | * - updated states |
| 35 | * | 34 | * |
| 36 | - * Caution: It returns negative log likelihood (nll), not log likelihood | ||
| 37 | */ | 35 | */ |
| 38 | - std::pair<Ort::Value, std::vector<Ort::Value>> Rescore( | ||
| 39 | - Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override; | 36 | + std::pair<Ort::Value, std::vector<Ort::Value>> ScoreToken( |
| 37 | + Ort::Value x, std::vector<Ort::Value> states) override; | ||
| 38 | + | ||
| 39 | + /** This function updates lm_lob_prob and nn_lm_scores of hyp | ||
| 40 | + * | ||
| 41 | + * @param scale LM score | ||
| 42 | + * @param hyps It is changed in-place. | ||
| 43 | + * | ||
| 44 | + */ | ||
| 45 | + void ComputeLMScore(float scale, Hypothesis *hyp) override; | ||
| 40 | 46 | ||
| 41 | private: | 47 | private: |
| 42 | class Impl; | 48 | class Impl; |
| @@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -121,7 +121,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 121 | 121 | ||
| 122 | // add log_prob of each hypothesis to p_logprob before taking top_k | 122 | // add log_prob of each hypothesis to p_logprob before taking top_k |
| 123 | for (int32_t i = 0; i != num_hyps; ++i) { | 123 | for (int32_t i = 0; i != num_hyps; ++i) { |
| 124 | - float log_prob = prev[i].log_prob; | 124 | + float log_prob = prev[i].log_prob + prev[i].lm_log_prob; |
| 125 | for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { | 125 | for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { |
| 126 | *p_logprob += log_prob; | 126 | *p_logprob += log_prob; |
| 127 | } | 127 | } |
| @@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -141,14 +141,18 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 141 | int32_t new_token = k % vocab_size; | 141 | int32_t new_token = k % vocab_size; |
| 142 | 142 | ||
| 143 | Hypothesis new_hyp = prev[hyp_index]; | 143 | Hypothesis new_hyp = prev[hyp_index]; |
| 144 | + const float prev_lm_log_prob = new_hyp.lm_log_prob; | ||
| 144 | if (new_token != 0) { | 145 | if (new_token != 0) { |
| 145 | new_hyp.ys.push_back(new_token); | 146 | new_hyp.ys.push_back(new_token); |
| 146 | new_hyp.timestamps.push_back(t + frame_offset); | 147 | new_hyp.timestamps.push_back(t + frame_offset); |
| 147 | new_hyp.num_trailing_blanks = 0; | 148 | new_hyp.num_trailing_blanks = 0; |
| 149 | + if (lm_) { | ||
| 150 | + lm_->ComputeLMScore(lm_scale_, &new_hyp); | ||
| 151 | + } | ||
| 148 | } else { | 152 | } else { |
| 149 | ++new_hyp.num_trailing_blanks; | 153 | ++new_hyp.num_trailing_blanks; |
| 150 | } | 154 | } |
| 151 | - new_hyp.log_prob = p_logprob[k]; | 155 | + new_hyp.log_prob = p_logprob[k] - prev_lm_log_prob; |
| 152 | hyps.Add(std::move(new_hyp)); | 156 | hyps.Add(std::move(new_hyp)); |
| 153 | } // for (auto k : topk) | 157 | } // for (auto k : topk) |
| 154 | cur.push_back(std::move(hyps)); | 158 | cur.push_back(std::move(hyps)); |
| @@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -156,10 +160,6 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 156 | } // for (int32_t b = 0; b != batch_size; ++b) | 160 | } // for (int32_t b = 0; b != batch_size; ++b) |
| 157 | } | 161 | } |
| 158 | 162 | ||
| 159 | - if (lm_) { | ||
| 160 | - lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); | ||
| 161 | - } | ||
| 162 | - | ||
| 163 | for (int32_t b = 0; b != batch_size; ++b) { | 163 | for (int32_t b = 0; b != batch_size; ++b) { |
| 164 | auto &hyps = cur[b]; | 164 | auto &hyps = cur[b]; |
| 165 | auto best_hyp = hyps.GetMostProbable(true); | 165 | auto best_hyp = hyps.GetMostProbable(true); |
| @@ -245,4 +245,26 @@ CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) { | @@ -245,4 +245,26 @@ CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) { | ||
| 245 | return *this; | 245 | return *this; |
| 246 | } | 246 | } |
| 247 | 247 | ||
| 248 | +std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) { | ||
| 249 | + std::vector<CopyableOrtValue> ans; | ||
| 250 | + ans.reserve(values.size()); | ||
| 251 | + | ||
| 252 | + for (auto &v : values) { | ||
| 253 | + ans.emplace_back(std::move(v)); | ||
| 254 | + } | ||
| 255 | + | ||
| 256 | + return ans; | ||
| 257 | +} | ||
| 258 | + | ||
| 259 | +std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) { | ||
| 260 | + std::vector<Ort::Value> ans; | ||
| 261 | + ans.reserve(values.size()); | ||
| 262 | + | ||
| 263 | + for (auto &v : values) { | ||
| 264 | + ans.emplace_back(std::move(v.value)); | ||
| 265 | + } | ||
| 266 | + | ||
| 267 | + return ans; | ||
| 268 | +} | ||
| 269 | + | ||
| 248 | } // namespace sherpa_onnx | 270 | } // namespace sherpa_onnx |
| @@ -97,8 +97,8 @@ struct CopyableOrtValue { | @@ -97,8 +97,8 @@ struct CopyableOrtValue { | ||
| 97 | 97 | ||
| 98 | CopyableOrtValue() = default; | 98 | CopyableOrtValue() = default; |
| 99 | 99 | ||
| 100 | - /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT | ||
| 101 | - : value(std::move(v)) {} | 100 | + /*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT |
| 101 | + : value(std::move(v)) {} | ||
| 102 | 102 | ||
| 103 | CopyableOrtValue(const CopyableOrtValue &other); | 103 | CopyableOrtValue(const CopyableOrtValue &other); |
| 104 | 104 | ||
| @@ -109,6 +109,10 @@ struct CopyableOrtValue { | @@ -109,6 +109,10 @@ struct CopyableOrtValue { | ||
| 109 | CopyableOrtValue &operator=(CopyableOrtValue &&other); | 109 | CopyableOrtValue &operator=(CopyableOrtValue &&other); |
| 110 | }; | 110 | }; |
| 111 | 111 | ||
| 112 | +std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values); | ||
| 113 | + | ||
| 114 | +std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values); | ||
| 115 | + | ||
| 112 | } // namespace sherpa_onnx | 116 | } // namespace sherpa_onnx |
| 113 | 117 | ||
| 114 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ | 118 | #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ |
| @@ -94,7 +94,7 @@ for a list of pre-trained models to download. | @@ -94,7 +94,7 @@ for a list of pre-trained models to download. | ||
| 94 | auto s = recognizer.CreateStream(); | 94 | auto s = recognizer.CreateStream(); |
| 95 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | 95 | s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); |
| 96 | 96 | ||
| 97 | - std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate)); | 97 | + std::vector<float> tail_paddings(static_cast<int>(0.5 * sampling_rate)); |
| 98 | // Note: We can call AcceptWaveform() multiple times. | 98 | // Note: We can call AcceptWaveform() multiple times. |
| 99 | s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); | 99 | s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size()); |
| 100 | 100 |
-
请 注册 或 登录 后发表评论