node_arg.h
4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/basic_types.h"
#include "core/graph/onnx_protobuf.h"
namespace onnxruntime {
// Node argument definition, for both input and output,
// including arg name, arg type (contains both type and shape).
//
// Design Question: in my opinion, shape should not be part of type.
// We may align the protobuf design with our operator registry interface,
// which has type specified for each operator, but no shape. Well, shape
// should be inferred with a separate shape inference function given
// input shapes, or input tensor data sometimes.
// With shape as part of type (current protobuf design),
// 1) we'll have to split the "TypeProto" into type and shape in this internal
// representation interface so that it could be easily used when doing type
// inference and matching with operator registry.
// 2) SetType should be always called before SetShape, otherwise, SetShape()
// will fail. Because shape is located in a TypeProto.
// Thoughts?
//
/**
@class NodeArg
Class representing a data type that is input or output for a Node, including the shape if it is a Tensor.
*/
class NodeArg {
public:
/**
Construct a new NodeArg.
@param name The name to use.
@param p_arg_type Optional TypeProto specifying type and shape information.
*/
NodeArg(const std::string& name,
const ONNX_NAMESPACE::TypeProto* p_arg_type);
NodeArg(NodeArg&&) = default;
NodeArg& operator=(NodeArg&& other) = default;
/** Gets the name. */
const std::string& Name() const noexcept;
/** Gets the data type. */
ONNX_NAMESPACE::DataType Type() const noexcept;
/** Gets the TypeProto
@returns TypeProto if type is set. nullptr otherwise. */
const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept;
/** Gets the shape if NodeArg is for a Tensor.
@returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */
const ONNX_NAMESPACE::TensorShapeProto* Shape() const;
/** Return an indicator.
@returns true if NodeArg is a normal tensor with a non-empty shape or a scalar with an empty shape. Otherwise, returns false. */
bool HasTensorOrScalarShape() const;
#if !defined(ORT_MINIMAL_BUILD)
/** Sets the shape.
@remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called,
as the shape information is stored as part of TypeProto. */
void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape);
/** Clears shape info.
@remarks If there is a mismatch during shape inferencing that can't be resolved the shape info may be removed. */
void ClearShape();
/** Validate and merge type [and shape] info from input_type.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type, bool strict, bool override_types, const logging::Logger& logger);
/** Validate and merge type [and shape] info from node_arg.
@param strict If true, the shape update will fail if there are incompatible values.
If false, will be lenient and merge only shape info that can be validly processed.
@param override_types If true, resolve the two inputs or two outputs type when different
@returns Success unless there is existing type or shape info that can't be successfully updated. */
common::Status UpdateTypeAndShape(const NodeArg& node_arg, bool strict, bool override_types, const logging::Logger& logger);
#endif // !defined(ORT_MINIMAL_BUILD)
/** Gets this NodeArg as a ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }
/** Gets a flag indicating whether this NodeArg exists or not.
Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
bool Exists() const noexcept;
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg);
friend class Graph;
NodeArg(NodeArgInfo&& node_arg_info);
#if !defined(ORT_MINIMAL_BUILD)
void SetType(ONNX_NAMESPACE::DataType p_type);
void SetType(const ONNX_NAMESPACE::TypeProto& type_proto);
#endif
// Node arg PType.
ONNX_NAMESPACE::DataType type_;
// Node arg name, type and shape.
NodeArgInfo node_arg_info_;
// Flag indicates whether <*this> node arg exists or not.
bool exists_;
};
} // namespace onnxruntime