Fangjun Kuang
Committed by GitHub

Use deep copy in Clone() (#66)

@@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() { @@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() {
171 } 171 }
172 172
173 private fun initModel() { 173 private fun initModel() {
  174 + // Please change getModelConfig() to add new models
  175 + // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
  176 + // for a list of available models
  177 + val type = 0
  178 + println("Select model type ${type}")
174 val config = OnlineRecognizerConfig( 179 val config = OnlineRecognizerConfig(
175 featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), 180 featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
176 - modelConfig = getModelConfig(type = 1)!!, 181 + modelConfig = getModelConfig(type = type)!!,
177 endpointConfig = getEndpointConfig(), 182 endpointConfig = getEndpointConfig(),
178 enableEndpoint = true 183 enableEndpoint = true
179 ) 184 )
@@ -63,7 +63,7 @@ class ViewController: UIViewController { @@ -63,7 +63,7 @@ class ViewController: UIViewController {
63 super.viewDidLoad() 63 super.viewDidLoad()
64 // Do any additional setup after loading the view. 64 // Do any additional setup after loading the view.
65 65
66 - resultLabel.text = "ASR with Next-gen Kaldi\n\nPress the Start button to run!" 66 + resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
67 recordBtn.setTitle("Start", for: .normal) 67 recordBtn.setTitle("Start", for: .normal)
68 initRecognizer() 68 initRecognizer()
69 initRecorder() 69 initRecorder()
@@ -37,7 +37,7 @@ template <typename T /*=float*/> @@ -37,7 +37,7 @@ template <typename T /*=float*/>
37 Ort::Value Cat(OrtAllocator *allocator, 37 Ort::Value Cat(OrtAllocator *allocator,
38 const std::vector<const Ort::Value *> &values, int32_t dim) { 38 const std::vector<const Ort::Value *> &values, int32_t dim) {
39 if (values.size() == 1u) { 39 if (values.size() == 1u) {
40 - return Clone(values[0]); 40 + return Clone(allocator, values[0]);
41 } 41 }
42 42
43 std::vector<int64_t> v0_shape = 43 std::vector<int64_t> v0_shape =
@@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( @@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
100 for (int32_t t = 0; t != num_frames; ++t) { 100 for (int32_t t = 0; t != num_frames; ++t) {
101 Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); 101 Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
102 cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); 102 cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size);
103 - Ort::Value logit =  
104 - model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); 103 + Ort::Value logit = model_->RunJoiner(
  104 + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
105 const float *p_logit = logit.GetTensorData<float>(); 105 const float *p_logit = logit.GetTensorData<float>();
106 106
107 bool emitted = false; 107 bool emitted = false;
@@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { @@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
53 } 53 }
54 } 54 }
55 55
56 -Ort::Value Clone(const Ort::Value *v) { 56 +Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
57 auto type_and_shape = v->GetTensorTypeAndShapeInfo(); 57 auto type_and_shape = v->GetTensorTypeAndShapeInfo();
58 std::vector<int64_t> shape = type_and_shape.GetShape(); 58 std::vector<int64_t> shape = type_and_shape.GetShape();
59 59
@@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) { @@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) {
61 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); 61 Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
62 62
63 switch (type_and_shape.GetElementType()) { 63 switch (type_and_shape.GetElementType()) {
64 - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:  
65 - return Ort::Value::CreateTensor(  
66 - memory_info,  
67 - const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(),  
68 - type_and_shape.GetElementCount(), shape.data(), shape.size());  
69 - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:  
70 - return Ort::Value::CreateTensor(  
71 - memory_info,  
72 - const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(),  
73 - type_and_shape.GetElementCount(), shape.data(), shape.size());  
74 - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:  
75 - return Ort::Value::CreateTensor(  
76 - memory_info,  
77 - const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(),  
78 - type_and_shape.GetElementCount(), shape.data(), shape.size()); 64 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
  65 + Ort::Value ans = Ort::Value::CreateTensor<int32_t>(
  66 + allocator, shape.data(), shape.size());
  67 + const int32_t *start = v->GetTensorData<int32_t>();
  68 + const int32_t *end = start + type_and_shape.GetElementCount();
  69 + int32_t *dst = ans.GetTensorMutableData<int32_t>();
  70 + std::copy(start, end, dst);
  71 + return ans;
  72 + }
  73 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
  74 + Ort::Value ans = Ort::Value::CreateTensor<int64_t>(
  75 + allocator, shape.data(), shape.size());
  76 + const int64_t *start = v->GetTensorData<int64_t>();
  77 + const int64_t *end = start + type_and_shape.GetElementCount();
  78 + int64_t *dst = ans.GetTensorMutableData<int64_t>();
  79 + std::copy(start, end, dst);
  80 + return ans;
  81 + }
  82 + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
  83 + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, shape.data(),
  84 + shape.size());
  85 + const float *start = v->GetTensorData<float>();
  86 + const float *end = start + type_and_shape.GetElementCount();
  87 + float *dst = ans.GetTensorMutableData<float>();
  88 + std::copy(start, end, dst);
  89 + return ans;
  90 + }
79 default: 91 default:
80 fprintf(stderr, "Unsupported type: %d\n", 92 fprintf(stderr, "Unsupported type: %d\n",
81 static_cast<int32_t>(type_and_shape.GetElementType())); 93 static_cast<int32_t>(type_and_shape.GetElementType()));
@@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, @@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
60 void PrintModelMetadata(std::ostream &os, 60 void PrintModelMetadata(std::ostream &os,
61 const Ort::ModelMetadata &meta_data); // NOLINT 61 const Ort::ModelMetadata &meta_data); // NOLINT
62 62
63 -// Return a shallow copy of v  
64 -Ort::Value Clone(const Ort::Value *v); 63 +// Return a deep copy of v
  64 +Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
65 65
66 // Print a 1-D tensor to stderr 66 // Print a 1-D tensor to stderr
67 void Print1D(Ort::Value *v); 67 void Print1D(Ort::Value *v);
@@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value, @@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
26 int32_t n = static_cast<int32_t>(shape[dim]); 26 int32_t n = static_cast<int32_t>(shape[dim]);
27 if (n == 1) { 27 if (n == 1) {
28 std::vector<Ort::Value> ans; 28 std::vector<Ort::Value> ans;
29 - ans.push_back(Clone(value)); 29 + ans.push_back(Clone(allocator, value));
30 return ans; 30 return ans;
31 } 31 }
32 32