onnx-utils.h
1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// sherpa-onnx/csrc/onnx-utils.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
#ifdef _MSC_VER
// For ToWide() below
#include <codecvt>
#include <locale>
#endif
#include <ostream>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
#ifdef _MSC_VER
// See
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
static std::wstring ToWide(const std::string &s) {
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
return converter.from_bytes(s);
}
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
#else
#define SHERPA_MAYBE_WIDE(s) s
#endif
/**
* Get the input names of a model.
*
* @param sess An onnxruntime session.
* @param input_names. On return, it contains the input names of the model.
* @param input_names_ptr. On return, input_names_ptr[i] contains
* input_names[i].c_str()
*/
void GetInputNames(Ort::Session *sess, std::vector<std::string> *input_names,
std::vector<const char *> *input_names_ptr);
/**
* Get the output names of a model.
*
* @param sess An onnxruntime session.
* @param output_names. On return, it contains the output names of the model.
* @param output_names_ptr. On return, output_names_ptr[i] contains
* output_names[i].c_str()
*/
void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
std::vector<const char *> *output_names_ptr);
void PrintModelMetadata(std::ostream &os,
const Ort::ModelMetadata &meta_data); // NOLINT
// Return a shallow copy of v
Ort::Value Clone(Ort::Value *v);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_