Fix warnings in tensor_flatten.cpp (#55956)

Summary:
Switch to use `TensorOptions` instead of deprecated `.type()` to fix compiler warnings as part of #55952
](https://our.intern.facebook.com/intern/diff/27830504/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55956

Pulled By: driazati

Reviewed By: pritamdamania87

Differential Revision: D27830504

fbshipit-source-id: f705818ddb7d8b17c0f5383f22dc431203a194d9
This commit is contained in:
davidriazati@fb.com 2021-04-20 17:16:59 -07:00 committed by Facebook GitHub Bot
parent 3d904b56ec
commit 7fff71eb9a
3 changed files with 39 additions and 19 deletions

View file

@ -15,6 +15,7 @@
#include <torch/csrc/autograd/variable.h>
#include <cstddef>
#include <optional>
#include <vector>
namespace torch {
@ -26,15 +27,18 @@ using namespace torch::autograd;
// of a single type only. Adding this logic directly in the loop makes it a bit
// ugly, so here's a helper for it.
struct unique_type_checker {
void show(const at::DeprecatedTypeProperties& t) {
if (!unique)
void show(size_t type_id) {
if (!unique) {
return;
if (!type)
type = &t;
unique = (type == &t);
}
if (!type_id_) {
type_id_ = type_id;
}
unique = type_id_.value() == type_id;
}
const at::DeprecatedTypeProperties* type = nullptr;
optional<size_t> type_id_;
bool unique = true;
};
@ -173,10 +177,10 @@ tensor_list2d broadcast_coalesced(
unique_type_checker type_checker;
at::cuda::CUDAGuard device_guard(devices[0]);
for (auto& chunk : utils::take_tensors(tensors, buffer_size)) {
auto& type = chunk.type();
type_checker.show(type);
auto type_id = chunk.type_id();
type_checker.show(type_id);
std::vector<at::Tensor> results;
if (chunk.type().is_sparse()) {
if (chunk.options().is_sparse()) {
auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors);
auto broadcast_indices = broadcast(flat_tuple.first, devices);
auto broadcast_values = broadcast(flat_tuple.second, devices);

View file

@ -7,6 +7,7 @@ namespace torch { namespace utils {
using namespace at;
std::vector<TensorGroup> take_tensors(
TensorList tensors,
size_t size_limit,
@ -18,7 +19,7 @@ std::vector<TensorGroup> take_tensors(
size_t cur_group_size = 0;
for (const auto & tensor : tensors) {
size_t tensor_size;
size_t tensor_size = 0;
if (tensor.is_sparse()) {
const auto& indices = tensor._indices();
const auto& values = tensor._values();
@ -28,7 +29,7 @@ std::vector<TensorGroup> take_tensors(
tensor_size = tensor.numel() * tensor.element_size();
}
auto& type_group = groups[tensor.type().id()];
auto& type_group = groups[type_id(tensor)];
type_group.tensors.push_back(tensor);
if (fine_grained) {
@ -64,17 +65,17 @@ std::vector<TensorGroup> take_tensors(
void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
AT_ASSERT(tensors.size() == order.size());
std::unordered_map<at::DeprecatedTypeProperties*, std::vector<size_t>> type_indices;
std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
type_indices[&tensors[i].type()].push_back(i);
type_id_to_indices[type_id(tensors[i])].push_back(i);
std::unordered_map<at::DeprecatedTypeProperties*, size_t> type_used;
std::unordered_map<size_t, size_t> type_id_to_type_used;
std::vector<Tensor> ordered_tensors;
ordered_tensors.reserve(tensors.size());
for (auto & tmpl_tensor : order) {
auto * type = &tmpl_tensor.type();
auto & indices = type_indices[type];
auto & used = type_used[type];
size_t tmpl_type_id = type_id(tmpl_tensor);
auto & indices = type_id_to_indices[tmpl_type_id];
auto & used = type_id_to_type_used[tmpl_type_id];
ordered_tensors.push_back(tensors[indices[used++]]);
}
std::swap(tensors, ordered_tensors);

View file

@ -4,9 +4,19 @@
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/ATen.h>
#include <utility>
#include <c10/core/TensorOptions.h>
namespace torch { namespace utils {
/// Generate an ID for a combination of tensor backend + scalar type to be used
/// when ordering tensors ('like' tensors are grouped by pulling out their
/// backend + scalar type, so this function combines that into a single number)
inline size_t type_id(const at::Tensor& tensor) {
return static_cast<size_t>(tensor.options().backend()) *
static_cast<size_t>(at::ScalarType::NumOptions) +
static_cast<size_t>(tensor.scalar_type());
}
inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
static auto flatten = [](const at::Tensor &t) { return t.contiguous().view({-1}); };
if (tensors.size() == 1)
@ -39,9 +49,14 @@ struct TensorGroup {
std::vector<at::Tensor> tensors;
size_t size = 0;
at::DeprecatedTypeProperties& type() {
size_t type_id() {
AT_ASSERT(!tensors.empty());
return tensors[0].type();
return ::torch::utils::type_id(tensors[0]);
}
const at::TensorOptions options() {
AT_ASSERT(!tensors.empty());
return tensors[0].options();
}
};