pytorch/caffe2/operators/elementwise_linear_op.h
Yury Zemlyanskiy e8c274cf16 Optimize memory usage for MI-LSTM
Summary:
Use ElementwiseLinearOps instead of manual Mul + Sum. That saves intermediate blobs.

For NMT use case

Before: https://our.intern.facebook.com/intern/fblearner/details/18060753
Time per step: 0.072
memory usage (per each of 2 GPUs): 9041MiB

After:https://our.intern.facebook.com/intern/fblearner/details/18107583
Time per step: 0.0715
Memory (per each GPU): 8560MiB

Reviewed By: akyrola

Differential Revision: D5038785

fbshipit-source-id: 4bc8155dbd0c87729e17236d68d62ca530aadb53
2017-05-10 16:53:43 -07:00

39 lines
1.1 KiB
C++

#ifndef CAFFE2_OPERATORS_ELEMENTWISE_LINEAR_OP_H_
#define CAFFE2_OPERATORS_ELEMENTWISE_LINEAR_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context, class Engine = DefaultEngine>
class ElementwiseLinearOp final : public Operator<Context> {
public:
ElementwiseLinearOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
int axis_;
};
template <typename T, class Context, class Engine = DefaultEngine>
class ElementwiseLinearGradientOp final : public Operator<Context> {
public:
ElementwiseLinearGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
protected:
int axis_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ELEMENTWISE_LINEAR_OP_H_