graph_nodes.h
6.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <type_traits>
#include <vector>
namespace onnxruntime {
class Node;
/**
Class to filter out null entries from either a vector of unique_ptr<Node> or a vector of [const] Node* and
provide an iterator interface that returns [const] Node& for the valid entries.
*/
template <typename TNodesContainer>
class ValidNodes {
public:
template <typename TIterator>
class NodeIterator;
// optional filtering function to return a subset of nodes
using NodeFilterFunc = std::function<bool(NodeIndex)>;
/**
Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection
@param[in] nodes Nodes to iterate, skipping invalid entries.
*/
explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(&nodes) {}
explicit ValidNodes(TNodesContainer& nodes, NodeFilterFunc&& filter_node_fn) noexcept
: nodes_(&nodes), filter_node_fn_{std::move(filter_node_fn)} {}
using ConstNodeIterator = NodeIterator<typename TNodesContainer::const_iterator>;
using MutableNodeIterator = NodeIterator<typename TNodesContainer::iterator>;
using ConstReverseNodeIterator = NodeIterator<typename TNodesContainer::const_reverse_iterator>;
ConstNodeIterator cbegin() const noexcept {
return {nodes_->cbegin(), nodes_->cend(), filter_node_fn_};
}
ConstNodeIterator cend() const noexcept {
return {nodes_->cend(), nodes_->cend(), filter_node_fn_};
}
ConstNodeIterator begin() const noexcept {
return cbegin();
}
ConstNodeIterator end() const noexcept {
return cend();
}
ConstReverseNodeIterator rbegin() const noexcept {
return {nodes_->crbegin(), nodes_->crend(), filter_node_fn_};
}
ConstReverseNodeIterator rend() const noexcept {
return {nodes_->crend(), nodes_->crend(), filter_node_fn_};
}
// we only allow mutable access if the container is non-const.
// we need to templatize the functions for enable_if to work at this level, but mandate T2 being TNodesContainer
template <typename T2 = TNodesContainer>
typename std::enable_if<!std::is_const<T2>::value, MutableNodeIterator>::type begin() noexcept {
static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
return MutableNodeIterator(nodes_->begin(), nodes_->end(), filter_node_fn_);
}
template <typename T2 = TNodesContainer>
typename std::enable_if<!std::is_const<T2>::value, MutableNodeIterator>::type end() noexcept {
static_assert(std::is_same<T2, TNodesContainer>::value, "Explicit specialization is not allowed");
return MutableNodeIterator(nodes_->end(), nodes_->end(), filter_node_fn_);
}
bool empty() const noexcept { return nodes_->empty(); }
/**
@class NodeIterator
Iterator to provide const and non-const access to valid Node instances in a Graph.
@remarks Skips invalid nodes.
*/
template <typename TIterator>
class NodeIterator {
// get the type being returned by the iterator. can't use TIterator::value_type as that is always non-const
using IterType = typename std::remove_reference<typename std::iterator_traits<TIterator>::reference>::type;
// and determine what we will return based on its constness
using T = typename std::conditional<std::is_const<IterType>::value,
const Node, // return const Node if this is a const iterator
Node>::type; // else return Node
public:
using iterator_category = std::input_iterator_tag;
using value_type = T;
using difference_type = typename TIterator::difference_type;
using pointer = T*;
using reference = T&;
using const_reference = const T&;
/** Construct a NodeInterator and move to the first valid node. */
NodeIterator<TIterator>(const TIterator current, const TIterator end, const NodeFilterFunc& filter_fn) noexcept
: current_{current}, end_{end}, apply_filter_{filter_fn != nullptr}, filter_func_{&filter_fn} {
// skip to next valid node, stopping at end if none are found
while (current_ < end && (*current_ == nullptr ||
(apply_filter_ && (*filter_func_)((*current_)->Index()) == true))) {
++current_;
}
}
bool operator==(const NodeIterator<TIterator>& other) const noexcept {
return (current_ == other.current_);
}
bool operator!=(const NodeIterator<TIterator>& other) const noexcept {
return (current_ != other.current_);
}
void operator++() {
if (current_ < end_) {
while (++current_ != end_) {
if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false))
break;
}
}
}
NodeIterator<TIterator> operator++(int) {
NodeIterator<TIterator> tmp{*this};
++(*this);
return tmp;
}
/** Return the current Node&. This will be const if the iterator was returned from a const GraphNodes instance. */
reference operator*() {
// if iterator is valid we always have a non-nullptr node
// if this is a nullptr we're at end_ and this shouldn't be being called
return **current_;
}
pointer operator->() {
return current_->get();
}
private:
TIterator current_;
TIterator end_;
bool apply_filter_; // store whether filter_func_ is not nullptr and contains a callable
const NodeFilterFunc* filter_func_; // store as pointer so iterator is copyable
};
private:
gsl::not_null<TNodesContainer*> nodes_; // always set by ctor
// no filtering if not set. this instance owns the filter func if set.
NodeFilterFunc filter_node_fn_;
};
/**
Class that provides iteration over all valid nodes in the Graph.
*/
class GraphNodes : public ValidNodes<std::vector<std::unique_ptr<Node>>> {
public:
GraphNodes(std::vector<std::unique_ptr<Node>>& nodes) : ValidNodes(nodes) {
}
};
// Variant that only ever allows const access to nodes and optionally allows filtering of the nodes.
class ConstGraphNodes : public ValidNodes<const std::vector<std::unique_ptr<Node>>> {
public:
ConstGraphNodes(const std::vector<std::unique_ptr<Node>>& nodes) : ValidNodes(nodes) {
}
ConstGraphNodes(const std::vector<std::unique_ptr<Node>>& nodes,
GraphNodes::NodeFilterFunc&& filter_func)
: ValidNodes(nodes, std::move(filter_func)) {
}
};
} // namespace onnxruntime