pytorch/caffe2/operators/while_op.h
Ilia Cherniavskii a0204331a8 Control flow operators
Summary:
This diff adds control flow operators in Caffe2 (starting with If, While):
 - Added If operator that executes then/else subnet
 - Branch subnet is executed in a separate isolated workspace, with some of the blobs transparently forwarded from the outer workspace
 - Adding a new NetBuilder subclass to construct nets using new operator
 - NetBuilder also keeps track of outer blob names and automatically sets blob bindings between outer and inner workspace, implementing generic convention on handling local/global variables in blocks

Reviewed By: volkhin

Differential Revision: D5720644

fbshipit-source-id: a674cde0c789f6a6ffdcd9d80159d1e42e49133f
2017-08-28 20:04:43 -07:00

47 lines
1.3 KiB
C++

#ifndef CAFFE2_OPERATORS_WHILE_OP_H_
#define CAFFE2_OPERATORS_WHILE_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class WhileOp final : public Operator<Context> {
public:
WhileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
CAFFE_ENFORCE(
this->template HasSingleArgumentOfType<NetDef>("loop_net"),
"loop_net must be specified in While operator");
loop_net_def_ =
this->template GetSingleArgument<NetDef>("loop_net", NetDef());
loop_net_ = ws->CreateNet(loop_net_def_, true);
CAFFE_ENFORCE(loop_net_, "Failed to initialize loop subnet");
cond_net_ = nullptr;
bool has_cond_net =
this->template HasSingleArgumentOfType<NetDef>("cond_net");
if (has_cond_net) {
cond_net_def_ =
this->template GetSingleArgument<NetDef>("cond_net", NetDef());
cond_net_ = ws->CreateNet(cond_net_def_, true);
CAFFE_ENFORCE(cond_net_, "Failed to initialize condition subnet");
}
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
private:
NetDef loop_net_def_;
NetBase* loop_net_;
NetDef cond_net_def_;
NetBase* cond_net_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_WHILE_OP_H_