HieDean
Committed by GitHub

Replace Clone() with View() (#432)

Co-authored-by: hiedean <hiedean@tju.edu.cn>
@@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( @@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
94 // now cur_encoder_out is of shape (num_hyps, joiner_dim) 94 // now cur_encoder_out is of shape (num_hyps, joiner_dim)
95 95
96 Ort::Value logit = model_->RunJoiner( 96 Ort::Value logit = model_->RunJoiner(
97 - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); 97 + std::move(cur_encoder_out), View(&decoder_out));
98 98
99 float *p_logit = logit.GetTensorMutableData<float>(); 99 float *p_logit = logit.GetTensorMutableData<float>();
100 LogSoftmax(p_logit, vocab_size, num_hyps); 100 LogSoftmax(p_logit, vocab_size, num_hyps);
@@ -67,13 +67,13 @@ class OnlineRnnLM::Impl { @@ -67,13 +67,13 @@ class OnlineRnnLM::Impl {
67 return {std::move(out[0]), std::move(next_states)}; 67 return {std::move(out[0]), std::move(next_states)};
68 } 68 }
69 69
70 - std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() const { 70 + std::pair<Ort::Value, std::vector<Ort::Value>> GetInitStates() {
71 std::vector<Ort::Value> ans; 71 std::vector<Ort::Value> ans;
72 ans.reserve(init_states_.size()); 72 ans.reserve(init_states_.size());
73 - for (const auto &s : init_states_) {  
74 - ans.emplace_back(Clone(allocator_, &s)); 73 + for (auto &s : init_states_) {
  74 + ans.emplace_back(View(&s));
75 } 75 }
76 - return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; 76 + return {View(&init_scores_.value), std::move(ans)};
77 } 77 }
78 78
79 private: 79 private:
@@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode(
99 } 99 }
100 if (is_batch_decoder_out_cached) { 100 if (is_batch_decoder_out_cached) {
101 auto &r = result->front(); 101 auto &r = result->front();
102 - std::vector<int64_t> decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); 102 + std::vector<int64_t> decoder_out_shape =
  103 + r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
103 decoder_out_shape[0] = batch_size; 104 decoder_out_shape[0] = batch_size;
104 - decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size()); 105 + decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
  106 + decoder_out_shape.data(), decoder_out_shape.size());
105 UseCachedDecoderOut(*result, &decoder_out); 107 UseCachedDecoderOut(*result, &decoder_out);
106 } else { 108 } else {
107 Ort::Value decoder_input = model_->BuildDecoderInput(*result); 109 Ort::Value decoder_input = model_->BuildDecoderInput(*result);
@@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
112 Ort::Value cur_encoder_out = 114 Ort::Value cur_encoder_out =
113 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); 115 GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
114 Ort::Value logit = model_->RunJoiner( 116 Ort::Value logit = model_->RunJoiner(
115 - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); 117 + std::move(cur_encoder_out), View(&decoder_out));
116 118
117 const float *p_logit = logit.GetTensorData<float>(); 119 const float *p_logit = logit.GetTensorData<float>();
118 120
@@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
120 cur_encoder_out = 120 cur_encoder_out =
121 Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); 121 Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
122 Ort::Value logit = model_->RunJoiner( 122 Ort::Value logit = model_->RunJoiner(
123 - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); 123 + std::move(cur_encoder_out), View(&decoder_out));
124 124
125 float *p_logit = logit.GetTensorMutableData<float>(); 125 float *p_logit = logit.GetTensorMutableData<float>();
126 LogSoftmax(p_logit, vocab_size, num_hyps); 126 LogSoftmax(p_logit, vocab_size, num_hyps);
@@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl { @@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl {
105 // - attn_cache 105 // - attn_cache
106 // - conv_cache 106 // - conv_cache
107 // - offset 107 // - offset
108 - std::vector<Ort::Value> GetInitStates() const { 108 + std::vector<Ort::Value> GetInitStates() {
109 std::vector<Ort::Value> ans; 109 std::vector<Ort::Value> ans;
110 ans.reserve(3); 110 ans.reserve(3);
111 - ans.push_back(Clone(Allocator(), &attn_cache_));  
112 - ans.push_back(Clone(Allocator(), &conv_cache_)); 111 + ans.push_back(View(&attn_cache_));
  112 + ans.push_back(View(&conv_cache_));
113 113
114 int64_t offset_shape = 1; 114 int64_t offset_shape = 1;
115 115