Committed by
GitHub
Replace Clone() with View() (#432)
Co-authored-by: hiedean <hiedean@tju.edu.cn>
正在显示
5 个修改的文件
包含
14 行增加
和
12 行删除
| @@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | @@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 94 | // now cur_encoder_out is of shape (num_hyps, joiner_dim) | 94 | // now cur_encoder_out is of shape (num_hyps, joiner_dim) |
| 95 | 95 | ||
| 96 | Ort::Value logit = model_->RunJoiner( | 96 | Ort::Value logit = model_->RunJoiner( |
| 97 | - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | 97 | + std::move(cur_encoder_out), View(&decoder_out)); |
| 98 | 98 | ||
| 99 | float *p_logit = logit.GetTensorMutableData<float>(); | 99 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 100 | LogSoftmax(p_logit, vocab_size, num_hyps); | 100 | LogSoftmax(p_logit, vocab_size, num_hyps); |
| @@ -67,13 +67,13 @@ class OnlineRnnLM::Impl { | @@ -67,13 +67,13 @@ class OnlineRnnLM::Impl { | ||
| 67 | return {std::move(out[0]), std::move(next_states)}; | 67 | return {std::move(out[0]), std::move(next_states)}; |
| 68 | } | 68 | } |
| 69 | 69 | ||
| 70 | - std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const { | 70 | + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() { |
| 71 | std::vector<Ort::Value> ans; | 71 | std::vector<Ort::Value> ans; |
| 72 | ans.reserve(init_states_.size()); | 72 | ans.reserve(init_states_.size()); |
| 73 | - for (const auto &s : init_states_) { | ||
| 74 | - ans.emplace_back(Clone(allocator_, &s)); | 73 | + for (auto &s : init_states_) { |
| 74 | + ans.emplace_back(View(&s)); | ||
| 75 | } | 75 | } |
| 76 | - return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; | 76 | + return {View(&init_scores_.value), std::move(ans)}; |
| 77 | } | 77 | } |
| 78 | 78 | ||
| 79 | private: | 79 | private: |
| @@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 99 | } | 99 | } |
| 100 | if (is_batch_decoder_out_cached) { | 100 | if (is_batch_decoder_out_cached) { |
| 101 | auto &r = result->front(); | 101 | auto &r = result->front(); |
| 102 | - std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | 102 | + std::vector<int64_t> decoder_out_shape = |
| 103 | + r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); | ||
| 103 | decoder_out_shape[0] = batch_size; | 104 | decoder_out_shape[0] = batch_size; |
| 104 | - decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size()); | 105 | + decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), |
| 106 | + decoder_out_shape.data(), decoder_out_shape.size()); | ||
| 105 | UseCachedDecoderOut(*result, &decoder_out); | 107 | UseCachedDecoderOut(*result, &decoder_out); |
| 106 | } else { | 108 | } else { |
| 107 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); | 109 | Ort::Value decoder_input = model_->BuildDecoderInput(*result); |
| @@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 112 | Ort::Value cur_encoder_out = | 114 | Ort::Value cur_encoder_out = |
| 113 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); | 115 | GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); |
| 114 | Ort::Value logit = model_->RunJoiner( | 116 | Ort::Value logit = model_->RunJoiner( |
| 115 | - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | 117 | + std::move(cur_encoder_out), View(&decoder_out)); |
| 116 | 118 | ||
| 117 | const float *p_logit = logit.GetTensorData<float>(); | 119 | const float *p_logit = logit.GetTensorData<float>(); |
| 118 | 120 |
| @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( | ||
| 120 | cur_encoder_out = | 120 | cur_encoder_out = |
| 121 | Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); | 121 | Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); |
| 122 | Ort::Value logit = model_->RunJoiner( | 122 | Ort::Value logit = model_->RunJoiner( |
| 123 | - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | 123 | + std::move(cur_encoder_out), View(&decoder_out)); |
| 124 | 124 | ||
| 125 | float *p_logit = logit.GetTensorMutableData<float>(); | 125 | float *p_logit = logit.GetTensorMutableData<float>(); |
| 126 | LogSoftmax(p_logit, vocab_size, num_hyps); | 126 | LogSoftmax(p_logit, vocab_size, num_hyps); |
| @@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl { | @@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl { | ||
| 105 | // - attn_cache | 105 | // - attn_cache |
| 106 | // - conv_cache | 106 | // - conv_cache |
| 107 | // - offset | 107 | // - offset |
| 108 | - std::vector<Ort::Value> GetInitStates() const { | 108 | + std::vector<Ort::Value> GetInitStates() { |
| 109 | std::vector<Ort::Value> ans; | 109 | std::vector<Ort::Value> ans; |
| 110 | ans.reserve(3); | 110 | ans.reserve(3); |
| 111 | - ans.push_back(Clone(Allocator(), &attn_cache_)); | ||
| 112 | - ans.push_back(Clone(Allocator(), &conv_cache_)); | 111 | + ans.push_back(View(&attn_cache_)); |
| 112 | + ans.push_back(View(&conv_cache_)); | ||
| 113 | 113 | ||
| 114 | int64_t offset_shape = 1; | 114 | int64_t offset_shape = 1; |
| 115 | 115 |
-
请 注册 或 登录 后发表评论