mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
refactor caffe2 operator constructors - 10/9 (#17659)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17659 clangr codemod Reviewed By: ezyang Differential Revision: D14304675 fbshipit-source-id: 45fbd84c50651a70ae29bf46df3322715e99d225
This commit is contained in:
parent
4db3f8f806
commit
f6fda4409b
6 changed files with 24 additions and 16 deletions
|
|
@ -13,8 +13,9 @@ template <class Context, class Engine = DefaultEngine>
|
|||
class BatchMatMulOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BatchMatMulOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit BatchMatMulOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
|
||||
trans_b_(this->template GetSingleArgument<int>("trans_b", 0)),
|
||||
broadcast_(this->template GetSingleArgument<int>("broadcast", 0)) {}
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ class BatchMomentsOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
BatchMomentsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit BatchMomentsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<std::string>("order", "NCHW"))) {
|
||||
CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
|
||||
|
|
@ -61,8 +62,9 @@ class BatchMomentsGradientOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
BatchMomentsGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit BatchMomentsGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<std::string>("order", "NCHW"))) {
|
||||
CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ template <typename T, class Context>
|
|||
class BatchSparseToDenseOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BatchSparseToDenseOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit BatchSparseToDenseOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int64_t, "dense_last_dim", dense_last_dim_, -1),
|
||||
OP_SINGLE_ARG(T, "default_value", default_value_, static_cast<T>(0)) {}
|
||||
bool RunOnDevice() override;
|
||||
|
|
@ -29,8 +30,9 @@ template <typename T, class Context>
|
|||
class BatchDenseToSparseOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BatchDenseToSparseOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit BatchDenseToSparseOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
bool RunOnDevice() override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ template <class Context>
|
|||
class BisectPercentileOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BisectPercentileOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit BisectPercentileOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
pct_raw_(OperatorBase::GetRepeatedArgument<float>(
|
||||
"percentile_raw",
|
||||
vector<float>{})),
|
||||
|
|
|
|||
|
|
@ -9,8 +9,9 @@ template <class Context>
|
|||
class BooleanMaskLengthsOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BooleanMaskLengthsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit BooleanMaskLengthsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(0));
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ template <class Context>
|
|||
class BooleanMaskOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit BooleanMaskOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue