pytorch/caffe2/operators/key_split_ops.h
Mingzhe Li 964e30de1d Workaround for Cuda9.2 and GCC7 compilation errors (#10510)
Summary:
Breaking out of #8338

This PR is a workaround for a bug with CUDA9.2 + GCC7.

Here is the error this PR fixed:
.../pytorch/caffe2/operators/elementwise_ops.h: In constructor ‘caffe2::BinaryElementwiseWithArgsOp<InputTypes, Context, Functor, OutputTypeMap>::BinaryElementwiseWithArgsOp(const caffe2::OperatorDef&, caffe2::Workspace*)’:
.../pytorch/caffe2/operators/elementwise_ops.h:106:189: error: ‘GetSingleArgument<bool>’ is not a member of ‘caffe2::BinaryElementwiseWithArgsOp<InputTypes, Context, Functor, OutputTypeMap>’
   BinaryElementwiseWithArgsOp(const OperatorDef& operator_def, Workspace* ws)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10510

Reviewed By: orionr

Differential Revision: D9319742

Pulled By: mingzhe09088

fbshipit-source-id: ce59e3db14539f071f3c20301e77ca36a6fc3f81
2018-08-14 20:54:52 -07:00

53 lines
1.4 KiB
C++

#pragma once
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class KeySplitOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
KeySplitOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
categorical_limit_(
this->template GetSingleArgument<int>("categorical_limit", 0)) {
CAFFE_ENFORCE_GT(categorical_limit_, 0);
}
bool RunOnDevice() override {
auto& keys = Input(0);
int N = keys.size();
const T* keys_data = keys.template data<T>();
std::vector<int> counts(categorical_limit_);
std::vector<int*> eids(categorical_limit_);
for (int k = 0; k < categorical_limit_; k++) {
counts[k] = 0;
}
for (int i = 0; i < N; i++) {
int k = keys_data[i];
CAFFE_ENFORCE_GT(categorical_limit_, k);
CAFFE_ENFORCE_GE(k, 0);
counts[k]++;
}
for (int k = 0; k < categorical_limit_; k++) {
auto* eid = Output(k);
eid->Resize(counts[k]);
eids[k] = eid->template mutable_data<int>();
counts[k] = 0;
}
for (int i = 0; i < N; i++) {
int k = keys_data[i];
eids[k][counts[k]++] = i;
}
return true;
}
private:
int categorical_limit_;
};
} // namespace caffe2