pytorch/caffe2/experiments/operators/fully_connected_op_sparse.h
Jerry Zhang 20d1bff292 Tensor construction codemod - 1/3 (#14828)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14828

Codemod generated with clangr shard mode, 25 files per diff,
motivation: https://github.com/pytorch/pytorch/pull/12407

Reviewed By: bddppq

Differential Revision: D13335160

fbshipit-source-id: a3ae4c5a86bfbdaf2d5aa14e0eef57255e829fd4
2018-12-06 11:47:32 -08:00

149 lines
4.1 KiB
C++

/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#endif // CAFFE2_USE_MKL
namespace caffe2 {
namespace {
template<int N>
using Shape = std::array<int, N>;
template<int N>
const std::vector<int64_t>& shape(Shape<N> vs) {
static thread_local std::vector<int64_t> cache;
cache.resize(vs.size());
for (auto i = 0; i < vs.size(); ++i) {
cache[i] = vs[i];
}
return cache;
}
inline const std::vector<int64_t>& shape(int i) {
return shape<1>(Shape<1>({i}));
}
inline const std::vector<int64_t>& shape(int i, int j) {
return shape<2>(Shape<2>({i, j}));
}
template <typename T, class Context>
void Sparse_mm(const T* acsr, const int* ia, const int* ja,
int m, int k, int n, const T* b, T* c, Context* context);
template<typename T, class Context>
void trans_mat(const T* o, T* t, int m, int n, Context* context);
template <>
void trans_mat<float, CPUContext>(
const float* o,
float* t,
int m,
int n,
CPUContext* /*context*/) {
for(int i = 0; i < m; ++i){
for(int j = 0; j < n; ++j){
t[j*m+i]=o[i*n+j];
}
}
}
// C = A(sparse) * B
// No transpose;
template <>
void Sparse_mm<float, CPUContext>(
const float* acsr,
const int* ia,
const int* ja,
int m,
int k,
int n,
const float* b,
float* c,
CPUContext* /*context*/) {
float alpha = 1.0, beta = 0.;
mkl_scsrmm("N", &m, &n, &k, &alpha, "GLNC",
acsr, ja, ia, ia+1, b, &n, &beta, c, &n);
}
}
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
template <typename T, class Context, class Engine=DefaultEngine>
class FullyConnectedOp_SPARSE final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
FullyConnectedOp_SPARSE(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
~FullyConnectedOp_SPARSE() {}
bool RunOnDevice() override {
const auto& Xt = Input(0); // transposed X
const auto& Wcsr = Input(1);
const auto& iw = Input(2);
const auto& jw = Input(3);
// Notice that we do not need to transpose b
const auto& b = Input(4);
// transposed Y
// here we assume X is k-by-m
CAFFE_ENFORCE_EQ(Xt.dim(), 2);
CAFFE_ENFORCE_EQ(b.dim(), 1);
// batch size
int K = Xt.dim() > 1 ? Xt.dim32(0) : 1;
// Feature dimension
int M = Xt.numel() / K;
// number of outputs.
int N = iw.dim32(0)-1;
CAFFE_ENFORCE_EQ(N, b.dim32(0));
auto* Yt = Output(0, shape(N, M), at::dtype<T>());
// Y' = W * X';
Sparse_mm<T, Context>(
Wcsr.template data<T>(), iw.template data<int>(),
jw.template data<int>(), N, K, M, Xt.template data<T>(),
Yt->template mutable_data<T>(), &context_);
// Add bias term
if (bias_multiplier_.numel() != M) {
// If the helper bias multiplier is not M, reshape and fill it with one.
bias_multiplier_.Resize(shape(M));
math::Set<T, Context>(
M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
&context_);
}
math::Gemm<T, Context, Engine>(
CblasNoTrans, CblasNoTrans, N, M, 1, 1,
b.template data<T>(), bias_multiplier_.template data<T>(), 1,
Yt->template mutable_data<T>(), &context_);
return true;
}
protected:
Tensor bias_multiplier_{Context::GetDeviceType()};
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_