diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 03a595f4423..4ee6bb25126 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -15,6 +15,7 @@ #include #include +#include #include namespace torch { @@ -26,15 +27,18 @@ using namespace torch::autograd; // of a single type only. Adding this logic directly in the loop makes it a bit // ugly, so here's a helper for it. struct unique_type_checker { - void show(const at::DeprecatedTypeProperties& t) { - if (!unique) + void show(size_t type_id) { + if (!unique) { return; - if (!type) - type = &t; - unique = (type == &t); + } + if (!type_id_) { + type_id_ = type_id; + } + + unique = type_id_.value() == type_id; } - const at::DeprecatedTypeProperties* type = nullptr; + optional type_id_; bool unique = true; }; @@ -173,10 +177,10 @@ tensor_list2d broadcast_coalesced( unique_type_checker type_checker; at::cuda::CUDAGuard device_guard(devices[0]); for (auto& chunk : utils::take_tensors(tensors, buffer_size)) { - auto& type = chunk.type(); - type_checker.show(type); + auto type_id = chunk.type_id(); + type_checker.show(type_id); std::vector results; - if (chunk.type().is_sparse()) { + if (chunk.options().is_sparse()) { auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors); auto broadcast_indices = broadcast(flat_tuple.first, devices); auto broadcast_values = broadcast(flat_tuple.second, devices); diff --git a/torch/csrc/utils/tensor_flatten.cpp b/torch/csrc/utils/tensor_flatten.cpp index 5d1ee20febf..1bbf90ac453 100644 --- a/torch/csrc/utils/tensor_flatten.cpp +++ b/torch/csrc/utils/tensor_flatten.cpp @@ -7,6 +7,7 @@ namespace torch { namespace utils { using namespace at; + std::vector take_tensors( TensorList tensors, size_t size_limit, @@ -18,7 +19,7 @@ std::vector take_tensors( size_t cur_group_size = 0; for (const auto & tensor : tensors) { - size_t tensor_size; + size_t tensor_size = 0; if (tensor.is_sparse()) { const auto& indices = tensor._indices(); const auto& values = tensor._values(); @@ -28,7 +29,7 @@ std::vector take_tensors( tensor_size = tensor.numel() * tensor.element_size(); } - auto& type_group = groups[tensor.type().id()]; + auto& type_group = groups[type_id(tensor)]; type_group.tensors.push_back(tensor); if (fine_grained) { @@ -64,17 +65,17 @@ std::vector take_tensors( void reorder_tensors_like(std::vector& tensors, TensorList order) { AT_ASSERT(tensors.size() == order.size()); - std::unordered_map> type_indices; + std::unordered_map> type_id_to_indices; for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) - type_indices[&tensors[i].type()].push_back(i); + type_id_to_indices[type_id(tensors[i])].push_back(i); - std::unordered_map type_used; + std::unordered_map type_id_to_type_used; std::vector ordered_tensors; ordered_tensors.reserve(tensors.size()); for (auto & tmpl_tensor : order) { - auto * type = &tmpl_tensor.type(); - auto & indices = type_indices[type]; - auto & used = type_used[type]; + size_t tmpl_type_id = type_id(tmpl_tensor); + auto & indices = type_id_to_indices[tmpl_type_id]; + auto & used = type_id_to_type_used[tmpl_type_id]; ordered_tensors.push_back(tensors[indices[used++]]); } std::swap(tensors, ordered_tensors); diff --git a/torch/csrc/utils/tensor_flatten.h b/torch/csrc/utils/tensor_flatten.h index cb54bbb53a7..8ee788b7fa3 100644 --- a/torch/csrc/utils/tensor_flatten.h +++ b/torch/csrc/utils/tensor_flatten.h @@ -4,9 +4,19 @@ #include #include #include +#include namespace torch { namespace utils { +/// Generate an ID for a combination of tensor backend + scalar type to be used +/// when ordering tensors ('like' tensors are grouped by pulling out their +/// backend + scalar type, so this function combines that into a single number) +inline size_t type_id(const at::Tensor& tensor) { + return static_cast(tensor.options().backend()) * + static_cast(at::ScalarType::NumOptions) + + static_cast(tensor.scalar_type()); +} + inline at::Tensor flatten_dense_tensors(at::TensorList tensors) { static auto flatten = [](const at::Tensor &t) { return t.contiguous().view({-1}); }; if (tensors.size() == 1) @@ -39,9 +49,14 @@ struct TensorGroup { std::vector tensors; size_t size = 0; - at::DeprecatedTypeProperties& type() { + size_t type_id() { AT_ASSERT(!tensors.empty()); - return tensors[0].type(); + return ::torch::utils::type_id(tensors[0]); + } + + const at::TensorOptions options() { + AT_ASSERT(!tensors.empty()); + return tensors[0].options(); } };