mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This reverts commit ca3b2bfbe3.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84806
Approved by: https://github.com/Chillee
This commit is contained in:
parent
bccc26f365
commit
c5a8946e40
27 changed files with 470 additions and 288 deletions
|
|
@ -17,7 +17,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
|
|||
{
|
||||
TORCH_INTERNAL_ASSERT(value_.defined());
|
||||
set_storage_access_should_throw();
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
checkInvariants();
|
||||
|
||||
const auto public_dims = value_.dim() - bdims_.size();
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ TensorBase empty_symint_meta(
|
|||
TORCH_CHECK(0, "other memory format not implemented yet");
|
||||
}
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_sym_sizes_and_strides(size, strides);
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, strides);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,6 +49,9 @@ void FunctionalTensorWrapper::set_constructor_metadata() {
|
|||
// Instead, it's sufficient to remove the `Dense` dispatch key,
|
||||
// which prevents us from accidentally trying to directly run a CPU/CUDA kernel.
|
||||
key_set_ = key_set_.remove(c10::DispatchKey::Dense);
|
||||
// We override a bunch of _custom(), so make sure they get called
|
||||
// TODO: metadata copying may not actually be necessary then
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
|
||||
}
|
||||
|
||||
FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
|
||||
|
|
@ -343,9 +346,6 @@ int64_t FunctionalTensorWrapper::numel_custom() const {
|
|||
bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
|
||||
return value_.unsafeGetTensorImpl()->is_contiguous();
|
||||
}
|
||||
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes() const {
|
||||
return value_.unsafeGetTensorImpl()->sym_sizes();
|
||||
}
|
||||
c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
|
||||
return value_.unsafeGetTensorImpl()->sym_sizes();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -141,7 +141,6 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
|||
int64_t dim_custom() const override;
|
||||
int64_t numel_custom() const override;
|
||||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
c10::SymIntArrayRef sym_sizes() const override;
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ NestedTensorImpl::NestedTensorImpl(
|
|||
storage_device);
|
||||
validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, offsets_);
|
||||
refresh_dim();
|
||||
set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
|
||||
set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
|
||||
}
|
||||
|
||||
NestedTensorImpl::NestedTensorImpl(
|
||||
|
|
@ -203,7 +203,7 @@ NestedTensorImpl::NestedTensorImpl(
|
|||
TORCH_INTERNAL_ASSERT(base_tensor.is_nested());
|
||||
validate_nested_tensor_metadata(nested_size_tensor_, nested_stride_tensor_, offsets_);
|
||||
refresh_dim();
|
||||
set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
|
||||
set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
|
||||
}
|
||||
|
||||
void NestedTensorImpl::refresh_dim() {
|
||||
|
|
@ -256,9 +256,6 @@ c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const {
|
|||
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef NestedTensorImpl::sym_sizes() const {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
c10::SymIntArrayRef NestedTensorImpl::sym_strides_custom() const {
|
||||
TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,7 +109,6 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
|
|||
}
|
||||
IntArrayRef sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_sizes() const override;
|
||||
IntArrayRef strides_custom() const override;
|
||||
c10::SymIntArrayRef sym_strides_custom() const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
|
|||
: TensorImpl(key_set, data_type, device),
|
||||
opaque_handle_(std::move(opaque_handle)) {
|
||||
set_storage_access_should_throw();
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
sizes_and_strides_.set_sizes(sizes);
|
||||
refresh_numel();
|
||||
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
|
|||
"to https://github.com/pytorch/pytorch/issues.");
|
||||
set_storage_access_should_throw();
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
// TODO: If this check ever shows up as a bottleneck, which is unlikely given that
|
||||
// comparing devices only involves comparing the type and index (two integers), we
|
||||
// can move this to a DEBUG only assert. Until then this confirms and maintains a
|
||||
|
|
@ -172,5 +172,8 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
|
|||
void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
|
||||
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
|
||||
}
|
||||
bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const {
|
||||
TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
|
|||
protected:
|
||||
IntArrayRef strides_custom() const override;
|
||||
SymIntArrayRef sym_strides_custom() const override;
|
||||
bool is_contiguous_custom(MemoryFormat) const override;
|
||||
|
||||
public:
|
||||
void set_size(int64_t dim, int64_t new_size) override;
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::Typ
|
|||
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
set_storage_access_should_throw();
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
}
|
||||
|
||||
// Destructor doesn't call release_resources because it's
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, int64
|
|||
{
|
||||
TORCH_INTERNAL_ASSERT(value_.defined());
|
||||
set_storage_access_should_throw();
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
checkInvariants();
|
||||
refreshTensorMetadata();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -389,83 +389,93 @@ impl::PyInterpreter& TensorImpl::load_pyobj_interpreter() const {
|
|||
}
|
||||
|
||||
bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
|
||||
if (is_python_dispatch()) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
// TODO: pass memory_format to is_contiguous call
|
||||
return load_pyobj_interpreter()->is_contiguous(this);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tensors of type ",
|
||||
tensorimpl_type_name(),
|
||||
" do not have is_contiguous");
|
||||
return is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
IntArrayRef TensorImpl::sizes_custom() const {
|
||||
if (is_python_dispatch()) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
return load_pyobj_interpreter()->sizes(this);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have sizes");
|
||||
return sizes_default();
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef TensorImpl::sym_sizes_custom() const {
|
||||
if (C10_UNLIKELY(is_python_dispatch())) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
return load_pyobj_interpreter()->sym_sizes(this);
|
||||
}
|
||||
return sym_sizes_default();
|
||||
}
|
||||
|
||||
c10::SymInt TensorImpl::sym_numel_custom() const {
|
||||
if (C10_UNLIKELY(is_python_dispatch())) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
return load_pyobj_interpreter()->sym_numel(this);
|
||||
}
|
||||
return sym_numel_default();
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef TensorImpl::sym_strides_custom() const {
|
||||
if (C10_UNLIKELY(is_python_dispatch())) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
return load_pyobj_interpreter()->sym_strides(this);
|
||||
}
|
||||
return sym_strides_default();
|
||||
}
|
||||
|
||||
c10::Device TensorImpl::device_custom() const {
|
||||
if (is_python_dispatch()) {
|
||||
if (C10_UNLIKELY(python_custom_device_)) {
|
||||
return load_pyobj_interpreter()->device(this);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have device");
|
||||
return device_default();
|
||||
}
|
||||
|
||||
IntArrayRef TensorImpl::strides_custom() const {
|
||||
if (is_python_dispatch()) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
|
||||
return load_pyobj_interpreter()->strides(this);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tensors of type ",
|
||||
tensorimpl_type_name(),
|
||||
" do not have strides");
|
||||
return strides_default();
|
||||
}
|
||||
|
||||
int64_t TensorImpl::dim_custom() const {
|
||||
if (is_python_dispatch()) {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
return load_pyobj_interpreter()->dim(this);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have dim");
|
||||
return dim_default();
|
||||
}
|
||||
|
||||
int64_t TensorImpl::numel_custom() const {
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have numel");
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
// TODO: fix this
|
||||
return load_pyobj_interpreter()->sym_numel(this).expect_int();
|
||||
}
|
||||
return numel_default();
|
||||
}
|
||||
|
||||
c10::Layout TensorImpl::layout_custom() const {
|
||||
if (is_python_dispatch()) {
|
||||
if (C10_UNLIKELY(python_custom_layout_)) {
|
||||
return load_pyobj_interpreter()->layout(this);
|
||||
}
|
||||
// TODO: fix this
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have layout");
|
||||
0, "Tensors of type ", tensorimpl_type_name(), " do not have layout")
|
||||
// return layout_default();
|
||||
}
|
||||
|
||||
int64_t TensorImpl::storage_offset_custom() const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
// TODO: fix this
|
||||
return load_pyobj_interpreter()->sym_storage_offset(this).expect_int();
|
||||
}
|
||||
return storage_offset_default();
|
||||
}
|
||||
|
||||
c10::SymInt TensorImpl::sym_storage_offset_custom() const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
return load_pyobj_interpreter()->sym_storage_offset(this);
|
||||
}
|
||||
return sym_storage_offset_default();
|
||||
}
|
||||
|
||||
static void deletePlacementDeleteContext(void* ptr) {
|
||||
|
|
@ -623,7 +633,15 @@ void TensorImpl::copy_generic_tensor_metadata(
|
|||
if (src_impl->extra_meta_ != nullptr) {
|
||||
dest_impl->extra_meta_ = src_impl->extra_meta_->clone();
|
||||
}
|
||||
dest_impl->sizes_strides_policy_ = src_impl->sizes_strides_policy_;
|
||||
|
||||
// NB: symbolic sizes and strides are copied, but custom policy is
|
||||
// NOT (you have no Python object to dispatch to!)
|
||||
// NB: subclass relevant policy doesn't have to be copied; the
|
||||
// constructor sets this up
|
||||
|
||||
dest_impl->refresh_sizes_strides_policy();
|
||||
dest_impl->refresh_layout_policy();
|
||||
dest_impl->refresh_device_policy();
|
||||
}
|
||||
|
||||
void TensorImpl::copy_tensor_metadata_except_version_counter(
|
||||
|
|
@ -867,22 +885,37 @@ void TensorImpl::ShareExternalPointer(
|
|||
}
|
||||
}
|
||||
|
||||
void TensorImpl::set_sym_sizes_and_strides(
|
||||
void TensorImpl::set_sizes_and_strides(
|
||||
c10::SymIntArrayRef sizes,
|
||||
c10::SymIntArrayRef strides) {
|
||||
c10::SymIntArrayRef strides,
|
||||
c10::optional<c10::SymInt> storage_offset) {
|
||||
auto int_sizes = asIntArrayRefSlowOpt(sizes);
|
||||
auto int_strides = asIntArrayRefSlowOpt(strides);
|
||||
if (int_sizes && int_strides &&
|
||||
(!storage_offset.has_value() || !storage_offset->is_symbolic()) &&
|
||||
!has_symbolic_sizes_strides_) {
|
||||
set_sizes_and_strides(*int_sizes, *int_strides);
|
||||
if (storage_offset.has_value())
|
||||
set_storage_offset(storage_offset->as_int_unchecked());
|
||||
return;
|
||||
}
|
||||
|
||||
has_symbolic_sizes_strides_ = true;
|
||||
sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::CustomSizes);
|
||||
refresh_sizes_strides_policy();
|
||||
if (!extra_meta_) {
|
||||
extra_meta_ = std::make_unique<ExtraMeta>();
|
||||
if (!storage_offset.has_value())
|
||||
extra_meta_->storage_offset_ = storage_offset_;
|
||||
}
|
||||
extra_meta_->sizes_ = sizes;
|
||||
extra_meta_->strides_ = strides;
|
||||
if (storage_offset.has_value())
|
||||
extra_meta_->storage_offset_ = std::move(*storage_offset);
|
||||
SymInt numel = 1;
|
||||
for (const auto& s : sizes) {
|
||||
numel *= s;
|
||||
}
|
||||
extra_meta_->numel_ = numel;
|
||||
// TODO: refresh the other entries
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ struct C10_API ExtraMeta {
|
|||
SymDimVector sizes_ = {0};
|
||||
SymDimVector strides_ = {1};
|
||||
SymInt numel_ = 1;
|
||||
SymInt storage_offset_ = 0; // TODO
|
||||
SymInt storage_offset_ = 0;
|
||||
// TODO:
|
||||
// SymBool is_contiguous_;
|
||||
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
|
||||
|
|
@ -573,41 +573,88 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return key_set_;
|
||||
}
|
||||
|
||||
// NOTE: The general recipe for customizable methods is that the fastpath
|
||||
// function (e.g., sizes()) does an unlikely policy test, and if doesn't
|
||||
// trigger, it does the fast path implementation with no checks and going
|
||||
// directly to on-TensorImpl fields. In particular, you never need to
|
||||
// check ExtraMeta if the policy doesn't trigger, as non-trivial ExtraMeta
|
||||
// implies the policy will always match.
|
||||
//
|
||||
// The default implementations of methods are "safe": they do extra tests
|
||||
// to make sure the internal state is consistent no matter if you are
|
||||
// doing symbolic shapes or not. If you don't want the tests, directly
|
||||
// override the custom method (e.g., custom_sizes()) to do your preferred
|
||||
// behavior.
|
||||
|
||||
public:
|
||||
/**
|
||||
* Return a reference to the sizes of this tensor. This reference remains
|
||||
* valid as long as the tensor is live and not resized.
|
||||
*/
|
||||
IntArrayRef sizes() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return sizes_custom();
|
||||
}
|
||||
return sizes_default();
|
||||
return sizes_and_strides_.sizes_arrayref();
|
||||
}
|
||||
|
||||
// TODO: make it non-virtual after a change to XLA
|
||||
virtual c10::SymIntArrayRef sym_sizes() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
SymIntArrayRef sym_sizes() const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
return sym_sizes_default();
|
||||
// Sizes guaranteed to be non-negative, so unchecked cast is OK
|
||||
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
|
||||
sizes_and_strides_.sizes_arrayref());
|
||||
}
|
||||
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const;
|
||||
IntArrayRef sizes_default() const {
|
||||
// TODO: force backtrace to be printed on this error
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"Cannot call sizes() on tensor with symbolic sizes/strides");
|
||||
return sizes_and_strides_.sizes_arrayref();
|
||||
}
|
||||
|
||||
SymIntArrayRef sym_sizes_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->sizes_;
|
||||
} else {
|
||||
// Sizes guaranteed to be non-negative, so unchecked cast is OK
|
||||
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
|
||||
sizes_default());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of elements in a tensor.
|
||||
*
|
||||
* WARNING: Previously, if you were using the Caffe2 API, you could
|
||||
* test numel() == -1 to see if a tensor was uninitialized. This
|
||||
* is no longer true; numel always accurately reports the product
|
||||
* of sizes of a tensor.
|
||||
*/
|
||||
int64_t numel() const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return numel_custom();
|
||||
}
|
||||
return numel_;
|
||||
}
|
||||
|
||||
c10::SymInt sym_numel() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_numel_custom();
|
||||
}
|
||||
return sym_numel_default();
|
||||
return c10::SymInt(SymInt::UNCHECKED, numel_);
|
||||
}
|
||||
|
||||
inline c10::SymInt sym_numel_default() const {
|
||||
int64_t numel_default() const {
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"Cannot call numel() on tensor with symbolic sizes/strides");
|
||||
return numel_;
|
||||
}
|
||||
|
||||
c10::SymInt sym_numel_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->numel_;
|
||||
} else {
|
||||
|
|
@ -615,31 +662,89 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
}
|
||||
|
||||
virtual c10::SymInt sym_numel_custom() const;
|
||||
/**
|
||||
* Return the number of dimensions of this tensor. Note that 0-dimension
|
||||
* represents a Tensor that is a Scalar, e.g., one that has a single element.
|
||||
*/
|
||||
int64_t dim() const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return dim_custom();
|
||||
}
|
||||
return sizes_and_strides_.size();
|
||||
}
|
||||
|
||||
int64_t dim_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->sizes_.size();
|
||||
} else {
|
||||
return sizes_and_strides_.size();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the offset in number of elements into the storage that this
|
||||
* tensor points to. Most tensors have storage_offset() == 0, but,
|
||||
* for example, an index into a tensor will have a non-zero storage_offset().
|
||||
*
|
||||
* WARNING: This is NOT computed in bytes.
|
||||
*/
|
||||
int64_t storage_offset() const {
|
||||
// TODO: maybe this should be toggled by strides
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return storage_offset_custom();
|
||||
}
|
||||
return storage_offset_;
|
||||
}
|
||||
|
||||
c10::SymInt sym_storage_offset() const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_storage_offset_custom();
|
||||
}
|
||||
return c10::SymInt(SymInt::UNCHECKED, storage_offset_);
|
||||
}
|
||||
|
||||
int64_t storage_offset_default() const {
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"Cannot call storage_offset() on tensor with symbolic sizes/strides");
|
||||
return storage_offset_;
|
||||
}
|
||||
|
||||
c10::SymInt sym_storage_offset_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->storage_offset_;
|
||||
} else {
|
||||
return c10::SymInt(SymInt::UNCHECKED, storage_offset_);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a reference to the strides of this tensor. This reference remains
|
||||
* valid as long as the tensor is live and not restrided.
|
||||
*/
|
||||
IntArrayRef strides() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomStrides))) {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
return strides_custom();
|
||||
}
|
||||
return strides_default();
|
||||
return sizes_and_strides_.strides_arrayref();
|
||||
}
|
||||
|
||||
// TODO: make it non-virtual after a change to XLA
|
||||
virtual c10::SymIntArrayRef sym_strides() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomStrides))) {
|
||||
c10::SymIntArrayRef sym_strides() const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
return sym_strides_custom();
|
||||
}
|
||||
return sym_strides_default();
|
||||
// strides guaranteed to be non-negative, so unchecked cast is OK
|
||||
return c10::SymIntArrayRef::fromIntArrayRefUnchecked(strides_default());
|
||||
}
|
||||
inline c10::SymIntArrayRef sym_strides_default() const {
|
||||
|
||||
IntArrayRef strides_default() const {
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"Cannot call strides() on tensor with symbolic sizes/strides");
|
||||
return sizes_and_strides_.strides_arrayref();
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef sym_strides_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->strides_;
|
||||
} else {
|
||||
|
|
@ -648,8 +753,36 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
}
|
||||
|
||||
virtual c10::SymIntArrayRef sym_strides_custom() const;
|
||||
/**
|
||||
* Whether or not a tensor is laid out in contiguous memory.
|
||||
*
|
||||
* Tensors with non-trivial strides are not contiguous. See
|
||||
* compute_contiguous() for the exact definition of whether or not
|
||||
* a tensor is contiguous or not.
|
||||
*/
|
||||
bool is_contiguous(
|
||||
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
return is_contiguous_custom(memory_format);
|
||||
}
|
||||
return is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
// These are factored into separate functions in case subclasses
|
||||
// want to use them
|
||||
bool is_contiguous_default(at::MemoryFormat memory_format) const {
|
||||
// TODO: handle symbolic shapes correctly
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(compute_contiguous() == is_contiguous_);
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
return is_channels_last_contiguous_;
|
||||
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
|
||||
return is_channels_last_3d_contiguous_;
|
||||
}
|
||||
return is_contiguous_;
|
||||
}
|
||||
|
||||
// NB: these dim accessor functions don't have _default(), as you can use
|
||||
// sizes_default/strides_default
|
||||
/**
|
||||
* Return the size of a tensor at some dimension, wrapping the dimension if
|
||||
* necessary.
|
||||
|
|
@ -658,9 +791,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* be faster
|
||||
*/
|
||||
int64_t size(int64_t d) const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return size_custom(d);
|
||||
}
|
||||
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
|
||||
|
|
@ -668,9 +799,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
c10::SymInt sym_size(int64_t d) const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_size_custom(d);
|
||||
}
|
||||
d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
|
||||
|
|
@ -687,79 +816,49 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
*/
|
||||
int64_t stride(int64_t d) const {
|
||||
d = maybe_wrap_dim(d, dim(), false);
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomStrides))) {
|
||||
if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
|
||||
// TODO: provide stride_custom, symmetrically with size_custom.
|
||||
// There is presently no user for it; only NestedTensor is using
|
||||
// size_custom overrideability
|
||||
return strides_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
|
||||
}
|
||||
// Intentionally don't call default, which also handles symbolic
|
||||
return sizes_and_strides_.stride_at_unchecked(d);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the number of dimensions of this tensor. Note that 0-dimension
|
||||
* represents a Tensor that is a Scalar, e.g., one that has a single element.
|
||||
*/
|
||||
int64_t dim() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
return dim_custom();
|
||||
}
|
||||
return dim_default();
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of elements in a tensor.
|
||||
*
|
||||
* WARNING: Previously, if you were using the Caffe2 API, you could
|
||||
* test numel() == -1 to see if a tensor was uninitialized. This
|
||||
* is no longer true; numel always accurately reports the product
|
||||
* of sizes of a tensor.
|
||||
*/
|
||||
int64_t numel() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
return numel_custom();
|
||||
}
|
||||
return numel_default();
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether or not a tensor is laid out in contiguous memory.
|
||||
*
|
||||
* Tensors with non-trivial strides are not contiguous. See
|
||||
* compute_contiguous() for the exact definition of whether or not
|
||||
* a tensor is contiguous or not.
|
||||
*/
|
||||
bool is_contiguous(
|
||||
at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomStrides))) {
|
||||
return is_contiguous_custom(memory_format);
|
||||
}
|
||||
return is_contiguous_default(memory_format);
|
||||
}
|
||||
|
||||
inline IntArrayRef strides_default() const {
|
||||
return sizes_and_strides_.strides_arrayref();
|
||||
}
|
||||
|
||||
inline IntArrayRef sizes_default() const {
|
||||
return sizes_and_strides_.sizes_arrayref();
|
||||
}
|
||||
|
||||
inline c10::SymIntArrayRef sym_sizes_default() const {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
return extra_meta_->sizes_;
|
||||
} else {
|
||||
return c10::SymIntArrayRef::fromIntArrayRefKnownNonNegative(
|
||||
sizes_default());
|
||||
}
|
||||
}
|
||||
enum class SizesStridesPolicy : uint8_t {
|
||||
// Default behavior, e.g., dense tensor.
|
||||
//
|
||||
// Can override: nothing
|
||||
Default = 0,
|
||||
// Customizable strides behavior, e.g., sparse tensor,
|
||||
// mkldnn tensor.
|
||||
//
|
||||
// Can override: strides(), is_contiguous()
|
||||
CustomStrides = 1,
|
||||
// Customizable sizes behavior, e.g., nested tensor
|
||||
//
|
||||
// Can override: strides(), is_contiguous(), sizes(), dim(), numel()
|
||||
CustomSizes = 2
|
||||
};
|
||||
|
||||
protected:
|
||||
inline bool matches_policy(SizesStridesPolicy policy) const {
|
||||
return sizes_strides_policy_ >= static_cast<uint8_t>(policy);
|
||||
}
|
||||
|
||||
inline bool matches_custom(SizesStridesPolicy policy) const {
|
||||
return custom_sizes_strides_ >= static_cast<uint8_t>(policy);
|
||||
}
|
||||
|
||||
inline bool matches_python_custom(SizesStridesPolicy policy) const {
|
||||
auto r = python_custom_sizes_strides_ >= static_cast<uint8_t>(policy);
|
||||
if (r) {
|
||||
TORCH_INTERNAL_ASSERT(is_python_dispatch())
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
/**
|
||||
* Customization points for the functions above. sizes_strides_policy_
|
||||
* must be set to enable these.
|
||||
|
|
@ -768,7 +867,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* for a tensor to have rank, but not well defined sizes.
|
||||
*/
|
||||
// sizes_strides_policy_ >= CustomStrides
|
||||
virtual IntArrayRef strides_custom() const;
|
||||
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
|
||||
// sizes_strides_policy_ >= CustomSizes
|
||||
// Currently this method only exists to be overwritten by subclasses such as
|
||||
|
|
@ -790,38 +888,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
virtual IntArrayRef sizes_custom() const;
|
||||
virtual IntArrayRef strides_custom() const;
|
||||
virtual int64_t numel_custom() const;
|
||||
virtual int64_t storage_offset_custom() const;
|
||||
virtual int64_t dim_custom() const;
|
||||
virtual Device device_custom() const;
|
||||
virtual Layout layout_custom() const;
|
||||
|
||||
virtual int64_t dim_custom() const;
|
||||
virtual int64_t numel_custom() const;
|
||||
|
||||
// These are factored into separate functions in case subclasses
|
||||
// want to use them
|
||||
inline bool is_contiguous_default(at::MemoryFormat memory_format) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(compute_contiguous() == is_contiguous_);
|
||||
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
||||
return is_channels_last_contiguous_;
|
||||
} else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
|
||||
return is_channels_last_3d_contiguous_;
|
||||
}
|
||||
return is_contiguous_;
|
||||
}
|
||||
inline int64_t dim_default() const {
|
||||
return sizes_and_strides_.size();
|
||||
}
|
||||
inline c10::Device device_default() const {
|
||||
TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device");
|
||||
// See NOTE [c10::optional operator usage in CUDA]
|
||||
return *device_opt_;
|
||||
}
|
||||
|
||||
inline int64_t numel_default() const {
|
||||
#ifdef DEBUG
|
||||
TORCH_INTERNAL_ASSERT(compute_numel() == numel_);
|
||||
#endif
|
||||
return numel_;
|
||||
}
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const;
|
||||
virtual c10::SymIntArrayRef sym_strides_custom() const;
|
||||
virtual c10::SymInt sym_numel_custom() const;
|
||||
virtual c10::SymInt sym_storage_offset_custom() const;
|
||||
|
||||
public:
|
||||
/**
|
||||
|
|
@ -906,7 +983,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
bool is_meta() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_meta();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kMeta;
|
||||
|
|
@ -915,7 +992,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
bool is_cpu() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_cpu();
|
||||
}
|
||||
// Note: we cannot rely on dispatch keys to determine the device type
|
||||
|
|
@ -927,7 +1004,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
bool is_cuda() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_cuda();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kCUDA;
|
||||
|
|
@ -936,35 +1013,35 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
bool is_xpu() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_xpu();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kXPU;
|
||||
}
|
||||
|
||||
bool is_ipu() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_ipu();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kIPU;
|
||||
}
|
||||
|
||||
bool is_xla() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_xla();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kXLA;
|
||||
}
|
||||
|
||||
bool is_hpu() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_hpu();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kHPU;
|
||||
}
|
||||
|
||||
bool is_lazy() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_lazy();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kLazy;
|
||||
|
|
@ -973,7 +1050,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
bool is_hip() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_hip();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kHIP;
|
||||
|
|
@ -982,7 +1059,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
bool is_ve() const {
|
||||
// NB: This method is not virtual and avoid dispatches for performance
|
||||
// reasons.
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_ve();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kVE;
|
||||
|
|
@ -993,28 +1070,28 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
bool is_vulkan() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_vulkan();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kVulkan;
|
||||
}
|
||||
|
||||
bool is_metal() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_metal();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kMetal;
|
||||
}
|
||||
|
||||
bool is_mps() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_mps();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kMPS;
|
||||
}
|
||||
|
||||
bool is_ort() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().is_ort();
|
||||
}
|
||||
return device_opt_.has_value() && device_opt_->type() == kORT;
|
||||
|
|
@ -1046,21 +1123,29 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
int64_t get_device() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom().index();
|
||||
}
|
||||
return device_default().index();
|
||||
}
|
||||
|
||||
Device device() const {
|
||||
if (C10_UNLIKELY(custom_device_)) {
|
||||
if (C10_UNLIKELY(device_policy_)) {
|
||||
return device_custom();
|
||||
}
|
||||
return device_default();
|
||||
}
|
||||
|
||||
protected:
|
||||
c10::Device device_default() const {
|
||||
TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device");
|
||||
// See NOTE [c10::optional operator usage in CUDA]
|
||||
return *device_opt_;
|
||||
}
|
||||
|
||||
public:
|
||||
Layout layout() const {
|
||||
if (C10_UNLIKELY(custom_layout_)) {
|
||||
if (C10_UNLIKELY(layout_policy_)) {
|
||||
return layout_custom();
|
||||
}
|
||||
|
||||
|
|
@ -1385,17 +1470,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return data_type_.itemsize();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the offset in number of elements into the storage that this
|
||||
* tensor points to. Most tensors have storage_offset() == 0, but,
|
||||
* for example, an index into a tensor will have a non-zero storage_offset().
|
||||
*
|
||||
* WARNING: This is NOT computed in bytes.
|
||||
*/
|
||||
TENSORIMPL_MAYBE_VIRTUAL int64_t storage_offset() const {
|
||||
return storage_offset_;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* Returns the human-readable name of the actual type of this object (e.g.,
|
||||
|
|
@ -1416,11 +1490,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return numel() == 0;
|
||||
}
|
||||
|
||||
// if we are going to use sym sizes, we should be setting sym strides at the
|
||||
// same time, otherwise it's very easy to misuse this API
|
||||
void set_sym_sizes_and_strides(
|
||||
void set_sizes_and_strides(
|
||||
c10::SymIntArrayRef sizes,
|
||||
c10::SymIntArrayRef strides);
|
||||
c10::SymIntArrayRef strides,
|
||||
c10::optional<c10::SymInt> storage_offset = c10::nullopt);
|
||||
|
||||
/**
|
||||
* Change the size at some dimension. This DOES NOT update strides;
|
||||
|
|
@ -1436,8 +1509,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
"set_size ",
|
||||
err_msg_tensor_metadata_change_not_allowed);
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"set_size() called on tensor with symbolic shape")
|
||||
!matches_policy(SizesStridesPolicy::CustomSizes),
|
||||
"set_size() called on tensor with dynamic shapes or customized size behavior")
|
||||
sizes_and_strides_.size_at(dim) = new_size;
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
|
|
@ -1473,6 +1546,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
allow_tensor_metadata_change(),
|
||||
"set_storage_offset ",
|
||||
err_msg_tensor_metadata_change_not_allowed);
|
||||
// TODO: this should probably consult policy
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"set_storage_offset() called on tensor with symbolic shape")
|
||||
storage_offset_ = storage_offset;
|
||||
}
|
||||
|
||||
|
|
@ -1488,15 +1565,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
allow_tensor_metadata_change(),
|
||||
"set_sizes_contiguous ",
|
||||
err_msg_tensor_metadata_change_not_allowed);
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomStrides))) {
|
||||
TORCH_CHECK(false, "todo, I guess we want to throw here");
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"set_sizes_contiguous() called on tensor with symbolic shape")
|
||||
!matches_policy(SizesStridesPolicy::CustomStrides),
|
||||
"tried to directly modify sizes for customized tensor");
|
||||
sizes_and_strides_.set_sizes(new_size);
|
||||
|
||||
refresh_numel();
|
||||
|
|
@ -1510,7 +1581,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* sizes/strides are in bounds for the storage that is allocated;
|
||||
* this is the responsibility of the caller
|
||||
*/
|
||||
void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride) {
|
||||
void set_sizes_and_strides(
|
||||
IntArrayRef new_size,
|
||||
IntArrayRef new_stride,
|
||||
c10::optional<int64_t> storage_offset = c10::nullopt) {
|
||||
TORCH_CHECK(
|
||||
allow_tensor_metadata_change(),
|
||||
"set_sizes_and_strides ",
|
||||
|
|
@ -1554,6 +1628,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
|
||||
refresh_numel();
|
||||
refresh_contiguous();
|
||||
|
||||
if (storage_offset.has_value()) {
|
||||
storage_offset_ = *storage_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -2438,32 +2516,53 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
public:
|
||||
enum class SizesStridesPolicy : uint8_t {
|
||||
// Default behavior, e.g., dense tensor.
|
||||
//
|
||||
// Can override: nothing
|
||||
Default = 0,
|
||||
// Customizable strides behavior, e.g., sparse tensor,
|
||||
// mkldnn tensor.
|
||||
//
|
||||
// Can override: strides(), is_contiguous()
|
||||
CustomStrides = 1,
|
||||
// Customizable sizes behavior, e.g., nested tensor
|
||||
//
|
||||
// Can override: strides(), is_contiguous(), sizes(), dim(), numel()
|
||||
CustomSizes = 2
|
||||
};
|
||||
void set_custom_sizes_strides(SizesStridesPolicy policy) {
|
||||
custom_sizes_strides_ = static_cast<uint8_t>(policy);
|
||||
refresh_sizes_strides_policy();
|
||||
}
|
||||
|
||||
void set_sizes_strides_policy(SizesStridesPolicy policy) {
|
||||
sizes_strides_policy_ = static_cast<uint8_t>(policy);
|
||||
void set_python_custom_sizes_strides(SizesStridesPolicy policy) {
|
||||
python_custom_sizes_strides_ = static_cast<uint8_t>(policy);
|
||||
refresh_sizes_strides_policy();
|
||||
}
|
||||
|
||||
void set_custom_device(bool custom_device) {
|
||||
custom_device_ = custom_device;
|
||||
refresh_device_policy();
|
||||
}
|
||||
|
||||
void set_custom_layout(bool custom_layout) {
|
||||
custom_layout_ = custom_layout;
|
||||
refresh_layout_policy();
|
||||
}
|
||||
|
||||
void set_python_custom_device(bool custom_device) {
|
||||
python_custom_device_ = custom_device;
|
||||
refresh_device_policy();
|
||||
}
|
||||
|
||||
void set_python_custom_layout(bool custom_layout) {
|
||||
python_custom_layout_ = custom_layout;
|
||||
refresh_layout_policy();
|
||||
}
|
||||
|
||||
protected:
|
||||
void refresh_sizes_strides_policy() {
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
sizes_strides_policy_ =
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes);
|
||||
} else {
|
||||
sizes_strides_policy_ =
|
||||
std::max(custom_sizes_strides_, python_custom_sizes_strides_);
|
||||
}
|
||||
}
|
||||
|
||||
void refresh_device_policy() {
|
||||
device_policy_ = custom_device_ || python_custom_device_;
|
||||
}
|
||||
|
||||
void refresh_layout_policy() {
|
||||
layout_policy_ = custom_layout_ || python_custom_layout_;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
@ -2584,8 +2683,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
allow_tensor_metadata_change_ = true;
|
||||
reserved_ = false;
|
||||
sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::Default);
|
||||
custom_sizes_strides_ = static_cast<uint8_t>(SizesStridesPolicy::Default);
|
||||
python_custom_sizes_strides_ =
|
||||
static_cast<uint8_t>(SizesStridesPolicy::Default);
|
||||
python_custom_device_ = false;
|
||||
python_custom_layout_ = false;
|
||||
custom_device_ = false;
|
||||
custom_layout_ = false;
|
||||
device_policy_ = false;
|
||||
layout_policy_ = false;
|
||||
storage_access_should_throw_ = false;
|
||||
has_symbolic_sizes_strides_ = false;
|
||||
}
|
||||
|
|
@ -2648,17 +2754,37 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
|
||||
// Call _custom() virtual methods for
|
||||
// strides()/is_contiguous()/sizes()/dim()/numel()
|
||||
// This is a combination of sizes_strides_custom_dispatch_
|
||||
// and has_symbolic_sizes_strides_
|
||||
uint8_t sizes_strides_policy_ : 2;
|
||||
|
||||
// Whether or not sizes_and_strides_ contains a symbolic value.
|
||||
bool has_symbolic_sizes_strides_ : 1;
|
||||
|
||||
// Call _custom() virtual method for
|
||||
// strides()/is_contiguous()/sizes()/dim()/numel()
|
||||
uint8_t custom_sizes_strides_ : 2;
|
||||
|
||||
// Combo of custom_ and python_custom_
|
||||
bool device_policy_ : 1;
|
||||
bool layout_policy_ : 1;
|
||||
|
||||
// Call _custom() virtual method for device()
|
||||
bool custom_device_ : 1;
|
||||
|
||||
// Call _custom() virtual method for layout()
|
||||
bool custom_layout_ : 1;
|
||||
|
||||
// Call into Python for
|
||||
// strides()/is_contiguous()/sizes()/dim()/numel()
|
||||
uint8_t python_custom_sizes_strides_ : 2;
|
||||
|
||||
// Call into Python for device()
|
||||
bool python_custom_device_ : 1;
|
||||
|
||||
// Call into Python for layout()
|
||||
bool python_custom_layout_ : 1;
|
||||
|
||||
// The set of DispatchKeys which describe this tensor. NB: this
|
||||
// does NOT include Autograd (historically, it did, but
|
||||
// not anymore!)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,18 @@ UndefinedTensorImpl::UndefinedTensorImpl()
|
|||
set_storage_access_should_throw();
|
||||
// TODO: accessing the sizes on an undefined tensor is not meaningful
|
||||
// and should error too, but empirically it does not!
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
}
|
||||
|
||||
bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const {
|
||||
return is_contiguous_default(format);
|
||||
}
|
||||
IntArrayRef UndefinedTensorImpl::strides_custom() const {
|
||||
TORCH_CHECK(false, "strides() called on an undefined Tensor");
|
||||
}
|
||||
SymIntArrayRef UndefinedTensorImpl::sym_strides_custom() const {
|
||||
TORCH_CHECK(false, "sym_strides() called on an undefined Tensor");
|
||||
}
|
||||
|
||||
#ifdef DEBUG
|
||||
bool UndefinedTensorImpl::has_storage() const {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
|
|||
|
||||
protected:
|
||||
bool is_contiguous_custom(MemoryFormat format) const override;
|
||||
IntArrayRef strides_custom() const override;
|
||||
SymIntArrayRef sym_strides_custom() const override;
|
||||
|
||||
private:
|
||||
UndefinedTensorImpl();
|
||||
|
|
|
|||
|
|
@ -54,6 +54,9 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override {
|
||||
PANIC(sym_strides);
|
||||
}
|
||||
c10::SymInt sym_storage_offset(const TensorImpl* self) const override {
|
||||
PANIC(sym_storage_offset);
|
||||
}
|
||||
|
||||
// Just swallow the event, don't do anything
|
||||
void trace_gpu_event_creation(uintptr_t event) const override {}
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ struct C10_API PyInterpreterVTable {
|
|||
virtual c10::Layout layout(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
|
||||
|
||||
virtual void trace_gpu_event_creation(uintptr_t event) const = 0;
|
||||
virtual void trace_gpu_event_deletion(uintptr_t event) const = 0;
|
||||
|
|
|
|||
|
|
@ -288,7 +288,6 @@ def is_inplace(op, variant):
|
|||
vjp_fail = {
|
||||
xfail('tensor_split'), # data_ptr composite compliance
|
||||
xfail('nn.functional.ctc_loss'), # data_ptr composite compliance
|
||||
xfail('to_sparse'),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -299,6 +298,7 @@ class TestOperators(TestCase):
|
|||
xfail('chalf', '', device_type='cpu'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
|
||||
skip('as_strided_scatter', ''), # silent incorrectness; seems flaky
|
||||
xfail('sparse.sampled_addmm', ''), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
xfail('to_sparse', ''), # Could not run 'aten::sum.dim_IntList'
|
||||
}))
|
||||
@opsToleranceOverride('TestOperators', 'test_grad', (
|
||||
tol1('nn.functional.binary_cross_entropy_with_logits',
|
||||
|
|
@ -602,6 +602,8 @@ class TestOperators(TestCase):
|
|||
# got a batched tensor as input while the running_mean or running_var,
|
||||
# which will be updated in place, were not batched.
|
||||
xfail("nn.functional.batch_norm", 'without_cudnn'),
|
||||
# view doesn't work on sparse
|
||||
xfail("to_sparse"),
|
||||
}))
|
||||
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
|
|
@ -676,6 +678,7 @@ class TestOperators(TestCase):
|
|||
xfail('take'), # dynamic
|
||||
xfail('pca_lowrank', ''), # randomness
|
||||
xfail('svd_lowrank', ''), # randomness
|
||||
xfail('to_sparse', ''), # non-dense output
|
||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
|
@ -1032,6 +1035,7 @@ class TestOperators(TestCase):
|
|||
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
|
||||
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
|
||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
skip('to_sparse', ''), # non-dense output
|
||||
|
||||
# fallback path doesn't work
|
||||
# All of the following are bugs and need to be fixed
|
||||
|
|
@ -1126,6 +1130,7 @@ class TestOperators(TestCase):
|
|||
|
||||
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestOperators', 'test_jvpvjp', vjp_fail.union({
|
||||
xfail('to_sparse', ''), # NYI
|
||||
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
|
||||
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
|
||||
xfail('normal', ''),
|
||||
|
|
|
|||
|
|
@ -89,9 +89,6 @@ class FakeSymbolicTensor(torch.Tensor):
|
|||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
device=device,
|
||||
)
|
||||
|
||||
r.sym_shape = sym_shape
|
||||
r.sym_stride = sym_stride
|
||||
return r
|
||||
|
||||
__torch_function__ = _disabled_torch_function_impl
|
||||
|
|
@ -104,22 +101,6 @@ class FakeSymbolicTensor(torch.Tensor):
|
|||
if func_overload in meta_funcs:
|
||||
return meta_funcs[func_overload](*args, **kwargs)
|
||||
|
||||
if func_overload == torch.ops.aten.sym_size.default:
|
||||
self = args[0]
|
||||
return self.sym_shape
|
||||
|
||||
if func_overload == torch.ops.aten.sym_stride.default:
|
||||
self = args[0]
|
||||
return self.sym_stride
|
||||
|
||||
# some calls can be redirected to `sym_size` rather than
|
||||
# `sym_sizes`. `sym_size` uses `dim` to canonicalize an index
|
||||
# so we need to implement both `sym_size` and `dim` for python
|
||||
# tensors
|
||||
if func_overload == torch.ops.aten.dim.default:
|
||||
self = args[0]
|
||||
return len(self.sym_shape)
|
||||
|
||||
if func_overload == torch.ops.aten.new_empty.default:
|
||||
self = args[0]
|
||||
shape = args[1]
|
||||
|
|
|
|||
|
|
@ -602,6 +602,17 @@ class FakeTensorOperatorInvariants(TestCase):
|
|||
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_sparse_new(self):
|
||||
with FakeTensorMode():
|
||||
indices = torch.randn(1, 1, dtype=torch.int64)
|
||||
values = torch.randn(1)
|
||||
extra = (2,)
|
||||
sparse = torch.randn(1).to_sparse()
|
||||
# This used to segfault, now it does not, but it still raises an
|
||||
# error
|
||||
sparse2 = sparse.new(indices, values, extra)
|
||||
|
||||
def test_like_ops(self):
|
||||
for schema in self.get_all_aten_schemas():
|
||||
if "_like" == schema.name[-5:]:
|
||||
|
|
|
|||
|
|
@ -1104,9 +1104,6 @@ symbolic_tensor_failures = {
|
|||
xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
|
||||
xfail('nn.functional.conv1d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.conv2d', ''), # aten.convolution.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.conv_transpose1d', ''), # aten.convolution.default - couldn't find symbolic meta function/decompo...
|
||||
xfail('nn.functional.conv_transpose2d', ''), # aten.convolution.default - couldn't find symbolic meta function/decompo...
|
||||
xfail('nn.functional.conv_transpose3d', ''), # aten.convolution.default - couldn't find symbolic meta function/decompo...
|
||||
xfail('nn.functional.cosine_embedding_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import unittest
|
|||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
|
||||
do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
|
||||
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM
|
||||
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
|
||||
from numbers import Number
|
||||
from typing import Dict, Any
|
||||
|
|
@ -909,6 +909,7 @@ class TestSparse(TestSparseBase):
|
|||
test_shape(10, 20, 0, 0)
|
||||
test_shape(10, 20, 0, 20)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166")
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_t_empty(self, device, dtype):
|
||||
def test_in_place(x):
|
||||
|
|
@ -3330,6 +3331,7 @@ class TestSparse(TestSparseBase):
|
|||
J[i] = g.to_dense() if g.is_sparse else g
|
||||
return J
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1166")
|
||||
def test_op(sparse_dims, nnz, with_size, coalesced):
|
||||
if isinstance(with_size, Number):
|
||||
with_size = [with_size] * sparse_dims
|
||||
|
|
|
|||
|
|
@ -920,7 +920,7 @@ class TestSparseCSR(TestCase):
|
|||
def test_csr_is_contiguous(self):
|
||||
a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Tensors of type SparseCsrTensorImpl do not have is_contiguous"):
|
||||
with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have is_contiguous"):
|
||||
a.is_contiguous()
|
||||
|
||||
def test_csr_double_to_sparse_csr(self):
|
||||
|
|
|
|||
|
|
@ -494,6 +494,8 @@ class FakeTensor(torch.Tensor):
|
|||
return f"FakeTensor({self_repr}, {self.fake_device})"
|
||||
|
||||
def new(self, *args, **kwargs):
|
||||
# TODO: This doesn't work with sparse self
|
||||
|
||||
# torch.Tensor.new does not go through the normal dispatcher pattern
|
||||
# so in order to use the same pattern as normal invocation of
|
||||
# returning meta device within the kernel we need to intercept
|
||||
|
|
@ -502,7 +504,7 @@ class FakeTensor(torch.Tensor):
|
|||
# when attempting to compute an output in meta, so
|
||||
# we compute the real tensor then convert to meta
|
||||
out_device = self.fake_device
|
||||
with no_dispatch():
|
||||
with no_dispatch(), in_kernel_invocation_manager(self.fake_mode):
|
||||
real_out = super().new(*args, **kwargs)
|
||||
|
||||
assert not isinstance(real_out, FakeTensor), real_out
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ struct ConcretePyInterpreterVTable final
|
|||
c10::Layout layout(const TensorImpl* self) const override;
|
||||
c10::SymInt sym_numel(const TensorImpl* self) const override;
|
||||
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override;
|
||||
c10::SymInt sym_storage_offset(const TensorImpl* self) const override;
|
||||
|
||||
void trace_gpu_event_creation(uintptr_t event) const override {
|
||||
concrete_trace_cuda<trace_cuda_event_creation_fn_name>(event);
|
||||
|
|
@ -715,14 +716,14 @@ static PyObject* THPVariable_make_subclass(
|
|||
data.set_requires_grad(r.toBool(2));
|
||||
const auto sizes_strides_policy = r.stringViewOptional(3);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
|
||||
data.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
if (r.toBool(4)) {
|
||||
data.unsafeGetTensorImpl()->set_custom_device(true);
|
||||
data.unsafeGetTensorImpl()->set_python_custom_device(true);
|
||||
}
|
||||
if (r.toBool(5)) {
|
||||
data.unsafeGetTensorImpl()->set_custom_layout(true);
|
||||
data.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
||||
}
|
||||
if (!r.isNone(6)) {
|
||||
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
|
||||
|
|
@ -804,7 +805,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_strides_policy(
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
} else {
|
||||
|
|
@ -819,17 +820,12 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||
|
||||
auto sym_sizes = r.symintlist(1);
|
||||
auto sym_strides = r.symintlist(2);
|
||||
auto sym_storage_offset = r.toSymIntOptional(3);
|
||||
|
||||
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||
|
||||
// TODO: this should probably be sym_sizes, sym_strides AND offset
|
||||
tensor_impl->set_sym_sizes_and_strides(sym_sizes, sym_strides);
|
||||
|
||||
// TODO: this may need to be symbolic as well
|
||||
auto storage_offset = r.toInt64Optional(3);
|
||||
if (storage_offset) {
|
||||
tensor_impl->set_storage_offset(*storage_offset);
|
||||
}
|
||||
tensor_impl->set_sizes_and_strides(
|
||||
sym_sizes, sym_strides, sym_storage_offset.value_or(0));
|
||||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
|
|
@ -842,10 +838,10 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||
tensor.set_requires_grad(r.toBool(9));
|
||||
|
||||
if (r.toBool(11)) {
|
||||
tensor.unsafeGetTensorImpl()->set_custom_device(true);
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_device(true);
|
||||
}
|
||||
if (r.toBool(12)) {
|
||||
tensor.unsafeGetTensorImpl()->set_custom_layout(true);
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar(
|
||||
|
|
@ -2542,6 +2538,29 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
|
|||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"sym_storage_offset",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("sym_storage_offset")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out == Py_None) {
|
||||
return self->sym_storage_offset_default();
|
||||
}
|
||||
return torch::is_symint_node(out)
|
||||
? out.cast<c10::SymIntNodeImpl*>()->toSymInt()
|
||||
: c10::SymInt{py::cast<int64_t>(out)};
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
|
|||
// This is a temporary fix for a PyTorch core issue,
|
||||
// according to https://github.com/pytorch/xla/pull/2682.
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomSizes);
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
|
||||
}
|
||||
|
||||
void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) {
|
||||
|
|
@ -160,10 +160,6 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
|
|||
return c10::SymIntArrayRef::fromIntArrayRef(sizes_custom());
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef LTCTensorImpl::sym_sizes() const {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
|
||||
void LTCTensorImpl::setup_size_properties() {
|
||||
size_t generation = tensor_->generation();
|
||||
if (generation != generation_) {
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
|
|||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
virtual c10::SymIntArrayRef sym_sizes() const override;
|
||||
virtual c10::SymIntArrayRef sym_strides_custom() const override;
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
|
|
|
|||
Loading…
Reference in a new issue