Add torch.ops.aten.print (#120295)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120295
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi 2024-02-21 16:53:43 -08:00 committed by PyTorch MergeBot
parent ef9b6d6816
commit f064dec7e0
4 changed files with 16 additions and 0 deletions

View file

@ -13,6 +13,7 @@
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <iostream>
#include <c10/util/Exception.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -22,6 +23,7 @@
#include <ATen/ops/_aminmax_native.h>
#include <ATen/ops/_assert_async_native.h>
#include <ATen/ops/_functional_assert_async_native.h>
#include <ATen/ops/_print_native.h>
#include <ATen/ops/_assert_scalar_native.h>
#include <ATen/ops/_functional_assert_scalar_native.h>
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
@ -71,6 +73,7 @@
#include <ATen/ops/where_native.h>
#include <ATen/ops/zeros_like.h>
#include <iostream>
#include <utility>
#endif
@ -440,6 +443,9 @@ Tensor _functional_assert_async_msg_cpu(
return dep_token.clone();
}
void _print(c10::string_view s) {
std::cout << s << "\n";
}
// Sorting-based algorithm for isin(); used when the number of test elements is large.
static void isin_sorting(

View file

@ -189,6 +189,10 @@
- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> ()
- func: _print(str s) -> ()
dispatch:
CompositeExplicitAutograd: _print
- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()
dispatch:
CompositeExplicitAutograd: sym_constrain_range

View file

@ -597,6 +597,11 @@ def assert_async_meta(val, assert_msg):
return
@register_meta(aten._print.default)
def print_meta(s):
return
@register_meta(aten._make_dep_token.default)
def make_dep_token(
*,

View file

@ -80,6 +80,7 @@ FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
"_nested_tensor_storage_offsets", # returns a vector of ints
"_chunk_grad_outputs_efficient_attention", # returns a bool
"_fused_sdp_choice", # returns an int
"_print", # no return
]
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [