#pragma once #include #include #include #include namespace torch::aot_inductor { template void convert_output_to_handle( const ArrayRefTensor& output, AtenTensorHandle& handle) { handle = output.expensiveCopyToTensor(); } template void convert_outputs_to_handles_helper( const std::tuple...>& outputs, AtenTensorHandle* output_handles, std::index_sequence) { (convert_output_to_handle(std::get(outputs), output_handles[Is]), ...); } template void convert_outputs_to_handles( const std::tuple...>& outputs, AtenTensorHandle* output_handles) { convert_outputs_to_handles_helper( outputs, output_handles, std::make_index_sequence()); } template void convert_handle_to_arrayref_tensor( AtenTensorHandle handle, ArrayRefTensor& input) { void* data_ptr; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr)); int64_t dim; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim)); int64_t numel; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel)); int64_t* sizes; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes)); int64_t* strides; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides)); int32_t dtype; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype)); int32_t device_type; AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type)); int32_t device_index; AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_get_device_index(handle, &device_index)); input = ArrayRefTensor( MiniArrayRef(reinterpret_cast(data_ptr), numel), MiniArrayRef(sizes, dim), MiniArrayRef(strides, dim), device_type, device_index); } template void convert_handles_to_inputs_helper( AtenTensorHandle* input_handles, std::tuple...>& inputs, std::index_sequence) { (convert_handle_to_arrayref_tensor(input_handles[Is], std::get(inputs)), ...); } template void convert_handles_to_inputs( AtenTensorHandle* input_handles, std::tuple...>& inputs) { convert_handles_to_inputs_helper( input_handles, inputs, std::make_index_sequence()); } template void assert_numel(const ArrayRefTensor& tensor, uint64_t numel) { if (tensor.numel() != numel) { std::stringstream err; err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel(); throw std::runtime_error(err.str()); } } } // namespace torch::aot_inductor