pytorch/torch/_numpy/_unary_ufuncs_impl.py
lezcano a9dca53438 NumPy support in torch.compile (#106211)
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/

We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.

In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.

Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.

All the tests in `tests/torch_np` take about 75s to run.

This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
2023-08-11 00:39:32 +00:00

71 lines
1.7 KiB
Python

"""Export torch work functions for unary ufuncs, rename/tweak to match numpy.
This listing is further exported to public symbols in the `_numpy/_ufuncs.py` module.
"""
import torch
from torch import ( # noqa: F401
absolute as fabs, # noqa: F401
arccos, # noqa: F401
arccosh, # noqa: F401
arcsin, # noqa: F401
arcsinh, # noqa: F401
arctan, # noqa: F401
arctanh, # noqa: F401
bitwise_not, # noqa: F401
bitwise_not as invert, # noqa: F401
ceil, # noqa: F401
conj_physical as conjugate, # noqa: F401
cos, # noqa: F401
cosh, # noqa: F401
deg2rad, # noqa: F401
deg2rad as radians, # noqa: F401
exp, # noqa: F401
exp2, # noqa: F401
expm1, # noqa: F401
floor, # noqa: F401
isfinite, # noqa: F401
isinf, # noqa: F401
isnan, # noqa: F401
log, # noqa: F401
log10, # noqa: F401
log1p, # noqa: F401
log2, # noqa: F401
logical_not, # noqa: F401
negative, # noqa: F401
rad2deg, # noqa: F401
rad2deg as degrees, # noqa: F401
reciprocal, # noqa: F401
round as fix, # noqa: F401
round as rint, # noqa: F401
sign, # noqa: F401
signbit, # noqa: F401
sin, # noqa: F401
sinh, # noqa: F401
sqrt, # noqa: F401
square, # noqa: F401
tan, # noqa: F401
tanh, # noqa: F401
trunc, # noqa: F401
)
# special cases: torch does not export these names
def cbrt(x):
return torch.pow(x, 1 / 3)
def positive(x):
return +x
def absolute(x):
# work around torch.absolute not impl for bools
if x.dtype == torch.bool:
return x
return torch.absolute(x)
# TODO set __name__ and __qualname__
abs = absolute
conj = conjugate