refactor caffe2 operator constructors - 7/9 (#17088)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17088

clangr codemod

also manually moved the constructor of a class from the .cpp file to the .h file.

Reviewed By: ezyang

Differential Revision: D14078531

fbshipit-source-id: 2adb4ac0ce523742da6cce3bc3b6c177b816c299
This commit is contained in:
Sebastian Messmer 2019-02-28 14:04:06 -08:00 committed by Facebook Github Bot
parent 42512242cc
commit 8db403b9dc
25 changed files with 136 additions and 111 deletions

View file

@ -14,9 +14,8 @@ namespace int8 {
class Int8SoftmaxOp final : public Operator<CPUContext> {
public:
Int8SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
ws_(ws) {}
explicit Int8SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws), ws_(ws) {}
~Int8SoftmaxOp() {
if (this->qnnpackOperator_ != nullptr) {

View file

@ -11,8 +11,9 @@ namespace caffe2 {
template <typename T, class Context, bool FIRSTDIMS>
class MaxReduceDimsOp final : public Operator<Context> {
public:
MaxReduceDimsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit MaxReduceDimsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
num_reduce_dims_(
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
@ -78,8 +79,9 @@ class MaxReduceDimsOp final : public Operator<Context> {
template <typename T, class Context, bool FIRSTDIMS>
class MaxReduceDimsGradientOp final : public Operator<Context> {
public:
MaxReduceDimsGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit MaxReduceDimsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
num_reduce_dims_(
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}

View file

@ -11,8 +11,9 @@ namespace caffe2 {
template <class Context, bool FIRSTDIMS, bool NORMALIZE>
class SumReduceDimsOp final : public Operator<Context> {
public:
SumReduceDimsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SumReduceDimsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
num_reduce_dims_(
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
@ -86,8 +87,9 @@ class SumReduceDimsOp final : public Operator<Context> {
template <class Context, bool FIRSTDIMS, bool NORMALIZE>
class SumReduceDimsGradientOp final : public Operator<Context> {
public:
SumReduceDimsGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SumReduceDimsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
num_reduce_dims_(
this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}

View file

@ -17,8 +17,9 @@ class ReduceOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ReduceOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ReduceOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axes_(this->template GetRepeatedArgument<int>("axes")),
OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
@ -84,8 +85,9 @@ class ReduceGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ReduceGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ReduceGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axes_(this->template GetRepeatedArgument<int>("axes")) {}
bool RunOnDevice() override {

View file

@ -14,10 +14,11 @@ class SumElementsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SumElementsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
average_(this->template GetSingleArgument<bool>("average", false)) {}
SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
: Operator<Context>(operator_def, ws), average_(average) {}
~SumElementsOp() {}
@ -51,8 +52,9 @@ class SumElementsIntOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsIntOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit SumElementsIntOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
~SumElementsIntOp() {}
bool RunOnDevice() override {
@ -74,13 +76,11 @@ class SumElementsGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit SumElementsGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
average_(this->template GetSingleArgument<bool>("average", false)) {}
SumElementsGradientOp(
const OperatorDef& operator_def,
Workspace* ws,
bool average)
explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws, bool average)
: Operator<Context>(operator_def, ws), average_(average) {}
~SumElementsGradientOp() {}

View file

@ -12,8 +12,9 @@ template <class Context>
class ReplaceNaNOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ReplaceNaNOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit ReplaceNaNOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));

View file

@ -14,8 +14,9 @@ template <typename F, class Context>
class ReshapeOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ReshapeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ReshapeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
new_shape_(this->template GetRepeatedArgument<int64_t>("shape")) {}
bool RunOnDevice() override {

View file

@ -9,8 +9,9 @@ namespace caffe2 {
template <typename T, class Context>
class ResizeNearestOp final : public Operator<Context> {
public:
ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ResizeNearestOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
width_scale_(1),
height_scale_(1),
order_(StringToStorageOrder(
@ -45,8 +46,9 @@ class ResizeNearestOp final : public Operator<Context> {
template <typename T, class Context>
class ResizeNearestGradientOp final : public Operator<Context> {
public:
ResizeNearestGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ResizeNearestGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
width_scale_(1),
height_scale_(1),
order_(StringToStorageOrder(

View file

@ -9,8 +9,9 @@ namespace caffe2 {
template <class Context>
class RMACRegionsOp final : public Operator<Context> {
public:
RMACRegionsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit RMACRegionsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
scales_(this->template GetSingleArgument<int>("scales", 3)),
overlap_(this->template GetSingleArgument<float>("overlap", 0.4f)) {}

View file

@ -17,7 +17,7 @@ class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
explicit RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
prefix_ = this->template GetSingleArgument<std::string>("prefix", "rnn");
ws_ = ws;

View file

@ -182,7 +182,7 @@ template <class Context>
class RecurrentNetworkOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RecurrentNetworkOp(const OperatorDef& operator_def, Workspace* ws)
explicit RecurrentNetworkOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
sharedWs_(ws),
enable_rnn_executor_(this->template GetSingleArgument<bool>(
@ -416,7 +416,7 @@ template <class Context>
class RecurrentNetworkGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RecurrentNetworkGradientOp(const OperatorDef& operator_def, Workspace* ws)
explicit RecurrentNetworkGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
sharedWs_(ws),
enable_rnn_executor_(this->template GetSingleArgument<bool>(
@ -425,8 +425,8 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
timestep_(this->template GetSingleArgument<std::string>(
"timestep",
"timestep")),
gradInputs_(this->template GetRepeatedArgument<int32_t>(
"outputs_with_grads")) {
gradInputs_(
this->template GetRepeatedArgument<int32_t>("outputs_with_grads")) {
CAFFE_ENFORCE(ws);
stepNetDef_ = detail::extractNetDef(operator_def, "backward_step_net");
@ -849,8 +849,9 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
template <class Context>
class AccumulateInputGradientOp : public Operator<Context> {
public:
AccumulateInputGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit AccumulateInputGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
offset_(this->template GetSingleArgument<int>("offset", -1)) {
CAFFE_ENFORCE(offset_ >= 0, "Offset not set");
}
@ -893,8 +894,9 @@ class AccumulateInputGradientOp : public Operator<Context> {
template <class Context>
class RNNApplyLinkOp : public Operator<Context> {
public:
RNNApplyLinkOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit RNNApplyLinkOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
offset_(this->template GetSingleArgument<int>("offset", -1)),
window_(this->template GetSingleArgument<int>("window", -1)) {
CAFFE_ENFORCE(offset_ >= 0, "offset not set");

View file

@ -33,17 +33,6 @@ TensorDescriptors<T>::~TensorDescriptors() {
}
}
template <typename T>
RecurrentBaseOp<T>::RecurrentBaseOp(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
}
template <typename T>
RecurrentBaseOp<T>::~RecurrentBaseOp() {
CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropoutDesc_));

View file

@ -32,7 +32,13 @@ template <typename T>
class RecurrentBaseOp : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
template<class... Args> explicit RecurrentBaseOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...), cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
}
virtual ~RecurrentBaseOp();
protected:
@ -84,8 +90,9 @@ template <typename T>
class RecurrentOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
: RecurrentBaseOp<T>(operator_def, ws) {}
template <class... Args>
explicit RecurrentOp(Args&&... args)
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
@ -100,8 +107,9 @@ template <typename T, RecurrentParamOpMode mode>
class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
: RecurrentBaseOp<T>(operator_def, ws) {}
template <class... Args>
explicit RecurrentParamAccessOp(Args&&... args)
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
};
@ -110,8 +118,9 @@ template <typename T>
class RecurrentGradientOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
: RecurrentBaseOp<T>(operator_def, ws) {}
template <class... Args>
explicit RecurrentGradientOp(Args&&... args)
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;

View file

@ -12,8 +12,9 @@ namespace caffe2 {
template <typename T, class Context>
class RoIAlignGradientOp final : public Operator<Context> {
public:
RoIAlignGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit RoIAlignGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
spatial_scale_(
this->template GetSingleArgument<float>("spatial_scale", 1.)),
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),

View file

@ -12,8 +12,9 @@ namespace caffe2 {
template <typename T, class Context>
class RoIAlignOp final : public Operator<Context> {
public:
RoIAlignOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit RoIAlignOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
spatial_scale_(

View file

@ -12,8 +12,9 @@ namespace caffe2 {
template <typename T, class Context>
class RoIAlignRotatedGradientOp final : public Operator<Context> {
public:
RoIAlignRotatedGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit RoIAlignRotatedGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
spatial_scale_(
this->template GetSingleArgument<float>("spatial_scale", 1.)),
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),

View file

@ -12,8 +12,9 @@ namespace caffe2 {
template <typename T, class Context>
class RoIAlignRotatedOp final : public Operator<Context> {
public:
RoIAlignRotatedOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit RoIAlignRotatedOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
spatial_scale_(

View file

@ -11,9 +11,11 @@ namespace caffe2 {
template <typename T, class Context>
class RoIPoolOp final : public Operator<Context> {
public:
RoIPoolOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
is_test_(this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
template <class... Args>
explicit RoIPoolOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
is_test_(
this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
order_(StringToStorageOrder(
this->template GetSingleArgument<string>("order", "NCHW"))),
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
@ -44,8 +46,9 @@ class RoIPoolOp final : public Operator<Context> {
template <typename T, class Context>
class RoIPoolGradientOp final : public Operator<Context> {
public:
RoIPoolGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws),
template <class... Args>
explicit RoIPoolGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
spatial_scale_(
this->template GetSingleArgument<float>("spatial_scale", 1.)),
pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),

View file

@ -11,8 +11,9 @@ template <class Context>
class ScaleOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ScaleOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ScaleOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
scale_(this->template GetSingleArgument<float>("scale", 1.0)) {}
template <typename T>

View file

@ -282,8 +282,9 @@ class AbstractReduceFrontOrBackOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AbstractReduceFrontOrBackOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit AbstractReduceFrontOrBackOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "num_reduce_dim", num_reduce_dims_, 1) {}
bool RunOnDevice() override {
@ -353,10 +354,9 @@ class AbstractReduceFrontOrBackGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AbstractReduceFrontOrBackGradientOp(
const OperatorDef& operator_def,
Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit AbstractReduceFrontOrBackGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "num_reduce_dim", num_reduce_dims_, 1) {}
bool RunOnDevice() override {
@ -988,8 +988,9 @@ class AbstractUnsortedSegmentOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AbstractUnsortedSegmentOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit AbstractUnsortedSegmentOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "num_segments", num_segments_, -1) {}
bool RunOnDevice() override {

View file

@ -13,8 +13,9 @@ class SeluOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SeluOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
template <class... Args>
explicit SeluOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
alpha_ = this->template GetSingleArgument<T>(
"alpha", 1.6732632423543772848170429916717f);
lambda_ = this->template GetSingleArgument<T>(
@ -35,8 +36,9 @@ template <typename T, class Context>
class SeluGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SeluGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
template <class... Args>
explicit SeluGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {
alpha_ = this->template GetSingleArgument<T>(
"alpha", 1.6732632423543772848170429916717f);
lambda_ = this->template GetSingleArgument<T>(

View file

@ -11,8 +11,9 @@ template <class Context>
class GatherPaddingOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
GatherPaddingOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit GatherPaddingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
startPaddingWidth_(
this->template GetSingleArgument<int>("padding_width", 1)),
endPaddingWidth_(
@ -106,8 +107,9 @@ template <class Context>
class RemovePaddingOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
RemovePaddingOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit RemovePaddingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
startPaddingWidth_(
this->template GetSingleArgument<int>("padding_width", 1)),
endPaddingWidth_(
@ -146,8 +148,9 @@ template <class Context>
class AddPaddingOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
AddPaddingOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit AddPaddingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
startPaddingWidth_(
this->template GetSingleArgument<int>("padding_width", 1)),
endPaddingWidth_(
@ -247,8 +250,9 @@ template <class Context>
class PadEmptySamplesOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
PadEmptySamplesOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
template <class... Args>
explicit PadEmptySamplesOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
};

View file

@ -13,8 +13,9 @@ template <class Context>
class ShapeOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ShapeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
template <class... Args>
explicit ShapeOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axes_(OperatorBase ::GetRepeatedArgument<int>("axes")) {}
bool RunOnDevice() override {

View file

@ -16,14 +16,13 @@ namespace caffe2 {
template <class Context>
class SinusoidPositionEncodingOp : public Operator<Context> {
public:
SinusoidPositionEncodingOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
embedding_size_(this->template GetSingleArgument<int>(
"embedding_size",
100)),
template <class... Args>
explicit SinusoidPositionEncodingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
embedding_size_(
this->template GetSingleArgument<int>("embedding_size", 100)),
alpha_(this->template GetSingleArgument<float>("alpha", 10000)),
amplitude_(
this->template GetSingleArgument<float>("amplitude", 1)) {}
amplitude_(this->template GetSingleArgument<float>("amplitude", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {

View file

@ -235,8 +235,8 @@ template<>
class SliceOp<CUDAContext> : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
SliceOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws),
template<class... Args> explicit SliceOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
statically_inited_(false) {}
@ -296,8 +296,8 @@ template <>
class SliceGradientOp<CUDAContext> : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
SliceGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CUDAContext>(operator_def, ws),
template<class... Args> explicit SliceGradientOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...),
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
statically_inited_(false) {}