schema_registry.h
6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/constants.h"
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/graph/onnx_protobuf.h"
#include "core/platform/ort_mutex.h"
#include <mutex>
#include <deque>
#include "sstream"
namespace onnxruntime {
using OpName_Domain_Version_Schema_Map = std::unordered_map<
std::string,
std::unordered_map<std::string, std::map<ONNX_NAMESPACE::OperatorSetVersion, ONNX_NAMESPACE::OpSchema>>>;
/**
@struct SchemaRegistryVersion
onnxruntime schema registry is a supplement to the built-in ONNX schema.
Every schema registry represent a collection of schema deltas from baseline_opset_version to opset_version
*/
struct SchemaRegistryVersion {
int baseline_opset_version;
int opset_version;
};
using DomainToVersionMap = std::unordered_map<std::string, int>;
using DomainToVersionRangeMap = std::unordered_map<std::string, SchemaRegistryVersion>;
class IOnnxRuntimeOpSchemaCollection : public ONNX_NAMESPACE::ISchemaRegistry {
public:
virtual DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const = 0;
using ISchemaRegistry::GetSchema;
const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& key, const int maxInclusiveVersion,
const std::string& domain) const final {
const ONNX_NAMESPACE::OpSchema* latest_schema = nullptr;
int earliest_opset_where_unchanged = std::numeric_limits<int>::max();
GetSchemaAndHistory(key, maxInclusiveVersion, domain, &latest_schema, &earliest_opset_where_unchanged);
assert(latest_schema == nullptr || (latest_schema->SinceVersion() <= maxInclusiveVersion &&
earliest_opset_where_unchanged == latest_schema->SinceVersion()));
return latest_schema;
}
virtual void GetSchemaAndHistory(
const std::string& key,
int maxInclusiveVersion,
const std::string& domain,
const ONNX_NAMESPACE::OpSchema** latest_schema,
int* earliest_opset_where_unchanged) const = 0;
};
/**
@class OnnxRuntimeOpSchemaRegistry
OnnxRuntimeOpSchemaRegistry is used to provide supplement for built-in ONNX schemas.
Each OnnxRuntimeOpSchemaRegistry must register complete opsets delta from a baseline version to max opset version.
(Please notice that baseline opsets are not include in the delta)
For example, ONNXRuntime is build with ONNX 1.2 which is at opset7, to use ONNX opset8 and opset9,
user could create a OnnxRuntimeOpSchemaRegistry and config it as {baseline_opset_version = 7, opset_version = 9}
it means this OnnxRuntimeOpSchemaRegistry contains the complete delta from opset7 to opset9.
*/
class OnnxRuntimeOpSchemaRegistry : public IOnnxRuntimeOpSchemaCollection {
public:
OnnxRuntimeOpSchemaRegistry() = default;
common::Status SetBaselineAndOpsetVersionForDomain(
const std::string& domain,
int baseline_opset_version,
int opset_version);
DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const override;
// OnnxRuntimeOpSchemaRegistry must register complete delta for a opset.
common::Status RegisterOpSet(
std::vector<ONNX_NAMESPACE::OpSchema>& schemas,
const std::string& domain,
int baseline_opset_version,
int opset_version);
using IOnnxRuntimeOpSchemaCollection::GetSchema;
void GetSchemaAndHistory(const std::string& key, int maxInclusiveVersion, const std::string& domain,
const ONNX_NAMESPACE::OpSchema** latest_schema,
int* earliest_opset_where_unchanged) const override;
bool empty() const {
return map_.empty();
}
private:
common::Status RegisterOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema);
common::Status RegisterOpSchemaInternal(ONNX_NAMESPACE::OpSchema&& op_schema);
OrtMutex mutex_;
OpName_Domain_Version_Schema_Map map_;
DomainToVersionRangeMap domain_version_range_map_;
};
/**
@class SchemaRegistryManager
SchemaRegistryManager provides a view based on built-in ONNX schema and a list of
OnnxRuntimeOpSchemaRegistry as supplement.
The user needs to make sure the customized schema registry is valid, otherwise the behavior is undefined.
@todo We may add more consistency checks later.
*/
class SchemaRegistryManager : public onnxruntime::IOnnxRuntimeOpSchemaCollection {
public:
/**
Register a new schema registry instance.
@remarks The schema registry priority is the reverse of registration order. i.e. the last registry added will be
searched first for a matching OpSchema.
*/
void RegisterRegistry(std::shared_ptr<IOnnxRuntimeOpSchemaCollection> registry);
/** Gets the latest opset versions.
@param is_onnx_only If true, return the latest ONNX schemas. If false, return the latest schemas for all domains.
*/
DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const override;
/** Gets the last released opset versions.
@param is_onnx_only If true, return ONNX schemas only. If false, return the schemas for all domains.
*/
DomainToVersionMap GetLastReleasedOpsetVersions(bool is_onnx_only) const ;
/**
Gets the OpSchema and its history.
Searches custom schema registries starting with the last one added. \
If the OpSchema is not found the default ONNX schema registry is searched.
@param key Operator type.
@param max_inclusive_version Maximum opset version allowed, inclusive.
@param domain The domain of the operator.
@param[out] latest_schema Returns the latest OpSchema if found. nullptr otherwise.
@param[out] earliest_opset_where_unchanged The earliest opset version preceding max_inclusive_version where the
operator is known to be unchanged.
*/
void GetSchemaAndHistory(const std::string& key, int max_inclusive_version, const std::string& domain,
const ONNX_NAMESPACE::OpSchema** latest_schema,
int* earliest_opset_where_unchanged) const override;
private:
void GetDomainToVersionMapForRegistries(DomainToVersionMap& domain_version_map, bool is_onnx_only) const;
std::deque<std::shared_ptr<IOnnxRuntimeOpSchemaCollection>> registries;
};
} // namespace onnxruntime