mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Avoid exposing defines that conflict with google logging, since this blocks external usage of libtorch in certain cases. All the 'interesting' changes should be in these two files, and the rest should just be mechanical changes via sed. c10/util/logging_is_not_google_glog.h c10/util/logging_is_google_glog.h Fixes https://github.com/pytorch/pytorch/issues/81415 cc @miladm @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/82032 Approved by: https://github.com/soumith, https://github.com/miladm
335 lines
9.9 KiB
C++
335 lines
9.9 KiB
C++
|
|
#pragma once
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/utils/math.h"
|
|
#include "c10/util/irange.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class SIndex, class Context>
|
|
bool SliceImpl(
|
|
Tensor* output,
|
|
const Tensor& data,
|
|
const Tensor& starts,
|
|
const Tensor& ends,
|
|
Context* context,
|
|
Tensor* gdata = nullptr,
|
|
const Tensor* go = nullptr) {
|
|
bool backward = output == nullptr;
|
|
|
|
auto* starts_data = starts.template data<SIndex>();
|
|
auto* ends_data = ends.template data<SIndex>();
|
|
|
|
CAFFE_ENFORCE_EQ(starts.dim(), 1);
|
|
CAFFE_ENFORCE_EQ(ends.dim(), 1);
|
|
CAFFE_ENFORCE_GE(data.dim(), starts.numel());
|
|
CAFFE_ENFORCE_EQ(starts.numel(), ends.numel());
|
|
|
|
std::vector<SIndex> starts_idx(data.dim());
|
|
std::vector<SIndex> ends_idx(data.dim());
|
|
std::vector<SIndex> dst_sizes(data.dim());
|
|
|
|
for (const auto i : c10::irange(data.dim())) {
|
|
if (i >= starts.numel()) {
|
|
starts_idx[i] = 0;
|
|
ends_idx[i] = data.size(i);
|
|
dst_sizes[i] = data.size(i);
|
|
continue;
|
|
}
|
|
if (data.size(i) > 0) {
|
|
auto start = starts_data[i];
|
|
auto end = ends_data[i];
|
|
if (start < 0) {
|
|
start = data.size(i) + 1 + start;
|
|
}
|
|
if (end < 0) {
|
|
end = data.size(i) + 1 + end;
|
|
}
|
|
if (start > data.size(i)) {
|
|
start = data.size(i);
|
|
}
|
|
if (end > data.size(i)) {
|
|
end = data.size(i);
|
|
}
|
|
CAFFE_ENFORCE_GE(start, 0);
|
|
CAFFE_ENFORCE_GE(end, 0);
|
|
CAFFE_ENFORCE_GE(end, start);
|
|
starts_idx[i] = start;
|
|
ends_idx[i] = end;
|
|
dst_sizes[i] = end - start;
|
|
} else {
|
|
starts_idx[i] = 0;
|
|
ends_idx[i] = 0;
|
|
dst_sizes[i] = 0;
|
|
}
|
|
}
|
|
|
|
if (data.numel() <= 0) {
|
|
// When the input is empty, we do not need to do copy.
|
|
if (!backward) {
|
|
output->Resize(dst_sizes);
|
|
output->raw_mutable_data(data.dtype());
|
|
} else {
|
|
gdata->ResizeLike(data);
|
|
gdata->raw_mutable_data(go->dtype());
|
|
}
|
|
return true;
|
|
}
|
|
// for now only supports slicing in 1 dimension
|
|
int dim = -1;
|
|
for (const auto i : c10::irange(data.dim())) {
|
|
if (starts_idx[i] > 0 || ends_idx[i] < data.size(i)) {
|
|
CAFFE_ENFORCE_EQ(
|
|
dim, -1, "Currently only possible to slice in 1 dimension.");
|
|
dim = i;
|
|
}
|
|
}
|
|
if (dim == -1) {
|
|
if (!backward) {
|
|
output->CopyFrom(data, true /*async*/);
|
|
} else {
|
|
gdata->CopyFrom(*go, true /*async*/);
|
|
}
|
|
return true;
|
|
}
|
|
size_t unit = std::accumulate(
|
|
data.sizes().begin() + dim + 1,
|
|
data.sizes().end(),
|
|
1,
|
|
std::multiplies<SIndex>());
|
|
size_t num_blocks = std::accumulate(
|
|
data.sizes().begin(),
|
|
data.sizes().begin() + dim,
|
|
1,
|
|
std::multiplies<SIndex>());
|
|
if (!backward) {
|
|
output->Resize(dst_sizes);
|
|
} else {
|
|
gdata->ResizeLike(data);
|
|
}
|
|
|
|
size_t itemsize = data.dtype().itemsize();
|
|
|
|
if (!backward) {
|
|
char* src_bytes = (char*)data.raw_data();
|
|
char* dst_bytes = (char*)output->raw_mutable_data(data.dtype());
|
|
|
|
size_t src_nbytes = data.nbytes();
|
|
size_t dst_nbytes = output->nbytes();
|
|
|
|
size_t src_block_size = unit * data.size(dim);
|
|
size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
|
size_t src_offset = unit * starts_idx[dim];
|
|
|
|
if (num_blocks == 0 || dst_block_size == 0) {
|
|
return true;
|
|
}
|
|
|
|
size_t src_block_size_bytes = itemsize * src_block_size;
|
|
size_t dst_block_size_bytes = itemsize * dst_block_size;
|
|
|
|
char* src_offset_bytes = src_bytes + itemsize * src_offset;
|
|
char* dst_offset_bytes = dst_bytes;
|
|
for (const auto i : c10::irange(num_blocks)) {
|
|
char* local_src_offset_bytes =
|
|
src_offset_bytes + i * src_block_size_bytes;
|
|
char* local_dst_offset_bytes =
|
|
dst_offset_bytes + i * dst_block_size_bytes;
|
|
TORCH_DCHECK_LE(
|
|
static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
|
|
static_cast<void*>(src_bytes + src_nbytes));
|
|
TORCH_DCHECK_LE(
|
|
static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
|
|
static_cast<void*>(dst_bytes + dst_nbytes));
|
|
context->CopyItemsSameDevice(
|
|
data.dtype(),
|
|
dst_block_size,
|
|
(void*)local_src_offset_bytes,
|
|
(void*)local_dst_offset_bytes);
|
|
}
|
|
} else {
|
|
char* src_bytes = (char*)go->raw_data();
|
|
char* dst_bytes = (char*)gdata->raw_mutable_data(go->dtype());
|
|
|
|
size_t src_nbytes = go->nbytes();
|
|
size_t dst_nbytes = gdata->nbytes();
|
|
|
|
size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
|
|
size_t dst_block_size = unit * data.size(dim);
|
|
size_t dst_offset = unit * starts_idx[dim];
|
|
|
|
if (num_blocks == 0 || dst_block_size == 0) {
|
|
return true;
|
|
}
|
|
|
|
size_t src_block_size_bytes = itemsize * src_block_size;
|
|
size_t dst_block_size_bytes = itemsize * dst_block_size;
|
|
|
|
char* src_offset_bytes = src_bytes;
|
|
char* dst_offset_bytes = dst_bytes + itemsize * dst_offset;
|
|
// Zero out gradient blob before copy since we copy in fewer items than
|
|
// there is space for
|
|
math::Set<char, Context>(dst_nbytes, 0, dst_bytes, context);
|
|
|
|
// If output tensor is empty, just return zeroed gradient tensor
|
|
if (!src_bytes) {
|
|
return true;
|
|
}
|
|
|
|
for (const auto i : c10::irange(num_blocks)) {
|
|
char* local_src_offset_bytes =
|
|
src_offset_bytes + i * src_block_size_bytes;
|
|
char* local_dst_offset_bytes =
|
|
dst_offset_bytes + i * dst_block_size_bytes;
|
|
TORCH_DCHECK_LE(
|
|
local_src_offset_bytes + src_block_size_bytes,
|
|
src_bytes + src_nbytes);
|
|
TORCH_DCHECK_LE(
|
|
local_dst_offset_bytes + src_block_size_bytes,
|
|
dst_bytes + dst_nbytes);
|
|
context->CopyItemsSameDevice(
|
|
go->dtype(),
|
|
src_block_size,
|
|
(void*)local_src_offset_bytes,
|
|
(void*)local_dst_offset_bytes);
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <class Context>
|
|
class SliceOp : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
|
explicit SliceOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
|
|
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
|
|
statically_inited_(false) {}
|
|
|
|
bool RunOnDevice() override {
|
|
if (InputSize() > 1) {
|
|
return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
|
|
} else {
|
|
return DoRunWithType<int64_t>();
|
|
}
|
|
}
|
|
|
|
template <typename SIndex>
|
|
bool DoRunWithType() {
|
|
if (InputSize() > 1) {
|
|
ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU), Input(1));
|
|
ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU), Input(2));
|
|
} else {
|
|
if (!statically_inited_) {
|
|
CAFFE_ENFORCE(HasArgument("starts"));
|
|
CAFFE_ENFORCE(HasArgument("ends"));
|
|
CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
|
|
|
|
ReinitializeTensor(&starts_host_, {static_cast<int64_t>(starts_.size())}, at::dtype<SIndex>().device(CPU));
|
|
ReinitializeTensor(&ends_host_, {static_cast<int64_t>(ends_.size())}, at::dtype<SIndex>().device(CPU));
|
|
|
|
memcpy(
|
|
starts_host_.template mutable_data<SIndex>(),
|
|
starts_.data(),
|
|
sizeof(SIndex) * starts_.size());
|
|
memcpy(
|
|
ends_host_.template mutable_data<SIndex>(),
|
|
ends_.data(),
|
|
sizeof(SIndex) * ends_.size());
|
|
statically_inited_ = true;
|
|
}
|
|
}
|
|
|
|
const auto& data = Input(0);
|
|
auto output = Output(0);
|
|
|
|
return SliceImpl<SIndex, Context>(
|
|
output, data, starts_host_, ends_host_, &context_);
|
|
}
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(SliceOp);
|
|
|
|
protected:
|
|
std::vector<int64_t> starts_;
|
|
std::vector<int64_t> ends_;
|
|
bool statically_inited_;
|
|
Tensor starts_host_;
|
|
Tensor ends_host_;
|
|
};
|
|
|
|
template <class Context>
|
|
class SliceGradientOp : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
template <class... Args>
|
|
explicit SliceGradientOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
starts_(this->template GetRepeatedArgument<int64_t>("starts")),
|
|
ends_(this->template GetRepeatedArgument<int64_t>("ends")),
|
|
statically_inited_(false) {}
|
|
|
|
C10_DISABLE_COPY_AND_ASSIGN(SliceGradientOp);
|
|
|
|
bool RunOnDevice() override {
|
|
if (InputSize() == 4) {
|
|
return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
|
|
} else {
|
|
return DoRunWithType<int64_t>();
|
|
}
|
|
}
|
|
|
|
template <typename SIndex>
|
|
bool DoRunWithType() {
|
|
auto* gdata = Output(0);
|
|
auto& data = Input(0);
|
|
|
|
if (InputSize() == 4) {
|
|
ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU), Input(1));
|
|
ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU), Input(2));
|
|
|
|
auto& go = Input(3);
|
|
|
|
return SliceImpl<SIndex, Context>(
|
|
nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
|
|
} else {
|
|
if (!statically_inited_) {
|
|
CAFFE_ENFORCE(HasArgument("starts"));
|
|
CAFFE_ENFORCE(HasArgument("ends"));
|
|
CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
|
|
|
|
ReinitializeTensor(
|
|
&starts_host_, {static_cast<int64_t>(starts_.size())}, at::dtype<SIndex>().device(CPU));
|
|
ReinitializeTensor(
|
|
&ends_host_, {static_cast<int64_t>(ends_.size())}, at::dtype<SIndex>().device(CPU));
|
|
|
|
memcpy(
|
|
starts_host_.template mutable_data<SIndex>(),
|
|
starts_.data(),
|
|
sizeof(SIndex) * starts_.size());
|
|
memcpy(
|
|
ends_host_.template mutable_data<SIndex>(),
|
|
ends_.data(),
|
|
sizeof(SIndex) * ends_.size());
|
|
|
|
statically_inited_ = true;
|
|
}
|
|
auto& go = Input(1);
|
|
|
|
return SliceImpl<SIndex, Context>(
|
|
nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
|
|
}
|
|
}
|
|
|
|
private:
|
|
|
|
std::vector<int64_t> starts_;
|
|
std::vector<int64_t> ends_;
|
|
bool statically_inited_;
|
|
Tensor starts_host_;
|
|
Tensor ends_host_;
|
|
};
|
|
} // namespace caffe2
|