onnx-utils.h
3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// sherpa-onnx/csrc/onnx-utils.h
//
// Copyright (c) 2023 Xiaomi Corporation
// Copyright (c) 2023 Pingfeng Luo
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#ifdef _MSC_VER
// For ToWide() below
#include <codecvt>
#include <locale>
#endif
#include <cassert>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
/**
* Get the input names of a model.
*
* @param sess An onnxruntime session.
* @param input_names. On return, it contains the input names of the model.
* @param input_names_ptr. On return, input_names_ptr[i] contains
* input_names[i].c_str()
*/
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr);
/**
* Get the output names of a model.
*
* @param sess An onnxruntime session.
* @param output_names. On return, it contains the output names of the model.
* @param output_names_ptr. On return, output_names_ptr[i] contains
* output_names[i].c_str()
*/
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr);
/**
* Get the output frame of Encoder
*
* @param allocator allocator of onnxruntime
* @param encoder_out encoder out tensor
* @param t frame_index
*
*/
Ort::Value GetEncoderOutFrame(OrtAllocator *allocator, Ort::Value *encoder_out,
int32_t t);
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
// Return a deep copy of v
Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
// Return a shallow copy
Ort::Value View(Ort::Value *v);
// Print a 1-D tensor to stderr
void Print1D(Ort::Value *v);
// Print a 2-D tensor to stderr
template <typename T = float>
void Print2D(Ort::Value *v);
// Print a 3-D tensor to stderr
void Print3D(Ort::Value *v);
// Print a 4-D tensor to stderr
void Print4D(Ort::Value *v);
template <typename T = float>
void Fill(Ort::Value *tensor, T value) {
auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
auto p = tensor->GetTensorMutableData<T>();
std::fill(p, p + n, value);
}
std::vector<char> ReadFile(const std::string &filename);
#if __ANDROID_API__ >= 9
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
#endif
// TODO(fangjun): Document it
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split);
struct CopyableOrtValue {
Ort::Value value{nullptr};
CopyableOrtValue() = default;
/*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT
: value(std::move(v)) {}
CopyableOrtValue(const CopyableOrtValue &other);
CopyableOrtValue &operator=(const CopyableOrtValue &other);
CopyableOrtValue(CopyableOrtValue &&other);
CopyableOrtValue &operator=(CopyableOrtValue &&other);
};
std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values);
std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_