PluginShapeInference.hpp
1.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
//
// ShapeInference.h
// MNN
//
// Created by MNN on 2020/04/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_
#define MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_
#include <functional>
#include <string>
#include <unordered_map>
#include <MNN/plugin/PluginContext.hpp>
namespace MNN {
namespace plugin {
class MNN_PUBLIC InferShapeKernel {
public:
virtual ~InferShapeKernel() = default;
virtual bool compute(InferShapeContext* ctx) = 0;
};
class MNN_PUBLIC InferShapeKernelRegister {
public:
// typedef InferShapeKernel* (*Factory)();
typedef std::function<InferShapeKernel*()> Factory;
static std::unordered_map<std::string, Factory>* getFactoryMap();
static bool add(const std::string& name, Factory factory);
static InferShapeKernel* get(const std::string& name);
};
template <typename PluginKernel>
struct InferShapeKernelRegistrar {
InferShapeKernelRegistrar(const std::string& name) {
InferShapeKernelRegister::add(name, []() { // NOLINT
return new PluginKernel; // NOLINT
});
}
};
#define REGISTER_PLUGIN_OP(name, inferShapeKernel) \
namespace { \
static auto _plugin_infer_shape_##name##_ __attribute__((unused)) = \
InferShapeKernelRegistrar<inferShapeKernel>(#name); \
} // namespace
} // namespace plugin
} // namespace MNN
#endif // MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_