onnx-utils.h
3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// sherpa-onnx/csrc/onnx-utils.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#ifdef _MSC_VER
// For ToWide() below
#include <codecvt>
#include <locale>
#endif
#include <cassert>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
/**
* Get the input names of a model.
*
* @param sess An onnxruntime session.
* @param input_names. On return, it contains the input names of the model.
* @param input_names_ptr. On return, input_names_ptr[i] contains
* input_names[i].c_str()
*/
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr);
/**
* Get the output names of a model.
*
* @param sess An onnxruntime session.
* @param output_names. On return, it contains the output names of the model.
* @param output_names_ptr. On return, output_names_ptr[i] contains
* output_names[i].c_str()
*/
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr);
/**
* Get the output frame of Encoder
*
* @param allocator allocator of onnxruntime
* @param encoder_out encoder out tensor
* @param t frame_index
*
*/
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
int32_t t);
std::string LookupCustomModelMetaData(const Ort::ModelMetadata &meta_data,
const char *key, OrtAllocator *allocator);
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
// Return a deep copy of v
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
// Return a shallow copy
Ort::Value View(Ort::Value *v);
float ComputeSum(const Ort::Value *v, int32_t n = -1);
float ComputeMean(const Ort::Value *v, int32_t n = -1);
// Print a 1-D tensor to stderr
template <typename T = float>
void Print1D(const Ort::Value *v);
// Print a 2-D tensor to stderr
template <typename T = float>
void Print2D(const Ort::Value *v);
// Print a 3-D tensor to stderr
void Print3D(const Ort::Value *v);
// Print a 4-D tensor to stderr
void Print4D(const Ort::Value *v);
void PrintShape(const Ort::Value *v);
template <typename T = float>
void Fill(Ort::Value *tensor, T value) {
auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
auto p = tensor->GetTensorMutableData<T>();
std::fill(p, p + n, value);
}
std::vector<char> ReadFile(const std::string &filename);
#if __ANDROID_API__ >= 9
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
#endif
#if __OHOS__
std::vector<char> ReadFile(NativeResourceManager *mgr,
const std::string &filename);
#endif
// TODO(fangjun): Document it
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split);
struct CopyableOrtValue {
Ort::Value value{nullptr};
CopyableOrtValue() = default;
/*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT
: value(std::move(v)) {}
CopyableOrtValue(const CopyableOrtValue &other);
CopyableOrtValue &operator=(const CopyableOrtValue &other);
CopyableOrtValue(CopyableOrtValue &&other) noexcept;
CopyableOrtValue &operator=(CopyableOrtValue &&other) noexcept;
};
std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values);
std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_