Module.hpp
4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//
// Module.hpp
// MNN
//
// Created by MNN on 2019/11/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_Train_Module_hpp
#define MNN_Train_Module_hpp
#include <vector>
#include <unordered_map>
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Executor.hpp>
#include <MNN/MNNForwardType.h>
namespace MNN {
namespace Express {
struct SubGraph;
class MNN_PUBLIC Module {
public:
Module() = default;
virtual ~Module() = default;
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) = 0;
Express::VARP forward(Express::VARP input);
std::vector<Express::VARP> parameters() const;
bool loadParameters(const std::vector<Express::VARP>& parameters);
void setIsTraining(const bool isTraining);
bool getIsTraining();
void clearCache();
const std::string& name() const {
return mName;
};
void setName(std::string name) {
mName = std::move(name);
}
const std::string type() const {
return mType;
}
void setType(std::string type) {
mType = std::move(type);
}
// Return the parameter index
int addParameter(Express::VARP parameter);
void setParameter(Express::VARP parameter, int index);
static Module* createEmpty(const std::vector<Express::VARP>& parameters);
struct BackendInfo {
MNNForwardType type = MNN_FORWARD_CPU;
BackendConfig* config = nullptr;
};
struct Config {
// Load module as dynamic, default static
bool dynamic = false;
// for static mode, if the shape is mutable, set true, otherwise set false to avoid resizeSession freqencily
bool shapeMutable = true;
// Pre-rearrange weights or not. Disabled by default.
// The weights will be rearranged in a general way, so the best implementation
// may not be adopted if `rearrange` is enabled.
bool rearrange = false;
BackendInfo* backend = nullptr;
};
static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr);
static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Config* config = nullptr);
// Shared RuntimeManager
static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Config* config = nullptr);
static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Config* config = nullptr);
static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {});
static Module* clone(const Module* module, const bool shareParams = false);
class CloneContext {
public:
CloneContext() = default;
explicit CloneContext(const bool shareParams)
: mShareParams(shareParams) {}
virtual ~CloneContext() = default;
const bool shareParams() const { return mShareParams; }
EXPRP getOrClone(const EXPRP expr);
VARP getOrClone(const VARP var);
private:
bool mShareParams = false;
std::unordered_map<const Expr*, EXPRP> mExprMap;
std::unordered_map<const Variable*, VARP> mVarMap;
};
virtual Module* clone(CloneContext* ctx) const {
return nullptr;
}
protected:
void registerModel(const std::vector<std::shared_ptr<Module>>& children);
virtual void onClearCache() {
}
Module* cloneBaseTo(CloneContext* ctx, Module* module) const;
private:
void _collectParameters(std::vector<Express::VARP>& result) const;
std::vector<std::shared_ptr<Module>> mChildren;
std::vector<Express::VARP> mParameters;
bool mIsTraining = true;
std::string mName;
std::string mType;
};
struct SubGraph {
std::vector<std::string> inputs;
std::vector<std::string> outputs;
std::shared_ptr<Module> m;
};
} // namespace Train
} // namespace MNN
#endif