refactor caffe2 operator constructors - 10/9 (#17659)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17659

clangr codemod

Reviewed By: ezyang

Differential Revision: D14304675

fbshipit-source-id: 45fbd84c50651a70ae29bf46df3322715e99d225
This commit is contained in:
Sebastian Messmer 2019-03-06 15:08:44 -08:00 committed by Facebook Github Bot
parent 4db3f8f806
commit f6fda4409b
6 changed files with 24 additions and 16 deletions

View file

@ -13,8 +13,9 @@ template <class Context, class Engine = DefaultEngine>
class BatchMatMulOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BatchMatMulOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit BatchMatMulOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
trans_b_(this->template GetSingleArgument<int>("trans_b", 0)),
broadcast_(this->template GetSingleArgument<int>("broadcast", 0)) {}

View file

@ -12,8 +12,9 @@ class BatchMomentsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BatchMomentsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit BatchMomentsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
order_(StringToStorageOrder(
this->template GetSingleArgument<std::string>("order", "NCHW"))) {
CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
@ -61,8 +62,9 @@ class BatchMomentsGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BatchMomentsGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit BatchMomentsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
order_(StringToStorageOrder(
this->template GetSingleArgument<std::string>("order", "NCHW"))) {
CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);

View file

@ -13,8 +13,9 @@ template <typename T, class Context>
class BatchSparseToDenseOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BatchSparseToDenseOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit BatchSparseToDenseOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int64_t, "dense_last_dim", dense_last_dim_, -1),
OP_SINGLE_ARG(T, "default_value", default_value_, static_cast<T>(0)) {}
bool RunOnDevice() override;
@ -29,8 +30,9 @@ template <typename T, class Context>
class BatchDenseToSparseOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BatchDenseToSparseOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit BatchDenseToSparseOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
private:

View file

@ -13,8 +13,9 @@ template <class Context>
class BisectPercentileOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BisectPercentileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit BisectPercentileOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
pct_raw_(OperatorBase::GetRepeatedArgument<float>(
"percentile_raw",
vector<float>{})),

View file

@ -9,8 +9,9 @@ template <class Context>
class BooleanMaskLengthsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BooleanMaskLengthsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit BooleanMaskLengthsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(0));

View file

@ -12,8 +12,9 @@ template <class Context>
class BooleanMaskOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit BooleanMaskOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
};