mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Conversions between strided and jagged layouts for Nested Tensors (#115749)"
This reverts commit 9450e198aa.
Reverted https://github.com/pytorch/pytorch/pull/115749 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/115749#issuecomment-2197790226))
This commit is contained in:
parent
24f69eef6a
commit
fa6c0fe3e4
12 changed files with 22 additions and 280 deletions
|
|
@ -331,28 +331,6 @@ Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const
|
|||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::_nested_strided_to_jagged_inverse(const at::Tensor & base, const at::Tensor & mutated_view, at::functionalization::InverseReturnMode inverse_return_mode) {
|
||||
// Mutated view is a jagged NT
|
||||
auto cpp_nt = at::_nested_jagged_to_strided(mutated_view);
|
||||
|
||||
if (inverse_return_mode != InverseReturnMode::NeverView) {
|
||||
return cpp_nt;
|
||||
} else {
|
||||
return cpp_nt.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::_nested_jagged_to_strided_inverse(const at::Tensor & base, const at::Tensor & mutated_view, at::functionalization::InverseReturnMode inverse_return_mode) {
|
||||
// Mutated view is a strided NT
|
||||
auto python_nt = at::_nested_strided_to_jagged(mutated_view);
|
||||
|
||||
if (inverse_return_mode != InverseReturnMode::NeverView) {
|
||||
return python_nt;
|
||||
} else {
|
||||
return python_nt.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
|
||||
if (inverse_return_mode != InverseReturnMode::NeverView) {
|
||||
return at::squeeze(mutated_view, dim);
|
||||
|
|
|
|||
|
|
@ -6224,33 +6224,6 @@
|
|||
CompositeExplicitAutogradNonFunctional: _nested_get_values_copy
|
||||
autogen: _nested_get_values_copy.out
|
||||
|
||||
- func: _nested_strided_to_jagged(Tensor(a) self) -> Tensor(a)
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
dispatch:
|
||||
NestedTensorCPU, NestedTensorCUDA: _nested_strided_to_jagged
|
||||
|
||||
- func: _nested_strided_to_jagged_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
tags: view_copy
|
||||
dispatch:
|
||||
CompositeExplicitAutogradNonFunctional: _nested_strided_to_jagged_copy
|
||||
autogen: _nested_strided_to_jagged_copy.out
|
||||
|
||||
- func: _nested_jagged_to_strided(Tensor(a) self) -> Tensor(a)
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
dispatch: {}
|
||||
|
||||
- func: _nested_jagged_to_strided_copy(Tensor self) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
tags: view_copy
|
||||
dispatch:
|
||||
CompositeExplicitAutogradNonFunctional: _nested_jagged_to_strided_copy
|
||||
autogen: _nested_jagged_to_strided_copy.out
|
||||
|
||||
- func: _nested_get_offsets(Tensor self) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
|
|
|
|||
|
|
@ -90,11 +90,11 @@ Tensor _to_copy_nested(
|
|||
bool non_blocking,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!layout.has_value() || self.layout() == layout.value() || layout.value() == Layout::Jagged,
|
||||
"to(options) doesn't generally support converting to a different layout, "
|
||||
"but for NT we support strided -> jagged conversion, you have ",
|
||||
!layout.has_value() || self.layout() == layout.value(),
|
||||
"to(options) doesn't support converting to a different layout, "
|
||||
"but got self.layout being ",
|
||||
self.layout(),
|
||||
" and options.layout is set as ",
|
||||
" and options.layout set as ",
|
||||
layout.value());
|
||||
auto options =
|
||||
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
|
||||
|
|
@ -112,13 +112,9 @@ Tensor _to_copy_nested(
|
|||
(options.layout() == c10::kStrided));
|
||||
|
||||
Tensor r;
|
||||
auto empty_op_layout = (layout.has_value() && layout.value() == Layout::Jagged) ? Layout::Strided : layout;
|
||||
r = at::empty_like(self, dtype, empty_op_layout, device, pin_out, memory_format);
|
||||
r = at::empty_like(self, dtype, layout, device, pin_out, memory_format);
|
||||
get_nested_tensor_impl(r)->get_buffer().copy_(
|
||||
get_nested_tensor_impl(self)->get_buffer(), non_blocking);
|
||||
if (layout.has_value() && self.layout() != layout.value() && layout.value() == Layout::Jagged) {
|
||||
return at::_nested_strided_to_jagged(r);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -867,72 +867,6 @@ std::tuple<Tensor, Tensor> _nested_compute_contiguous_strides_offsets(const Tens
|
|||
construct_offsets(nested_size));
|
||||
}
|
||||
|
||||
Tensor _nested_strided_to_jagged(const Tensor& self) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
|
||||
// All jagged NT can be converted into strided NTs, but the opposite is not True
|
||||
// Only strided NTs with a single jagged dimension might be converted into
|
||||
// jagged NTs, so first we check for that
|
||||
int ragged_dims_count = 0;
|
||||
int ragged_idx = -1;
|
||||
for (int64_t i = 0; i < self_ptr->dim(); ++i) {
|
||||
if (!self_ptr->opt_size(i).has_value()) {
|
||||
ragged_dims_count++;
|
||||
ragged_idx = i;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ragged_dims_count == 1, "Only strided NTs with 1 ragged dim can be converted to jagged NTs");
|
||||
|
||||
// Once that's checked, we convert the offsets + sizes in strided NT to
|
||||
// offsets + (optionally) lengths for the jagged NT
|
||||
auto ragged_offsets = self_ptr->get_storage_offsets();
|
||||
const int64_t* ragged_offsets_ptr = ragged_offsets.const_data_ptr<int64_t>();
|
||||
auto ragged_sizes = self_ptr->get_nested_sizes();
|
||||
const int64_t* ragged_sizes_ptr = ragged_sizes.const_data_ptr<int64_t>();
|
||||
int64_t post_ragged_stride = 1;
|
||||
for (int64_t i : c10::irange(ragged_idx, ragged_sizes.size(1))) {
|
||||
post_ragged_stride *= ragged_sizes_ptr[i];
|
||||
}
|
||||
auto ragged_offsets_sizes = ragged_offsets.sizes();
|
||||
auto metadata_tensor_options = self_ptr->get_buffer().options().dtype(kLong).device(at::kCPU);
|
||||
auto jagged_offsets = at::empty({ragged_offsets_sizes[0]+1}, metadata_tensor_options);
|
||||
int64_t* jagged_offsets_ptr = jagged_offsets.mutable_data_ptr<int64_t>();
|
||||
auto jagged_lengths = at::empty({ragged_offsets_sizes[0]}, metadata_tensor_options);
|
||||
int64_t* jagged_lengths_ptr = jagged_lengths.mutable_data_ptr<int64_t>();
|
||||
bool lengths_needed = false;
|
||||
int64_t ragged_sizes_stride_0 = ragged_sizes.stride(0);
|
||||
int64_t num_offsets = ragged_offsets.size(0);
|
||||
for (int64_t i : c10::irange(num_offsets)) {
|
||||
jagged_offsets_ptr[i] = int64_t(ragged_offsets_ptr[i] / post_ragged_stride);
|
||||
jagged_lengths_ptr[i] = int64_t(ragged_sizes_ptr[i * ragged_sizes_stride_0 + (ragged_idx-1)]);
|
||||
if (i > 0) {
|
||||
auto offsets_diff = jagged_offsets_ptr[i] - jagged_offsets_ptr[i-1];
|
||||
if (offsets_diff != jagged_lengths_ptr[i-1]) {
|
||||
lengths_needed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jagged_offsets_ptr[num_offsets] = jagged_offsets_ptr[num_offsets-1] + ragged_sizes_ptr[(num_offsets-1)*ragged_sizes_stride_0 + (ragged_idx-1)];
|
||||
|
||||
jagged_offsets = jagged_offsets.to(self_ptr->get_buffer().device());
|
||||
jagged_lengths = jagged_lengths.to(self_ptr->get_buffer().device());
|
||||
|
||||
c10::optional<at::Tensor> jagged_lengths_arg = lengths_needed ? c10::optional(jagged_lengths) : c10::nullopt;
|
||||
std::vector<int64_t> njt_sizes(self_ptr->dim()-1);
|
||||
int njt_sizes_it = 0;
|
||||
for (int64_t i = 0; i < self_ptr->dim(); ++i) {
|
||||
if (i != ragged_idx) {
|
||||
njt_sizes[njt_sizes_it] = self_ptr->size(i);
|
||||
++njt_sizes_it;
|
||||
}
|
||||
}
|
||||
njt_sizes[0] = -1;
|
||||
auto njt_buffer = self_ptr->get_buffer().view(c10::IntArrayRef(njt_sizes));
|
||||
Tensor dummy = at::_nested_get_jagged_dummy(self);
|
||||
return at::_nested_view_from_jagged(njt_buffer, jagged_offsets, dummy, jagged_lengths_arg, ragged_idx);
|
||||
}
|
||||
|
||||
// See Note [Special size rule for nested tensor]
|
||||
Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -453,13 +453,7 @@ aten::_nested_get_ragged_idx
|
|||
aten::_nested_get_values
|
||||
aten::_nested_get_values_copy
|
||||
aten::_nested_get_values_copy.out
|
||||
aten::_nested_jagged_to_strided
|
||||
aten::_nested_jagged_to_strided_copy
|
||||
aten::_nested_jagged_to_strided_copy.out
|
||||
aten::_nested_select_backward
|
||||
aten::_nested_strided_to_jagged
|
||||
aten::_nested_strided_to_jagged_copy
|
||||
aten::_nested_strided_to_jagged_copy.out
|
||||
aten::_nested_sum_backward
|
||||
aten::_nested_tensor_from_mask
|
||||
aten::_nested_tensor_from_mask.out
|
||||
|
|
|
|||
|
|
@ -559,16 +559,9 @@ class TestNestedTensor(TestCase):
|
|||
nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
|
||||
self.assertEqual(result, nested_namespace_result)
|
||||
|
||||
@parametrize(
|
||||
"layout",
|
||||
[torch.strided, torch.jagged],
|
||||
name_fn=lambda l: f"_with_{layout_name(l)}_layout",
|
||||
)
|
||||
def test_to(self, layout):
|
||||
def test_to(self):
|
||||
ntensors = 4
|
||||
nt = random_nt_from_dims(
|
||||
(7, None, 10), torch.device("cpu"), torch.float32, layout=layout
|
||||
)
|
||||
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
||||
|
||||
def test_copy_behavior(t, non_blocking=False):
|
||||
self.assertIs(t, t.to(t, non_blocking=non_blocking))
|
||||
|
|
@ -643,15 +636,6 @@ class TestNestedTensor(TestCase):
|
|||
self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
|
||||
self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)
|
||||
|
||||
# Jagged <-> strided
|
||||
new_layout = torch.jagged if layout == torch.strided else torch.strided
|
||||
new_layout_nt = torch.ops.aten._to_copy(nt, layout=new_layout)
|
||||
self.assertIs(new_layout_nt.layout, new_layout)
|
||||
self.assertEqual(new_layout_nt.device, nt.device)
|
||||
self.assertEqual(new_layout_nt.size(2), nt.size(2))
|
||||
self.assertEqual(new_layout_nt.unbind(), nt.unbind())
|
||||
self.assertNotEqual(new_layout_nt.data_ptr(), nt.data_ptr())
|
||||
|
||||
def test_copy_(self):
|
||||
ntensors = 4
|
||||
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
||||
|
|
@ -5667,40 +5651,6 @@ class TestNestedTensorSubclass(TestCase):
|
|||
expected_grad.unbind()[1].add_(1.0)
|
||||
torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
|
||||
|
||||
def test_layout_conversion(self, device):
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randn(2, 4, device=device),
|
||||
torch.randn(5, 4, device=device),
|
||||
torch.randn(3, 4, device=device),
|
||||
],
|
||||
layout=torch.jagged,
|
||||
)
|
||||
strided_nt = torch.ops.aten._nested_jagged_to_strided(nt)
|
||||
self.assertEqual(strided_nt.unbind(), nt.unbind())
|
||||
self.assertEqual(strided_nt.data_ptr(), nt.data_ptr())
|
||||
|
||||
jagged_nt = torch.ops.aten._nested_strided_to_jagged(strided_nt)
|
||||
self.assertEqual(jagged_nt.unbind(), nt.unbind())
|
||||
self.assertEqual(jagged_nt.data_ptr(), nt.data_ptr())
|
||||
|
||||
def test_layout_conversion_backward(self, device):
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randn(2, 4, device=device),
|
||||
torch.randn(5, 4, device=device),
|
||||
torch.randn(3, 4, device=device),
|
||||
],
|
||||
layout=torch.jagged,
|
||||
requires_grad=True,
|
||||
)
|
||||
strided_nt = torch.ops.aten._nested_jagged_to_strided(nt)
|
||||
jagged_nt = torch.ops.aten._nested_strided_to_jagged(strided_nt)
|
||||
|
||||
jagged_nt.backward(torch.ones_like(jagged_nt))
|
||||
expected_grad = torch.ones_like(nt)
|
||||
torch._dynamo.disable(self.assertEqual)(expected_grad, nt.grad)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
|
|
|
|||
|
|
@ -1949,18 +1949,6 @@
|
|||
AutogradNestedTensor:
|
||||
self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_storage_offsets())
|
||||
|
||||
- name: _nested_jagged_to_strided(Tensor(a) self) -> Tensor(a)
|
||||
self: "_nested_view_from_jagged(grad.values().view_as(at::_nested_get_values(self)),
|
||||
at::_nested_get_offsets(self),
|
||||
at::_nested_get_jagged_dummy(self),
|
||||
at::_nested_get_lengths(self),
|
||||
at::_nested_get_ragged_idx(self),
|
||||
at::_nested_get_min_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_min_seqlen(self)) : c10::nullopt,
|
||||
at::_nested_get_max_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_max_seqlen(self)) : c10::nullopt)"
|
||||
|
||||
- name: _nested_strided_to_jagged(Tensor(a) self) -> Tensor(a)
|
||||
self: at::_nested_jagged_to_strided(grad)
|
||||
|
||||
# Why is _values() not differentiable?
|
||||
# See NOTE [ Sparse: autograd and API ]
|
||||
- name: _values(Tensor(a) self) -> Tensor(a)
|
||||
|
|
|
|||
|
|
@ -62,8 +62,6 @@ VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [
|
|||
"_nested_get_values",
|
||||
"_nested_view_from_buffer",
|
||||
"_nested_view_from_jagged",
|
||||
"_nested_strided_to_jagged",
|
||||
"_nested_jagged_to_strided",
|
||||
]
|
||||
|
||||
VIEW_FUNCTIONS = {
|
||||
|
|
|
|||
|
|
@ -608,7 +608,6 @@ def index_put_impl(fake_mode, func, *args, **kwargs):
|
|||
@register_op_impl(aten._nested_tensor_from_tensor_list.out)
|
||||
@register_op_impl(aten._nested_view_from_buffer.default)
|
||||
@register_op_impl(aten._nested_view_from_buffer_copy.default)
|
||||
@register_op_impl(aten._nested_strided_to_jagged.default)
|
||||
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
|
||||
raise UnsupportedOperatorException(
|
||||
"torch.compile does not support strided NestedTensor"
|
||||
|
|
|
|||
|
|
@ -389,13 +389,3 @@ Example::
|
|||
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths
|
||||
|
||||
return nested_view_from_values_offsets_lengths(values, offsets, lengths, ragged_idx=jagged_dim)
|
||||
|
||||
|
||||
# This library impl is here so pytorch picks it up when initializing, otherwise users had to import
|
||||
# torch.nested._internal.ops to get it, which is not ideal. Importing all of ops here results in a
|
||||
# fun circular dependency hell, so this is the next best thing
|
||||
@torch.library.impl("aten::_nested_get_jagged_dummy", ["default", "NestedTensorCPU", "NestedTensorCUDA"]) # type: ignore[has-type]
|
||||
def _aten_nested_get_jagged_dummy(x) -> Tensor:
|
||||
from torch.nested._internal.nested_tensor import _nt_view_dummy
|
||||
|
||||
return _nt_view_dummy()
|
||||
|
|
|
|||
|
|
@ -200,9 +200,6 @@ class NestedTensor(torch.Tensor):
|
|||
def _min_seqlen(self):
|
||||
return self._get_min_seqlen()
|
||||
|
||||
def data_ptr(self) -> int:
|
||||
return self._values.data_ptr()
|
||||
|
||||
def __repr__(self):
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import operator
|
|||
import torch
|
||||
from torch.nested._internal.sdpa import jagged_scaled_dot_product_attention
|
||||
|
||||
from .nested_tensor import _tensor_symint_registry, NestedTensor
|
||||
from .nested_tensor import NestedTensor
|
||||
from typing import * # noqa: F403
|
||||
import torch.nn.functional as F
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
|
|
@ -439,85 +439,25 @@ def linear_backward_default(func, *args, **kwargs):
|
|||
return (ds, dw, db)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten._to_copy.default,
|
||||
"self: jt_all, dtype: any?, layout: any?, device: any?, pin_memory: any?, non_blocking: any?, memory_format: any?",
|
||||
)
|
||||
def _to_copy_default(func, *args, **kwargs):
|
||||
@register_jagged_func(torch.ops.aten._to_copy.default, "self: jt_all")
|
||||
def to_copy_default(func, *args, **kwargs):
|
||||
from .nested_tensor import _tensor_symint_registry
|
||||
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp: NestedTensor = new_kwargs.pop("input")
|
||||
new_layout = new_kwargs.pop("layout")
|
||||
if new_layout is None:
|
||||
new_layout = inp.layout
|
||||
|
||||
if new_layout not in [torch.strided, torch.jagged]:
|
||||
raise ValueError("Nested Tensors can only have jagged and strided layouts")
|
||||
inp = new_kwargs.pop("input")
|
||||
# don't change layout
|
||||
new_kwargs.pop("layout")
|
||||
|
||||
new_values = func(inp._values, **new_kwargs)
|
||||
|
||||
# Copy to a new Python subclass NestedTensor
|
||||
new_offsets = inp._offsets.to(device=new_values.device)
|
||||
_tensor_symint_registry[new_offsets] = _tensor_symint_registry[inp._offsets]
|
||||
inp_kwargs = extract_kwargs(inp)
|
||||
inp_kwargs["offsets"] = new_offsets
|
||||
|
||||
new_njt = NestedTensor(new_values, **inp_kwargs)
|
||||
|
||||
if new_layout == torch.jagged:
|
||||
return new_njt
|
||||
|
||||
return torch._nested_jagged_to_strided(new_njt)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten._nested_jagged_to_strided.default, "self: jt_all")
|
||||
def _nested_jagged_to_strided(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp: NestedTensor = new_kwargs.pop("input")
|
||||
|
||||
# TODO: Figure out a better way to accomplish this?
|
||||
if torch._subclasses.fake_tensor.is_fake(inp):
|
||||
# NB: NST is not supported in PT2. Calling this op with garbage will hit the
|
||||
# fake tensor unsupported impl and graph break.
|
||||
return torch._nested_view_from_buffer(
|
||||
inp._values.view(-1),
|
||||
nested_size=inp._values,
|
||||
nested_strides=inp._values,
|
||||
offsets=inp._values,
|
||||
)
|
||||
|
||||
# Create a new C++ NT from the Python NestedTensor
|
||||
# Start by creating metadata needed by C++ NT
|
||||
ragged_source = inp.lengths() if inp.lengths() is not None else inp.offsets().diff()
|
||||
|
||||
nested_sizes = torch.empty(
|
||||
(inp.offsets().shape[0] - 1, inp._values.dim()), dtype=torch.int64
|
||||
)
|
||||
non_ragged_dims = list(range(inp._values.dim()))
|
||||
non_ragged_dims = (
|
||||
non_ragged_dims[: inp._ragged_idx - 1] + non_ragged_dims[inp._ragged_idx :]
|
||||
)
|
||||
nested_sizes[:, non_ragged_dims] = torch.tensor(
|
||||
inp._size[1 : inp._ragged_idx] + inp._size[inp._ragged_idx + 1 :]
|
||||
)
|
||||
nested_sizes[:, inp._ragged_idx - 1] = ragged_source
|
||||
nested_strides = torch.empty_like(nested_sizes)
|
||||
nested_strides[:, :] = torch.tensor(inp._strides[1:])
|
||||
nested_offsets = inp.offsets() * functools.reduce(
|
||||
lambda a, b: a * b, inp._values.shape[1:]
|
||||
)
|
||||
nested_offsets = nested_offsets[:-1]
|
||||
return torch._nested_view_from_buffer(
|
||||
inp._values.view(-1),
|
||||
nested_size=nested_sizes.cpu(),
|
||||
nested_strides=nested_strides.cpu(),
|
||||
offsets=nested_offsets.cpu(),
|
||||
)
|
||||
return NestedTensor(new_values, **inp_kwargs)
|
||||
|
||||
|
||||
register_jagged_func(
|
||||
|
|
@ -525,7 +465,6 @@ register_jagged_func(
|
|||
torch.ops.aten.empty_like.default,
|
||||
torch.ops.aten.ones_like.default,
|
||||
torch.ops.aten.zeros_like.default,
|
||||
torch.ops.aten.empty_like.default,
|
||||
torch.ops.aten.randn_like.default,
|
||||
torch.ops.aten.detach.default,
|
||||
],
|
||||
|
|
@ -1246,3 +1185,9 @@ def _nested_get_jagged_dummy(func, *args, **kwargs):
|
|||
from torch.nested._internal.nested_tensor import _nt_view_dummy
|
||||
|
||||
return _nt_view_dummy()
|
||||
|
||||
|
||||
with torch.library._scoped_library("aten", "IMPL") as aten:
|
||||
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CPU")
|
||||
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "CUDA")
|
||||
aten.impl("_nested_get_jagged_dummy", _nested_get_jagged_dummy, "Meta")
|
||||
|
|
|
|||
Loading…
Reference in a new issue