Enable optional tensorList fallback to cpu. (#119273)

add optional tensorList fallback to cpu.
Add testcases and old pr is: https://github.com/pytorch/pytorch/pull/106449

@bdhirsh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119273
Approved by: https://github.com/bdhirsh
This commit is contained in:
Shan19900305 2024-02-07 03:54:10 +00:00 committed by PyTorch MergeBot
parent 53ee47ca32
commit 6c3600d008
3 changed files with 26 additions and 5 deletions

View file

@ -108,6 +108,22 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorList().vec())));
tensorlist_cpu_args.push_back(cpu_ivalue);
(*stack)[arguments_begin + idx] = std::move(cpu_ivalue);
tensorlist_args.push_back(ivalue.toTensorList());
} else if (ivalue.isOptionalTensorList()) {
auto opt_tensors = ivalue.toOptionalTensorList().vec();
std::vector<at::Tensor> need_convert_tensors;
std::vector<int> need_convert_tensors_index;
for (auto i : c10::irange(opt_tensors.size())) {
if (!opt_tensors[i].has_value() || !opt_tensors[i]->defined()) continue;
need_convert_tensors.push_back(opt_tensors[i].value());
need_convert_tensors_index.push_back(i);
}
auto cpu_tensors = to_cpu(need_convert_tensors);
for (const auto i : c10::irange(need_convert_tensors_index.size())) {
auto idx = need_convert_tensors_index[i];
opt_tensors[idx] = cpu_tensors[i];
}
(*stack)[arguments_begin + idx] = c10::IValue(opt_tensors);
}
}
// XLA requires all of the tensor arguments to be gathered up and converted to CPU together.

View file

@ -48,8 +48,6 @@ void abs_kernel(::at::TensorIteratorBase& iter) {
abs_counter += 1;
}
} // namespace
void quantize_tensor_per_tensor_affine_privateuse1(
const at::Tensor& rtensor,
at::Tensor& qtensor,
@ -58,6 +56,8 @@ void quantize_tensor_per_tensor_affine_privateuse1(
// do nothing
}
} // namespace
namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
@ -366,6 +366,7 @@ void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
// This basic implementation doesn't bother dealing with different device indices

View file

@ -452,17 +452,21 @@ class TestCppExtensionOpenRgistration(common.TestCase):
def test_open_device_tensor_type_fallback():
torch.utils.rename_privateuse1_backend('foo')
# create tensors located in custom device
x = torch.Tensor([1, 2, 3]).to('foo')
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to('foo')
y = torch.Tensor([1, 0, 2]).to('foo')
# create result tensor located in cpu
z_cpu = torch.Tensor([0, 2, 1])
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
# Check that our device is correct.
device = self.module.custom_device()
self.assertTrue(x.device == device)
self.assertFalse(x.is_cpu)
# call sub op, which will fallback to cpu
z = torch.sub(x, y)
self.assertEqual(z_cpu, z)
# call index op, which will fallback to cpu
z_cpu = torch.Tensor([3, 1])
y = torch.Tensor([1, 0]).long().to('foo')
z = x[y, y]
self.assertEqual(z_cpu, z)
def test_open_device_tensorlist_type_fallback():