正在显示
7 个修改的文件
包含
41 行增加
和
24 行删除
| @@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() { | @@ -171,9 +171,14 @@ class MainActivity : AppCompatActivity() { | ||
| 171 | } | 171 | } |
| 172 | 172 | ||
| 173 | private fun initModel() { | 173 | private fun initModel() { |
| 174 | + // Please change getModelConfig() to add new models | ||
| 175 | + // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||
| 176 | + // for a list of available models | ||
| 177 | + val type = 0 | ||
| 178 | + println("Select model type ${type}") | ||
| 174 | val config = OnlineRecognizerConfig( | 179 | val config = OnlineRecognizerConfig( |
| 175 | featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), | 180 | featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), |
| 176 | - modelConfig = getModelConfig(type = 1)!!, | 181 | + modelConfig = getModelConfig(type = type)!!, |
| 177 | endpointConfig = getEndpointConfig(), | 182 | endpointConfig = getEndpointConfig(), |
| 178 | enableEndpoint = true | 183 | enableEndpoint = true |
| 179 | ) | 184 | ) |
| @@ -63,7 +63,7 @@ class ViewController: UIViewController { | @@ -63,7 +63,7 @@ class ViewController: UIViewController { | ||
| 63 | super.viewDidLoad() | 63 | super.viewDidLoad() |
| 64 | // Do any additional setup after loading the view. | 64 | // Do any additional setup after loading the view. |
| 65 | 65 | ||
| 66 | - resultLabel.text = "ASR with Next-gen Kaldi\n\nPress the Start button to run!" | 66 | + resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!" |
| 67 | recordBtn.setTitle("Start", for: .normal) | 67 | recordBtn.setTitle("Start", for: .normal) |
| 68 | initRecognizer() | 68 | initRecognizer() |
| 69 | initRecorder() | 69 | initRecorder() |
| @@ -37,7 +37,7 @@ template <typename T /*=float*/> | @@ -37,7 +37,7 @@ template <typename T /*=float*/> | ||
| 37 | Ort::Value Cat(OrtAllocator *allocator, | 37 | Ort::Value Cat(OrtAllocator *allocator, |
| 38 | const std::vector<const Ort::Value *> &values, int32_t dim) { | 38 | const std::vector<const Ort::Value *> &values, int32_t dim) { |
| 39 | if (values.size() == 1u) { | 39 | if (values.size() == 1u) { |
| 40 | - return Clone(values[0]); | 40 | + return Clone(allocator, values[0]); |
| 41 | } | 41 | } |
| 42 | 42 | ||
| 43 | std::vector<int64_t> v0_shape = | 43 | std::vector<int64_t> v0_shape = |
| @@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | @@ -100,8 +100,8 @@ void OnlineTransducerGreedySearchDecoder::Decode( | ||
| 100 | for (int32_t t = 0; t != num_frames; ++t) { | 100 | for (int32_t t = 0; t != num_frames; ++t) { |
| 101 | Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); | 101 | Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); |
| 102 | cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); | 102 | cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, batch_size); |
| 103 | - Ort::Value logit = | ||
| 104 | - model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); | 103 | + Ort::Value logit = model_->RunJoiner( |
| 104 | + std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); | ||
| 105 | const float *p_logit = logit.GetTensorData<float>(); | 105 | const float *p_logit = logit.GetTensorData<float>(); |
| 106 | 106 | ||
| 107 | bool emitted = false; | 107 | bool emitted = false; |
| @@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | @@ -53,7 +53,7 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { | ||
| 53 | } | 53 | } |
| 54 | } | 54 | } |
| 55 | 55 | ||
| 56 | -Ort::Value Clone(const Ort::Value *v) { | 56 | +Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) { |
| 57 | auto type_and_shape = v->GetTensorTypeAndShapeInfo(); | 57 | auto type_and_shape = v->GetTensorTypeAndShapeInfo(); |
| 58 | std::vector<int64_t> shape = type_and_shape.GetShape(); | 58 | std::vector<int64_t> shape = type_and_shape.GetShape(); |
| 59 | 59 | ||
| @@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) { | @@ -61,21 +61,33 @@ Ort::Value Clone(const Ort::Value *v) { | ||
| 61 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | 61 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); |
| 62 | 62 | ||
| 63 | switch (type_and_shape.GetElementType()) { | 63 | switch (type_and_shape.GetElementType()) { |
| 64 | - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: | ||
| 65 | - return Ort::Value::CreateTensor( | ||
| 66 | - memory_info, | ||
| 67 | - const_cast<Ort::Value *>(v)->GetTensorMutableData<int32_t>(), | ||
| 68 | - type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 69 | - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: | ||
| 70 | - return Ort::Value::CreateTensor( | ||
| 71 | - memory_info, | ||
| 72 | - const_cast<Ort::Value *>(v)->GetTensorMutableData<int64_t>(), | ||
| 73 | - type_and_shape.GetElementCount(), shape.data(), shape.size()); | ||
| 74 | - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: | ||
| 75 | - return Ort::Value::CreateTensor( | ||
| 76 | - memory_info, | ||
| 77 | - const_cast<Ort::Value *>(v)->GetTensorMutableData<float>(), | ||
| 78 | - type_and_shape.GetElementCount(), shape.data(), shape.size()); | 64 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { |
| 65 | + Ort::Value ans = Ort::Value::CreateTensor<int32_t>( | ||
| 66 | + allocator, shape.data(), shape.size()); | ||
| 67 | + const int32_t *start = v->GetTensorData<int32_t>(); | ||
| 68 | + const int32_t *end = start + type_and_shape.GetElementCount(); | ||
| 69 | + int32_t *dst = ans.GetTensorMutableData<int32_t>(); | ||
| 70 | + std::copy(start, end, dst); | ||
| 71 | + return ans; | ||
| 72 | + } | ||
| 73 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { | ||
| 74 | + Ort::Value ans = Ort::Value::CreateTensor<int64_t>( | ||
| 75 | + allocator, shape.data(), shape.size()); | ||
| 76 | + const int64_t *start = v->GetTensorData<int64_t>(); | ||
| 77 | + const int64_t *end = start + type_and_shape.GetElementCount(); | ||
| 78 | + int64_t *dst = ans.GetTensorMutableData<int64_t>(); | ||
| 79 | + std::copy(start, end, dst); | ||
| 80 | + return ans; | ||
| 81 | + } | ||
| 82 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { | ||
| 83 | + Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, shape.data(), | ||
| 84 | + shape.size()); | ||
| 85 | + const float *start = v->GetTensorData<float>(); | ||
| 86 | + const float *end = start + type_and_shape.GetElementCount(); | ||
| 87 | + float *dst = ans.GetTensorMutableData<float>(); | ||
| 88 | + std::copy(start, end, dst); | ||
| 89 | + return ans; | ||
| 90 | + } | ||
| 79 | default: | 91 | default: |
| 80 | fprintf(stderr, "Unsupported type: %d\n", | 92 | fprintf(stderr, "Unsupported type: %d\n", |
| 81 | static_cast<int32_t>(type_and_shape.GetElementType())); | 93 | static_cast<int32_t>(type_and_shape.GetElementType())); |
| @@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | @@ -60,8 +60,8 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names, | ||
| 60 | void PrintModelMetadata(std::ostream &os, | 60 | void PrintModelMetadata(std::ostream &os, |
| 61 | const Ort::ModelMetadata &meta_data); // NOLINT | 61 | const Ort::ModelMetadata &meta_data); // NOLINT |
| 62 | 62 | ||
| 63 | -// Return a shallow copy of v | ||
| 64 | -Ort::Value Clone(const Ort::Value *v); | 63 | +// Return a deep copy of v |
| 64 | +Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); | ||
| 65 | 65 | ||
| 66 | // Print a 1-D tensor to stderr | 66 | // Print a 1-D tensor to stderr |
| 67 | void Print1D(Ort::Value *v); | 67 | void Print1D(Ort::Value *v); |
| @@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value, | @@ -26,7 +26,7 @@ std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value, | ||
| 26 | int32_t n = static_cast<int32_t>(shape[dim]); | 26 | int32_t n = static_cast<int32_t>(shape[dim]); |
| 27 | if (n == 1) { | 27 | if (n == 1) { |
| 28 | std::vector<Ort::Value> ans; | 28 | std::vector<Ort::Value> ans; |
| 29 | - ans.push_back(Clone(value)); | 29 | + ans.push_back(Clone(allocator, value)); |
| 30 | return ans; | 30 | return ans; |
| 31 | } | 31 | } |
| 32 | 32 |
-
请 注册 或 登录 后发表评论