mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix warnings in tensor_flatten.cpp (#55956)
Summary: Switch to use `TensorOptions` instead of deprecated `.type()` to fix compiler warnings as part of #55952 ](https://our.intern.facebook.com/intern/diff/27830504/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/55956 Pulled By: driazati Reviewed By: pritamdamania87 Differential Revision: D27830504 fbshipit-source-id: f705818ddb7d8b17c0f5383f22dc431203a194d9
This commit is contained in:
parent
3d904b56ec
commit
7fff71eb9a
3 changed files with 39 additions and 19 deletions
|
|
@ -15,6 +15,7 @@
|
|||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -26,15 +27,18 @@ using namespace torch::autograd;
|
|||
// of a single type only. Adding this logic directly in the loop makes it a bit
|
||||
// ugly, so here's a helper for it.
|
||||
struct unique_type_checker {
|
||||
void show(const at::DeprecatedTypeProperties& t) {
|
||||
if (!unique)
|
||||
void show(size_t type_id) {
|
||||
if (!unique) {
|
||||
return;
|
||||
if (!type)
|
||||
type = &t;
|
||||
unique = (type == &t);
|
||||
}
|
||||
if (!type_id_) {
|
||||
type_id_ = type_id;
|
||||
}
|
||||
|
||||
unique = type_id_.value() == type_id;
|
||||
}
|
||||
|
||||
const at::DeprecatedTypeProperties* type = nullptr;
|
||||
optional<size_t> type_id_;
|
||||
bool unique = true;
|
||||
};
|
||||
|
||||
|
|
@ -173,10 +177,10 @@ tensor_list2d broadcast_coalesced(
|
|||
unique_type_checker type_checker;
|
||||
at::cuda::CUDAGuard device_guard(devices[0]);
|
||||
for (auto& chunk : utils::take_tensors(tensors, buffer_size)) {
|
||||
auto& type = chunk.type();
|
||||
type_checker.show(type);
|
||||
auto type_id = chunk.type_id();
|
||||
type_checker.show(type_id);
|
||||
std::vector<at::Tensor> results;
|
||||
if (chunk.type().is_sparse()) {
|
||||
if (chunk.options().is_sparse()) {
|
||||
auto flat_tuple = utils::flatten_sparse_tensors(chunk.tensors);
|
||||
auto broadcast_indices = broadcast(flat_tuple.first, devices);
|
||||
auto broadcast_values = broadcast(flat_tuple.second, devices);
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ namespace torch { namespace utils {
|
|||
|
||||
using namespace at;
|
||||
|
||||
|
||||
std::vector<TensorGroup> take_tensors(
|
||||
TensorList tensors,
|
||||
size_t size_limit,
|
||||
|
|
@ -18,7 +19,7 @@ std::vector<TensorGroup> take_tensors(
|
|||
size_t cur_group_size = 0;
|
||||
|
||||
for (const auto & tensor : tensors) {
|
||||
size_t tensor_size;
|
||||
size_t tensor_size = 0;
|
||||
if (tensor.is_sparse()) {
|
||||
const auto& indices = tensor._indices();
|
||||
const auto& values = tensor._values();
|
||||
|
|
@ -28,7 +29,7 @@ std::vector<TensorGroup> take_tensors(
|
|||
tensor_size = tensor.numel() * tensor.element_size();
|
||||
}
|
||||
|
||||
auto& type_group = groups[tensor.type().id()];
|
||||
auto& type_group = groups[type_id(tensor)];
|
||||
type_group.tensors.push_back(tensor);
|
||||
|
||||
if (fine_grained) {
|
||||
|
|
@ -64,17 +65,17 @@ std::vector<TensorGroup> take_tensors(
|
|||
|
||||
void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
|
||||
AT_ASSERT(tensors.size() == order.size());
|
||||
std::unordered_map<at::DeprecatedTypeProperties*, std::vector<size_t>> type_indices;
|
||||
std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
|
||||
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
|
||||
type_indices[&tensors[i].type()].push_back(i);
|
||||
type_id_to_indices[type_id(tensors[i])].push_back(i);
|
||||
|
||||
std::unordered_map<at::DeprecatedTypeProperties*, size_t> type_used;
|
||||
std::unordered_map<size_t, size_t> type_id_to_type_used;
|
||||
std::vector<Tensor> ordered_tensors;
|
||||
ordered_tensors.reserve(tensors.size());
|
||||
for (auto & tmpl_tensor : order) {
|
||||
auto * type = &tmpl_tensor.type();
|
||||
auto & indices = type_indices[type];
|
||||
auto & used = type_used[type];
|
||||
size_t tmpl_type_id = type_id(tmpl_tensor);
|
||||
auto & indices = type_id_to_indices[tmpl_type_id];
|
||||
auto & used = type_id_to_type_used[tmpl_type_id];
|
||||
ordered_tensors.push_back(tensors[indices[used++]]);
|
||||
}
|
||||
std::swap(tensors, ordered_tensors);
|
||||
|
|
|
|||
|
|
@ -4,9 +4,19 @@
|
|||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <utility>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
namespace torch { namespace utils {
|
||||
|
||||
/// Generate an ID for a combination of tensor backend + scalar type to be used
|
||||
/// when ordering tensors ('like' tensors are grouped by pulling out their
|
||||
/// backend + scalar type, so this function combines that into a single number)
|
||||
inline size_t type_id(const at::Tensor& tensor) {
|
||||
return static_cast<size_t>(tensor.options().backend()) *
|
||||
static_cast<size_t>(at::ScalarType::NumOptions) +
|
||||
static_cast<size_t>(tensor.scalar_type());
|
||||
}
|
||||
|
||||
inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
|
||||
static auto flatten = [](const at::Tensor &t) { return t.contiguous().view({-1}); };
|
||||
if (tensors.size() == 1)
|
||||
|
|
@ -39,9 +49,14 @@ struct TensorGroup {
|
|||
std::vector<at::Tensor> tensors;
|
||||
size_t size = 0;
|
||||
|
||||
at::DeprecatedTypeProperties& type() {
|
||||
size_t type_id() {
|
||||
AT_ASSERT(!tensors.empty());
|
||||
return tensors[0].type();
|
||||
return ::torch::utils::type_id(tensors[0]);
|
||||
}
|
||||
|
||||
const at::TensorOptions options() {
|
||||
AT_ASSERT(!tensors.empty());
|
||||
return tensors[0].options();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue