mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: This diff adds control flow operators in Caffe2 (starting with If, While): - Added If operator that executes then/else subnet - Branch subnet is executed in a separate isolated workspace, with some of the blobs transparently forwarded from the outer workspace - Adding a new NetBuilder subclass to construct nets using new operator - NetBuilder also keeps track of outer blob names and automatically sets blob bindings between outer and inner workspace, implementing generic convention on handling local/global variables in blocks Reviewed By: volkhin Differential Revision: D5720644 fbshipit-source-id: a674cde0c789f6a6ffdcd9d80159d1e42e49133f
47 lines
1.3 KiB
C++
47 lines
1.3 KiB
C++
#ifndef CAFFE2_OPERATORS_WHILE_OP_H_
|
|
#define CAFFE2_OPERATORS_WHILE_OP_H_
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/core/operator.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class Context>
|
|
class WhileOp final : public Operator<Context> {
|
|
public:
|
|
WhileOp(const OperatorDef& operator_def, Workspace* ws)
|
|
: Operator<Context>(operator_def, ws) {
|
|
CAFFE_ENFORCE(
|
|
this->template HasSingleArgumentOfType<NetDef>("loop_net"),
|
|
"loop_net must be specified in While operator");
|
|
loop_net_def_ =
|
|
this->template GetSingleArgument<NetDef>("loop_net", NetDef());
|
|
loop_net_ = ws->CreateNet(loop_net_def_, true);
|
|
CAFFE_ENFORCE(loop_net_, "Failed to initialize loop subnet");
|
|
|
|
cond_net_ = nullptr;
|
|
bool has_cond_net =
|
|
this->template HasSingleArgumentOfType<NetDef>("cond_net");
|
|
if (has_cond_net) {
|
|
cond_net_def_ =
|
|
this->template GetSingleArgument<NetDef>("cond_net", NetDef());
|
|
cond_net_ = ws->CreateNet(cond_net_def_, true);
|
|
CAFFE_ENFORCE(cond_net_, "Failed to initialize condition subnet");
|
|
}
|
|
}
|
|
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
bool RunOnDevice() override;
|
|
|
|
private:
|
|
NetDef loop_net_def_;
|
|
NetBase* loop_net_;
|
|
|
|
NetDef cond_net_def_;
|
|
NetBase* cond_net_;
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_WHILE_OP_H_
|