mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
refactor caffe2 operator constructors - 7/9 (#17088)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17088 clangr codemod also manually moved the constructor of a class from the .cpp file to the .h file. Reviewed By: ezyang Differential Revision: D14078531 fbshipit-source-id: 2adb4ac0ce523742da6cce3bc3b6c177b816c299
This commit is contained in:
parent
42512242cc
commit
8db403b9dc
25 changed files with 136 additions and 111 deletions
|
|
@ -14,9 +14,8 @@ namespace int8 {
|
|||
|
||||
class Int8SoftmaxOp final : public Operator<CPUContext> {
|
||||
public:
|
||||
Int8SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws),
|
||||
ws_(ws) {}
|
||||
explicit Int8SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
|
||||
|
||||
~Int8SoftmaxOp() {
|
||||
if (this->qnnpackOperator_ != nullptr) {
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context, bool FIRSTDIMS>
|
||||
class MaxReduceDimsOp final : public Operator<Context> {
|
||||
public:
|
||||
MaxReduceDimsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit MaxReduceDimsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
num_reduce_dims_(
|
||||
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
|
||||
|
||||
|
|
@ -78,8 +79,9 @@ class MaxReduceDimsOp final : public Operator<Context> {
|
|||
template <typename T, class Context, bool FIRSTDIMS>
|
||||
class MaxReduceDimsGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
MaxReduceDimsGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit MaxReduceDimsGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
num_reduce_dims_(
|
||||
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ namespace caffe2 {
|
|||
template <class Context, bool FIRSTDIMS, bool NORMALIZE>
|
||||
class SumReduceDimsOp final : public Operator<Context> {
|
||||
public:
|
||||
SumReduceDimsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SumReduceDimsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
num_reduce_dims_(
|
||||
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
|
||||
|
||||
|
|
@ -86,8 +87,9 @@ class SumReduceDimsOp final : public Operator<Context> {
|
|||
template <class Context, bool FIRSTDIMS, bool NORMALIZE>
|
||||
class SumReduceDimsGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
SumReduceDimsGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SumReduceDimsGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
num_reduce_dims_(
|
||||
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@ class ReduceOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
ReduceOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ReduceOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
axes_(this->template GetRepeatedArgument<int>("axes")),
|
||||
OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
|
||||
|
||||
|
|
@ -84,8 +85,9 @@ class ReduceGradientOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
ReduceGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ReduceGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
axes_(this->template GetRepeatedArgument<int>("axes")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@ class SumElementsOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
SumElementsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SumElementsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
average_(this->template GetSingleArgument<bool>("average", false)) {}
|
||||
SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
|
||||
explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
|
||||
: Operator<Context>(operator_def, ws), average_(average) {}
|
||||
~SumElementsOp() {}
|
||||
|
||||
|
|
@ -51,8 +52,9 @@ class SumElementsIntOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
SumElementsIntOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit SumElementsIntOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
~SumElementsIntOp() {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
@ -74,13 +76,11 @@ class SumElementsGradientOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit SumElementsGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
average_(this->template GetSingleArgument<bool>("average", false)) {}
|
||||
SumElementsGradientOp(
|
||||
const OperatorDef& operator_def,
|
||||
Workspace* ws,
|
||||
bool average)
|
||||
explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws, bool average)
|
||||
: Operator<Context>(operator_def, ws), average_(average) {}
|
||||
~SumElementsGradientOp() {}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ template <class Context>
|
|||
class ReplaceNaNOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ReplaceNaNOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit ReplaceNaNOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ template <typename F, class Context>
|
|||
class ReshapeOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ReshapeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ReshapeOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
new_shape_(this->template GetRepeatedArgument<int64_t>("shape")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -9,8 +9,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class ResizeNearestOp final : public Operator<Context> {
|
||||
public:
|
||||
ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ResizeNearestOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
width_scale_(1),
|
||||
height_scale_(1),
|
||||
order_(StringToStorageOrder(
|
||||
|
|
@ -45,8 +46,9 @@ class ResizeNearestOp final : public Operator<Context> {
|
|||
template <typename T, class Context>
|
||||
class ResizeNearestGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
ResizeNearestGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ResizeNearestGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
width_scale_(1),
|
||||
height_scale_(1),
|
||||
order_(StringToStorageOrder(
|
||||
|
|
|
|||
|
|
@ -9,8 +9,9 @@ namespace caffe2 {
|
|||
template <class Context>
|
||||
class RMACRegionsOp final : public Operator<Context> {
|
||||
public:
|
||||
RMACRegionsOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit RMACRegionsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
scales_(this->template GetSingleArgument<int>("scales", 3)),
|
||||
overlap_(this->template GetSingleArgument<float>("overlap", 0.4f)) {}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
prefix_ = this->template GetSingleArgument<std::string>("prefix", "rnn");
|
||||
ws_ = ws;
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ template <class Context>
|
|||
class RecurrentNetworkOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
RecurrentNetworkOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit RecurrentNetworkOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
sharedWs_(ws),
|
||||
enable_rnn_executor_(this->template GetSingleArgument<bool>(
|
||||
|
|
@ -416,7 +416,7 @@ template <class Context>
|
|||
class RecurrentNetworkGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
RecurrentNetworkGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
explicit RecurrentNetworkGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
sharedWs_(ws),
|
||||
enable_rnn_executor_(this->template GetSingleArgument<bool>(
|
||||
|
|
@ -425,8 +425,8 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
|||
timestep_(this->template GetSingleArgument<std::string>(
|
||||
"timestep",
|
||||
"timestep")),
|
||||
gradInputs_(this->template GetRepeatedArgument<int32_t>(
|
||||
"outputs_with_grads")) {
|
||||
gradInputs_(
|
||||
this->template GetRepeatedArgument<int32_t>("outputs_with_grads")) {
|
||||
CAFFE_ENFORCE(ws);
|
||||
|
||||
stepNetDef_ = detail::extractNetDef(operator_def, "backward_step_net");
|
||||
|
|
@ -849,8 +849,9 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
|
|||
template <class Context>
|
||||
class AccumulateInputGradientOp : public Operator<Context> {
|
||||
public:
|
||||
AccumulateInputGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit AccumulateInputGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
offset_(this->template GetSingleArgument<int>("offset", -1)) {
|
||||
CAFFE_ENFORCE(offset_ >= 0, "Offset not set");
|
||||
}
|
||||
|
|
@ -893,8 +894,9 @@ class AccumulateInputGradientOp : public Operator<Context> {
|
|||
template <class Context>
|
||||
class RNNApplyLinkOp : public Operator<Context> {
|
||||
public:
|
||||
RNNApplyLinkOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit RNNApplyLinkOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
offset_(this->template GetSingleArgument<int>("offset", -1)),
|
||||
window_(this->template GetSingleArgument<int>("window", -1)) {
|
||||
CAFFE_ENFORCE(offset_ >= 0, "offset not set");
|
||||
|
|
|
|||
|
|
@ -33,17 +33,6 @@ TensorDescriptors<T>::~TensorDescriptors() {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RecurrentBaseOp<T>::RecurrentBaseOp(
|
||||
const OperatorDef& operator_def,
|
||||
Workspace* ws)
|
||||
: Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
|
||||
CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
|
||||
CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
|
||||
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RecurrentBaseOp<T>::~RecurrentBaseOp() {
|
||||
CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropoutDesc_));
|
||||
|
|
|
|||
|
|
@ -32,7 +32,13 @@ template <typename T>
|
|||
class RecurrentBaseOp : public Operator<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
|
||||
template<class... Args> explicit RecurrentBaseOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...), cudnn_wrapper_(&context_) {
|
||||
CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
|
||||
CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
|
||||
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
|
||||
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
|
||||
}
|
||||
virtual ~RecurrentBaseOp();
|
||||
|
||||
protected:
|
||||
|
|
@ -84,8 +90,9 @@ template <typename T>
|
|||
class RecurrentOp : public RecurrentBaseOp<T> {
|
||||
public:
|
||||
USE_RECURRENT_BASE_FUNCTIONS
|
||||
RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: RecurrentBaseOp<T>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit RecurrentOp(Args&&... args)
|
||||
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
|
|
@ -100,8 +107,9 @@ template <typename T, RecurrentParamOpMode mode>
|
|||
class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
|
||||
public:
|
||||
USE_RECURRENT_BASE_FUNCTIONS
|
||||
RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: RecurrentBaseOp<T>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit RecurrentParamAccessOp(Args&&... args)
|
||||
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
};
|
||||
|
|
@ -110,8 +118,9 @@ template <typename T>
|
|||
class RecurrentGradientOp : public RecurrentBaseOp<T> {
|
||||
public:
|
||||
USE_RECURRENT_BASE_FUNCTIONS
|
||||
RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: RecurrentBaseOp<T>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit RecurrentGradientOp(Args&&... args)
|
||||
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class RoIAlignGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
RoIAlignGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit RoIAlignGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
spatial_scale_(
|
||||
this->template GetSingleArgument<float>("spatial_scale", 1.)),
|
||||
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class RoIAlignOp final : public Operator<Context> {
|
||||
public:
|
||||
RoIAlignOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit RoIAlignOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
spatial_scale_(
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class RoIAlignRotatedGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
RoIAlignRotatedGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit RoIAlignRotatedGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
spatial_scale_(
|
||||
this->template GetSingleArgument<float>("spatial_scale", 1.)),
|
||||
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class RoIAlignRotatedOp final : public Operator<Context> {
|
||||
public:
|
||||
RoIAlignRotatedOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit RoIAlignRotatedOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
spatial_scale_(
|
||||
|
|
|
|||
|
|
@ -11,9 +11,11 @@ namespace caffe2 {
|
|||
template <typename T, class Context>
|
||||
class RoIPoolOp final : public Operator<Context> {
|
||||
public:
|
||||
RoIPoolOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
is_test_(this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
|
||||
template <class... Args>
|
||||
explicit RoIPoolOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
is_test_(
|
||||
this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
|
||||
order_(StringToStorageOrder(
|
||||
this->template GetSingleArgument<string>("order", "NCHW"))),
|
||||
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
|
||||
|
|
@ -44,8 +46,9 @@ class RoIPoolOp final : public Operator<Context> {
|
|||
template <typename T, class Context>
|
||||
class RoIPoolGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
RoIPoolGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws),
|
||||
template <class... Args>
|
||||
explicit RoIPoolGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
spatial_scale_(
|
||||
this->template GetSingleArgument<float>("spatial_scale", 1.)),
|
||||
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ template <class Context>
|
|||
class ScaleOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ScaleOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ScaleOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
scale_(this->template GetSingleArgument<float>("scale", 1.0)) {}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -282,8 +282,9 @@ class AbstractReduceFrontOrBackOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
AbstractReduceFrontOrBackOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit AbstractReduceFrontOrBackOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "num_reduce_dim", num_reduce_dims_, 1) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
@ -353,10 +354,9 @@ class AbstractReduceFrontOrBackGradientOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
AbstractReduceFrontOrBackGradientOp(
|
||||
const OperatorDef& operator_def,
|
||||
Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit AbstractReduceFrontOrBackGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "num_reduce_dim", num_reduce_dims_, 1) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
@ -988,8 +988,9 @@ class AbstractUnsortedSegmentOp : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
AbstractUnsortedSegmentOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit AbstractUnsortedSegmentOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "num_segments", num_segments_, -1) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ class SeluOp final : public Operator<Context> {
|
|||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
SeluOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit SeluOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {
|
||||
alpha_ = this->template GetSingleArgument<T>(
|
||||
"alpha", 1.6732632423543772848170429916717f);
|
||||
lambda_ = this->template GetSingleArgument<T>(
|
||||
|
|
@ -35,8 +36,9 @@ template <typename T, class Context>
|
|||
class SeluGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
SeluGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
template <class... Args>
|
||||
explicit SeluGradientOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {
|
||||
alpha_ = this->template GetSingleArgument<T>(
|
||||
"alpha", 1.6732632423543772848170429916717f);
|
||||
lambda_ = this->template GetSingleArgument<T>(
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ template <class Context>
|
|||
class GatherPaddingOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
GatherPaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit GatherPaddingOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
startPaddingWidth_(
|
||||
this->template GetSingleArgument<int>("padding_width", 1)),
|
||||
endPaddingWidth_(
|
||||
|
|
@ -106,8 +107,9 @@ template <class Context>
|
|||
class RemovePaddingOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
RemovePaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit RemovePaddingOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
startPaddingWidth_(
|
||||
this->template GetSingleArgument<int>("padding_width", 1)),
|
||||
endPaddingWidth_(
|
||||
|
|
@ -146,8 +148,9 @@ template <class Context>
|
|||
class AddPaddingOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
AddPaddingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit AddPaddingOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
startPaddingWidth_(
|
||||
this->template GetSingleArgument<int>("padding_width", 1)),
|
||||
endPaddingWidth_(
|
||||
|
|
@ -247,8 +250,9 @@ template <class Context>
|
|||
class PadEmptySamplesOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
PadEmptySamplesOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
template <class... Args>
|
||||
explicit PadEmptySamplesOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ template <class Context>
|
|||
class ShapeOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ShapeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
template <class... Args>
|
||||
explicit ShapeOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
axes_(OperatorBase ::GetRepeatedArgument<int>("axes")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -16,14 +16,13 @@ namespace caffe2 {
|
|||
template <class Context>
|
||||
class SinusoidPositionEncodingOp : public Operator<Context> {
|
||||
public:
|
||||
SinusoidPositionEncodingOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws),
|
||||
embedding_size_(this->template GetSingleArgument<int>(
|
||||
"embedding_size",
|
||||
100)),
|
||||
template <class... Args>
|
||||
explicit SinusoidPositionEncodingOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
embedding_size_(
|
||||
this->template GetSingleArgument<int>("embedding_size", 100)),
|
||||
alpha_(this->template GetSingleArgument<float>("alpha", 10000)),
|
||||
amplitude_(
|
||||
this->template GetSingleArgument<float>("amplitude", 1)) {}
|
||||
amplitude_(this->template GetSingleArgument<float>("amplitude", 1)) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
|
|
|||
|
|
@ -235,8 +235,8 @@ template<>
|
|||
class SliceOp<CUDAContext> : public Operator<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
SliceOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CUDAContext>(operator_def, ws),
|
||||
template<class... Args> explicit SliceOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
||||
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
|
||||
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
|
||||
statically_inited_(false) {}
|
||||
|
|
@ -296,8 +296,8 @@ template <>
|
|||
class SliceGradientOp<CUDAContext> : public Operator<CUDAContext> {
|
||||
public:
|
||||
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
||||
SliceGradientOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<CUDAContext>(operator_def, ws),
|
||||
template<class... Args> explicit SliceGradientOp(Args&&... args)
|
||||
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
||||
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
|
||||
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
|
||||
statically_inited_(false) {}
|
||||
|
|
|
|||
Loading…
Reference in a new issue