diff --git a/aten/src/ATen/native/CPUFallback.cpp b/aten/src/ATen/native/CPUFallback.cpp index b8d9d3b9347..27208ab6f6d 100644 --- a/aten/src/ATen/native/CPUFallback.cpp +++ b/aten/src/ATen/native/CPUFallback.cpp @@ -108,6 +108,22 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool auto cpu_ivalue = c10::IValue(c10::List(to_cpu(ivalue.toTensorList().vec()))); tensorlist_cpu_args.push_back(cpu_ivalue); (*stack)[arguments_begin + idx] = std::move(cpu_ivalue); + tensorlist_args.push_back(ivalue.toTensorList()); + } else if (ivalue.isOptionalTensorList()) { + auto opt_tensors = ivalue.toOptionalTensorList().vec(); + std::vector need_convert_tensors; + std::vector need_convert_tensors_index; + for (auto i : c10::irange(opt_tensors.size())) { + if (!opt_tensors[i].has_value() || !opt_tensors[i]->defined()) continue; + need_convert_tensors.push_back(opt_tensors[i].value()); + need_convert_tensors_index.push_back(i); + } + auto cpu_tensors = to_cpu(need_convert_tensors); + for (const auto i : c10::irange(need_convert_tensors_index.size())) { + auto idx = need_convert_tensors_index[i]; + opt_tensors[idx] = cpu_tensors[i]; + } + (*stack)[arguments_begin + idx] = c10::IValue(opt_tensors); } } // XLA requires all of the tensor arguments to be gathered up and converted to CPU together. diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index cfe49e02a48..5cf6ea1df90 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -48,8 +48,6 @@ void abs_kernel(::at::TensorIteratorBase& iter) { abs_counter += 1; } -} // namespace - void quantize_tensor_per_tensor_affine_privateuse1( const at::Tensor& rtensor, at::Tensor& qtensor, @@ -58,6 +56,8 @@ void quantize_tensor_per_tensor_affine_privateuse1( // do nothing } +} // namespace + namespace at::native { REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel); @@ -366,6 +366,7 @@ void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); + m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); } // This basic implementation doesn't bother dealing with different device indices diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 6e264b760ba..3a99423d709 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -452,17 +452,21 @@ class TestCppExtensionOpenRgistration(common.TestCase): def test_open_device_tensor_type_fallback(): torch.utils.rename_privateuse1_backend('foo') # create tensors located in custom device - x = torch.Tensor([1, 2, 3]).to('foo') + x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to('foo') y = torch.Tensor([1, 0, 2]).to('foo') # create result tensor located in cpu - z_cpu = torch.Tensor([0, 2, 1]) + z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]]) # Check that our device is correct. device = self.module.custom_device() self.assertTrue(x.device == device) self.assertFalse(x.is_cpu) # call sub op, which will fallback to cpu z = torch.sub(x, y) - + self.assertEqual(z_cpu, z) + # call index op, which will fallback to cpu + z_cpu = torch.Tensor([3, 1]) + y = torch.Tensor([1, 0]).long().to('foo') + z = x[y, y] self.assertEqual(z_cpu, z) def test_open_device_tensorlist_type_fallback():