2024-06-29 04:48:07 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2021-12-17 22:06:40 +00:00
|
|
|
import contextlib
|
|
|
|
|
import functools
|
|
|
|
|
import hashlib
|
2021-10-28 17:43:11 +00:00
|
|
|
import os
|
2021-12-17 22:06:40 +00:00
|
|
|
import re
|
Pretty-print dataclasses (#76810)
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
2022-05-16 14:20:41 +00:00
|
|
|
import sys
|
2022-07-16 03:52:25 +00:00
|
|
|
import textwrap
|
|
|
|
|
from dataclasses import fields, is_dataclass
|
2022-12-18 22:55:22 +00:00
|
|
|
from enum import auto, Enum
|
2024-07-05 21:47:12 +00:00
|
|
|
from pathlib import Path
|
2024-12-02 21:46:15 +00:00
|
|
|
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
|
2023-10-13 21:19:50 +00:00
|
|
|
from typing_extensions import Self
|
|
|
|
|
|
Rename tools/codegen to torchgen (#76275)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275
In preparation for addressing
https://github.com/pytorch/pytorch/issues/73212
Diff was generated with:
```
git mv tools/codegen torchgen
git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g'
sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt
```
and a manual edits to:
* tools/test/test_gen_backend_stubs.py
* torchgen/build.bzl
* torchgen/gen_backend_stubs.py
aka this diff:
```
diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 3dc26c6d2d..104054575e 100644
--- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run
from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401
path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py')
+gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py')
# gen_backend_stubs.py is an integration point that is called directly by external backends.
# The tests here are to confirm that badly formed inputs result in reasonable error messages.
diff --git a/torchgen/build.bzl b/torchgen/build.bzl
index ed04e35a43..d00078a3cf 100644
--- a/torchgen/build.bzl
+++ b/torchgen/build.bzl
@@ -1,6 +1,6 @@
def define_targets(rules):
rules.py_library(
- name = "codegen",
+ name = "torchgen",
srcs = rules.glob(["**/*.py"]),
deps = [
rules.requirement("PyYAML"),
@@ -11,6 +11,6 @@ def define_targets(rules):
rules.py_binary(
name = "gen",
- srcs = [":codegen"],
+ srcs = [":torchgen"],
visibility = ["//visibility:public"],
)
diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index c1a672a655..beee7a15e0 100644
--- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -474,7 +474,7 @@ def run(
) -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
- pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+ pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
def make_file_manager(install_dir: str) -> FileManager:
```
run_all_fbandroid_tests
Test Plan: sandcastle
Reviewed By: albanD, ngimel
Differential Revision: D35770317
fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf
(cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
2022-04-25 01:32:01 +00:00
|
|
|
from torchgen.code_template import CodeTemplate
|
2021-02-04 17:10:34 +00:00
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from argparse import Namespace
|
2024-12-02 21:46:15 +00:00
|
|
|
from collections.abc import Iterable, Iterator, Sequence
|
2024-06-29 04:48:07 +00:00
|
|
|
|
|
|
|
|
|
2024-07-05 21:47:12 +00:00
|
|
|
REPO_ROOT = Path(__file__).absolute().parent.parent
|
|
|
|
|
|
|
|
|
|
|
2021-02-04 17:10:34 +00:00
|
|
|
# Many of these functions share logic for defining both the definition
|
|
|
|
|
# and declaration (for example, the function signature is the same), so
|
|
|
|
|
# we organize them into one function that takes a Target to say which
|
|
|
|
|
# code we want.
|
2021-02-04 17:10:34 +00:00
|
|
|
#
|
|
|
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
2023-02-09 19:17:46 +00:00
|
|
|
# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
|
|
|
|
|
# what targets are valid for your use.
|
2022-12-18 22:55:22 +00:00
|
|
|
class Target(Enum):
|
|
|
|
|
# top level namespace (not including at)
|
|
|
|
|
DEFINITION = auto()
|
|
|
|
|
DECLARATION = auto()
|
|
|
|
|
# TORCH_LIBRARY(...) { ... }
|
|
|
|
|
REGISTRATION = auto()
|
|
|
|
|
# namespace { ... }
|
|
|
|
|
ANONYMOUS_DEFINITION = auto()
|
|
|
|
|
# namespace cpu { ... }
|
|
|
|
|
NAMESPACED_DEFINITION = auto()
|
|
|
|
|
NAMESPACED_DECLARATION = auto()
|
|
|
|
|
|
2020-11-20 05:44:43 +00:00
|
|
|
|
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
|
|
|
# occurrence of a parameter in the derivative formula
|
2022-04-19 12:25:45 +00:00
|
|
|
IDENT_REGEX = r"(^|\W){}($|\W)"
|
2020-11-20 05:44:43 +00:00
|
|
|
|
2023-03-15 02:46:45 +00:00
|
|
|
|
2020-11-20 05:44:43 +00:00
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
2024-06-29 04:48:07 +00:00
|
|
|
def split_name_params(schema: str) -> tuple[str, list[str]]:
|
2022-04-19 12:25:45 +00:00
|
|
|
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
|
2020-11-20 05:44:43 +00:00
|
|
|
if m is None:
|
2022-04-19 12:25:45 +00:00
|
|
|
raise RuntimeError(f"Unsupported function schema: {schema}")
|
2020-11-20 05:44:43 +00:00
|
|
|
name, _, params = m.groups()
|
2022-04-19 12:25:45 +00:00
|
|
|
return name, params.split(", ")
|
|
|
|
|
|
2021-02-04 17:10:34 +00:00
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
T = TypeVar("T")
|
|
|
|
|
S = TypeVar("S")
|
2021-02-04 17:10:34 +00:00
|
|
|
|
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
|
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
|
|
2023-03-15 02:46:45 +00:00
|
|
|
|
2021-02-04 17:10:34 +00:00
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
2024-06-29 04:48:07 +00:00
|
|
|
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
|
2021-02-04 17:10:34 +00:00
|
|
|
for x in xs:
|
|
|
|
|
r = func(x)
|
|
|
|
|
if r is not None:
|
|
|
|
|
yield r
|
|
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
|
2021-02-04 17:10:34 +00:00
|
|
|
# Map over function that returns sequences and cat them all together
|
|
|
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
|
|
|
for x in xs:
|
2023-03-29 19:15:24 +00:00
|
|
|
yield from func(x)
|
2021-02-04 17:10:34 +00:00
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
|
2021-02-04 17:10:34 +00:00
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
|
|
|
# easily say that an error occurred while processing a specific
|
|
|
|
|
# context.
|
|
|
|
|
@contextlib.contextmanager
|
2021-06-12 13:55:44 +00:00
|
|
|
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
2021-02-04 17:10:34 +00:00
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# TODO: this does the wrong thing with KeyError
|
2021-06-12 13:55:44 +00:00
|
|
|
msg = msg_fn()
|
2022-04-19 12:25:45 +00:00
|
|
|
msg = textwrap.indent(msg, " ")
|
|
|
|
|
msg = f"{e.args[0]}\n{msg}" if e.args else msg
|
2021-02-04 17:10:34 +00:00
|
|
|
e.args = (msg,) + e.args[1:]
|
|
|
|
|
raise
|
2021-10-28 17:43:11 +00:00
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
|
2021-10-28 17:43:11 +00:00
|
|
|
# A little trick from https://github.com/python/mypy/issues/6366
|
|
|
|
|
# for getting mypy to do exhaustiveness checking
|
|
|
|
|
# TODO: put this somewhere else, maybe
|
|
|
|
|
def assert_never(x: NoReturn) -> NoReturn:
|
2023-07-18 01:20:32 +00:00
|
|
|
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
2021-10-28 17:43:11 +00:00
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
|
2024-12-02 21:46:15 +00:00
|
|
|
@functools.cache
|
2021-10-28 17:43:11 +00:00
|
|
|
def _read_template(template_fn: str) -> CodeTemplate:
|
|
|
|
|
return CodeTemplate.from_file(template_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# String hash that's stable across different executions, unlike builtin hash
|
|
|
|
|
def string_stable_hash(s: str) -> int:
|
2022-04-19 12:25:45 +00:00
|
|
|
sha1 = hashlib.sha1(s.encode("latin1")).digest()
|
|
|
|
|
return int.from_bytes(sha1, byteorder="little")
|
|
|
|
|
|
2021-10-28 17:43:11 +00:00
|
|
|
|
|
|
|
|
# A small abstraction for writing out generated files and keeping track
|
|
|
|
|
# of what files have been written (so you can write out a list of output
|
|
|
|
|
# files)
|
|
|
|
|
class FileManager:
|
|
|
|
|
install_dir: str
|
|
|
|
|
template_dir: str
|
|
|
|
|
dry_run: bool
|
2024-06-29 04:48:07 +00:00
|
|
|
filenames: set[str]
|
2021-10-28 17:43:11 +00:00
|
|
|
|
|
|
|
|
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
|
|
|
|
self.install_dir = install_dir
|
|
|
|
|
self.template_dir = template_dir
|
|
|
|
|
self.filenames = set()
|
|
|
|
|
self.dry_run = dry_run
|
|
|
|
|
|
|
|
|
|
def _write_if_changed(self, filename: str, contents: str) -> None:
|
2024-06-29 04:48:07 +00:00
|
|
|
old_contents: str | None
|
2021-10-28 17:43:11 +00:00
|
|
|
try:
|
2023-07-18 01:20:32 +00:00
|
|
|
with open(filename) as f:
|
2021-10-28 17:43:11 +00:00
|
|
|
old_contents = f.read()
|
2023-07-18 01:20:32 +00:00
|
|
|
except OSError:
|
2021-10-28 17:43:11 +00:00
|
|
|
old_contents = None
|
|
|
|
|
if contents != old_contents:
|
2021-12-17 22:06:40 +00:00
|
|
|
# Create output directory if it doesn't exist
|
|
|
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
2022-04-19 12:25:45 +00:00
|
|
|
with open(filename, "w") as f:
|
2021-10-28 17:43:11 +00:00
|
|
|
f.write(contents)
|
|
|
|
|
|
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.
For example, if the entries in yaml file looks like:
```
- func: op_1(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: at::op_1_kernel # ATen kernel
- func: op_2(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: custom::op_2_kernel # custom kernel
```
We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:
`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at
namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom
```
Notice the difference between `at::cpu` and `custom::cpu`.
Then the definition for these can be found in `RegisterCPU.cpp`.
`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"
namespace at {
namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_1_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}
namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at
namespace custom {
namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_2_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}
namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom
```
The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:
1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers
This customized operator will have nothing to do with `at::native`, `at::cpu` etc.
Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs
Differential Revision: D37972772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
2022-08-04 07:48:44 +00:00
|
|
|
# Read from template file and replace pattern with callable (type could be dict or str).
|
|
|
|
|
def substitute_with_template(
|
2024-06-29 04:48:07 +00:00
|
|
|
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
|
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.
For example, if the entries in yaml file looks like:
```
- func: op_1(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: at::op_1_kernel # ATen kernel
- func: op_2(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: custom::op_2_kernel # custom kernel
```
We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:
`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at
namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom
```
Notice the difference between `at::cpu` and `custom::cpu`.
Then the definition for these can be found in `RegisterCPU.cpp`.
`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"
namespace at {
namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_1_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}
namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at
namespace custom {
namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_2_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}
namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom
```
The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:
1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers
This customized operator will have nothing to do with `at::native`, `at::cpu` etc.
Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs
Differential Revision: D37972772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
2022-08-04 07:48:44 +00:00
|
|
|
) -> str:
|
|
|
|
|
template_path = os.path.join(self.template_dir, template_fn)
|
|
|
|
|
env = env_callable()
|
|
|
|
|
if isinstance(env, dict):
|
|
|
|
|
if "generated_comment" not in env:
|
2024-07-05 21:47:12 +00:00
|
|
|
generator_default = REPO_ROOT / "torchgen" / "gen.py"
|
|
|
|
|
try:
|
|
|
|
|
generator = Path(
|
|
|
|
|
sys.modules["__main__"].__file__ or generator_default
|
|
|
|
|
).absolute()
|
|
|
|
|
except (KeyError, AttributeError):
|
|
|
|
|
generator = generator_default.absolute()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
generator_path = generator.relative_to(REPO_ROOT).as_posix()
|
|
|
|
|
except ValueError:
|
|
|
|
|
generator_path = generator.name
|
|
|
|
|
|
|
|
|
|
env = {
|
|
|
|
|
**env, # copy the original dict instead of mutating it
|
|
|
|
|
"generated_comment": (
|
|
|
|
|
"@" + f"generated by {generator_path} from {template_fn}"
|
|
|
|
|
),
|
|
|
|
|
}
|
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.
For example, if the entries in yaml file looks like:
```
- func: op_1(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: at::op_1_kernel # ATen kernel
- func: op_2(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: custom::op_2_kernel # custom kernel
```
We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:
`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at
namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom
```
Notice the difference between `at::cpu` and `custom::cpu`.
Then the definition for these can be found in `RegisterCPU.cpp`.
`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"
namespace at {
namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_1_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}
namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at
namespace custom {
namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_2_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}
namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom
```
The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:
1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers
This customized operator will have nothing to do with `at::native`, `at::cpu` etc.
Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs
Differential Revision: D37972772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
2022-08-04 07:48:44 +00:00
|
|
|
template = _read_template(template_path)
|
|
|
|
|
return template.substitute(env)
|
|
|
|
|
elif isinstance(env, str):
|
|
|
|
|
return env
|
|
|
|
|
else:
|
|
|
|
|
assert_never(env)
|
|
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
def write_with_template(
|
|
|
|
|
self,
|
|
|
|
|
filename: str,
|
|
|
|
|
template_fn: str,
|
2024-06-29 04:48:07 +00:00
|
|
|
env_callable: Callable[[], str | dict[str, Any]],
|
2022-04-19 12:25:45 +00:00
|
|
|
) -> None:
|
2023-07-18 01:20:32 +00:00
|
|
|
filename = f"{self.install_dir}/{filename}"
|
2021-10-28 17:43:11 +00:00
|
|
|
assert filename not in self.filenames, "duplicate file write {filename}"
|
|
|
|
|
self.filenames.add(filename)
|
|
|
|
|
if not self.dry_run:
|
[torchgen] Generate wrapper functions under custom namespaces (#81744)
Summary:
A follow up of #81581. Before these 2 PRs, if an operator with custom kernel namespace is added to `native_functions.yaml` (or any other yaml consumed by `torchgen`), although we are able to recognize the custom kernel in files such as `NativeFunctions.h` and `RegisterCPU.cpp`, we still generate backend specific wrappers under the hardcoded `at` namespace. This changes the behavior, by generating wrapper functions under custom namespaces.
For example, if the entries in yaml file looks like:
```
- func: op_1(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: at::op_1_kernel # ATen kernel
- func: op_2(Tensor(a) self) -> Tensor(a)
dispatch:
CPU: custom::op_2_kernel # custom kernel
```
We generate the following code for `CPUFunctions_inl.h` and `RegisterCPU.cpp`:
`CPUFunctions_inl.h`:
```
namespace at {
namespace cpu {
TORCH_API at::Tensor & op_1(const at::Tensor & self);
} // namespace cpu
} // namespace at
namespace custom {
namespace cpu {
TORCH_API at::Tensor & op_2(const at::Tensor & self);
} // namespace cpu
} // namespace custom
```
Notice the difference between `at::cpu` and `custom::cpu`.
Then the definition for these can be found in `RegisterCPU.cpp`.
`RegisterCPU.cpp`:
```
#include "CPUFunctions.h"
namespace at {
namespace {
at::Tensor & wrapper_op_1(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_1_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_1", TORCH_FN(wrapper_op_1));
}
namespace cpu {
at::Tensor & op_1(at::Tensor & self) {
return wrapper_op_1(self);
}
} // namespace cpu
} // namespace at
namespace custom {
namespace {
at::Tensor & wrapper_op_2(const at::Tensor & self) {
// No device check
// DeviceGuard omitted
return at::native::op_2_kernel(self);
}
} // anonymous namespace
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("op_2", TORCH_FN(wrapper_op_2));
}
namespace cpu {
at::Tensor & op_2(at::Tensor & self) {
return wrapper_op_2(self);
}
} // namespace cpu
} // namespace custom
```
The benefit for this change is that it unifies all the namespaces derived from custom ops. In the example above, there are:
1. `custom::native` for kernels
2. `custom::<dispatch_key>` e.g., `custom::cpu` for wrappers
This customized operator will have nothing to do with `at::native`, `at::cpu` etc.
Test Plan: This is very hard to test. I will refactor this logic, abstract out some layers so it's testable. Will do it in coming PRs
Differential Revision: D37972772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81744
Approved by: https://github.com/bdhirsh
2022-08-04 07:48:44 +00:00
|
|
|
substitute_out = self.substitute_with_template(
|
|
|
|
|
template_fn=template_fn,
|
|
|
|
|
env_callable=env_callable,
|
|
|
|
|
)
|
|
|
|
|
self._write_if_changed(filename=filename, contents=substitute_out)
|
2021-10-28 17:43:11 +00:00
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
def write(
|
|
|
|
|
self,
|
|
|
|
|
filename: str,
|
2024-06-29 04:48:07 +00:00
|
|
|
env_callable: Callable[[], str | dict[str, Any]],
|
2022-04-19 12:25:45 +00:00
|
|
|
) -> None:
|
2021-10-28 17:43:11 +00:00
|
|
|
self.write_with_template(filename, filename, env_callable)
|
|
|
|
|
|
|
|
|
|
def write_sharded(
|
2022-04-19 12:25:45 +00:00
|
|
|
self,
|
|
|
|
|
filename: str,
|
|
|
|
|
items: Iterable[T],
|
|
|
|
|
*,
|
|
|
|
|
key_fn: Callable[[T], str],
|
2024-06-29 04:48:07 +00:00
|
|
|
env_callable: Callable[[T], dict[str, list[str]]],
|
2022-04-19 12:25:45 +00:00
|
|
|
num_shards: int,
|
2024-06-29 04:48:07 +00:00
|
|
|
base_env: dict[str, Any] | None = None,
|
|
|
|
|
sharded_keys: set[str],
|
2025-01-09 23:00:21 +00:00
|
|
|
) -> None:
|
|
|
|
|
self.write_sharded_with_template(
|
|
|
|
|
filename,
|
|
|
|
|
filename,
|
|
|
|
|
items,
|
|
|
|
|
key_fn=key_fn,
|
|
|
|
|
env_callable=env_callable,
|
|
|
|
|
num_shards=num_shards,
|
|
|
|
|
base_env=base_env,
|
|
|
|
|
sharded_keys=sharded_keys,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def write_sharded_with_template(
|
|
|
|
|
self,
|
|
|
|
|
filename: str,
|
|
|
|
|
template_fn: str,
|
|
|
|
|
items: Iterable[T],
|
|
|
|
|
*,
|
|
|
|
|
key_fn: Callable[[T], str],
|
|
|
|
|
env_callable: Callable[[T], dict[str, list[str]]],
|
|
|
|
|
num_shards: int,
|
|
|
|
|
base_env: dict[str, Any] | None = None,
|
|
|
|
|
sharded_keys: set[str],
|
2021-10-28 17:43:11 +00:00
|
|
|
) -> None:
|
2024-06-29 04:48:07 +00:00
|
|
|
everything: dict[str, Any] = {"shard_id": "Everything"}
|
|
|
|
|
shards: list[dict[str, Any]] = [
|
2022-04-19 12:25:45 +00:00
|
|
|
{"shard_id": f"_{i}"} for i in range(num_shards)
|
|
|
|
|
]
|
2021-10-28 17:43:11 +00:00
|
|
|
all_shards = [everything] + shards
|
|
|
|
|
|
|
|
|
|
if base_env is not None:
|
|
|
|
|
for shard in all_shards:
|
|
|
|
|
shard.update(base_env)
|
|
|
|
|
|
|
|
|
|
for key in sharded_keys:
|
|
|
|
|
for shard in all_shards:
|
|
|
|
|
if key in shard:
|
2022-04-19 12:25:45 +00:00
|
|
|
assert isinstance(
|
|
|
|
|
shard[key], list
|
|
|
|
|
), "sharded keys in base_env must be a list"
|
2021-10-28 17:43:11 +00:00
|
|
|
shard[key] = shard[key].copy()
|
|
|
|
|
else:
|
|
|
|
|
shard[key] = []
|
|
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
|
2021-10-28 17:43:11 +00:00
|
|
|
for k, v in from_.items():
|
|
|
|
|
assert k in sharded_keys, f"undeclared sharded key {k}"
|
|
|
|
|
into[k] += v
|
|
|
|
|
|
2022-02-02 19:20:44 +00:00
|
|
|
if self.dry_run:
|
|
|
|
|
# Dry runs don't write any templates, so incomplete environments are fine
|
|
|
|
|
items = ()
|
|
|
|
|
|
2021-10-28 17:43:11 +00:00
|
|
|
for item in items:
|
|
|
|
|
key = key_fn(item)
|
|
|
|
|
sid = string_stable_hash(key) % num_shards
|
|
|
|
|
env = env_callable(item)
|
|
|
|
|
|
|
|
|
|
merge_env(shards[sid], env)
|
|
|
|
|
merge_env(everything, env)
|
|
|
|
|
|
2022-04-19 12:25:45 +00:00
|
|
|
dot_pos = filename.rfind(".")
|
2021-10-28 17:43:11 +00:00
|
|
|
if dot_pos == -1:
|
|
|
|
|
dot_pos = len(filename)
|
|
|
|
|
base_filename = filename[:dot_pos]
|
|
|
|
|
extension = filename[dot_pos:]
|
|
|
|
|
|
|
|
|
|
for shard in all_shards:
|
2022-04-19 12:25:45 +00:00
|
|
|
shard_id = shard["shard_id"]
|
|
|
|
|
self.write_with_template(
|
2025-01-09 23:00:21 +00:00
|
|
|
f"{base_filename}{shard_id}{extension}",
|
|
|
|
|
template_fn,
|
|
|
|
|
lambda: shard,
|
2022-04-19 12:25:45 +00:00
|
|
|
)
|
2021-10-28 17:43:11 +00:00
|
|
|
|
|
|
|
|
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
|
|
|
|
self.filenames.discard(
|
2022-04-19 12:25:45 +00:00
|
|
|
f"{self.install_dir}/{base_filename}Everything{extension}"
|
|
|
|
|
)
|
2021-10-28 17:43:11 +00:00
|
|
|
|
2021-12-07 23:56:49 +00:00
|
|
|
def write_outputs(self, variable_name: str, filename: str) -> None:
|
2021-10-28 17:43:11 +00:00
|
|
|
"""Write a file containing the list of all outputs which are
|
|
|
|
|
generated by this script."""
|
2022-04-19 12:25:45 +00:00
|
|
|
content = "set({}\n {})".format(
|
|
|
|
|
variable_name,
|
|
|
|
|
"\n ".join('"' + name + '"' for name in sorted(self.filenames)),
|
|
|
|
|
)
|
2021-12-07 23:56:49 +00:00
|
|
|
self._write_if_changed(filename, content)
|
2022-03-01 22:54:42 +00:00
|
|
|
|
autograd: fix non-deterministic output in codegen comments (#84695)
Summary:
Like it says in the title. Currently, this will return output like this:
In Buck1, that's OK because Buck1's caching doesn't really care too much about
However, in Buck2, this is a disaster, because caching is based exclusively
on inputs and outputs and
The diff here proposes making the path relative to the codegen script itself,
which should carry about as much info, but avoid cache misses.
Concretely, this:
```
// generated from /dev/shm/uid-34135/cfbc5712-seed-nspid4026533424_cgpid2794673-ns-4026533443/tools/autograd/templates/python_functions.h
```
Becomes, this:
```
// generated from ../tools/autograd/templates/python_functions.h
```
So, we keep the useful part, and we get caching. This matters because those
headers are used in actions like:
```
fbcode//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops -- action (cxx_compile gen_embedding_backward_adam_split_unweighted_cuda.cu (pic))
```
Those actions take upwards of 5 minutes to finish, so by allowing a cache hit,
we are a) saving our users a lot of time and b) saving some RE capacity as
well.
This actually matters a lot because right now those targets are produced by
`//caffe2:generate-code`, which itself doesn't get cache hits from RE because
`generate_code.par` is non-deterministic (this is, unfortunately, true of PARs
in general), so that rule introduces non-determinism that the codegen
propagates and we get zero caching.
This diff doesn't fix `//caffe2:generate-code`'s inputs being
non-deterministic, but it does fix its *outputs* being non-deterministic, which
means the non-determinism stops there, and we get back to cache hits.
Test Plan:
- CI
```
buck2 build fbcode//caffe2:generate-code
buck2 build fbcode//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops
```
Reviewed By: ndmitchell
Differential Revision: D39348565
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84695
Approved by: https://github.com/soulitzer
2022-09-13 18:41:15 +00:00
|
|
|
def template_dir_for_comments(self) -> str:
|
|
|
|
|
"""
|
|
|
|
|
This needs to be deterministic. The template dir is an absolute path
|
|
|
|
|
that varies across builds. So, just use the path relative to this file,
|
|
|
|
|
which will point to the codegen source but will be stable.
|
|
|
|
|
"""
|
|
|
|
|
return os.path.relpath(self.template_dir, os.path.dirname(__file__))
|
|
|
|
|
|
2022-03-01 22:54:42 +00:00
|
|
|
|
|
|
|
|
# Helper function to generate file manager
|
2022-04-19 12:25:45 +00:00
|
|
|
def make_file_manager(
|
2024-06-29 04:48:07 +00:00
|
|
|
options: Namespace, install_dir: str | None = None
|
2022-04-19 12:25:45 +00:00
|
|
|
) -> FileManager:
|
2022-03-01 22:54:42 +00:00
|
|
|
template_dir = os.path.join(options.source_path, "templates")
|
|
|
|
|
install_dir = install_dir if install_dir else options.install_dir
|
2022-04-19 12:25:45 +00:00
|
|
|
return FileManager(
|
|
|
|
|
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
|
|
|
|
|
)
|
Pretty-print dataclasses (#76810)
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
2022-05-16 14:20:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Helper function to create a pretty representation for dataclasses
|
|
|
|
|
def dataclass_repr(
|
|
|
|
|
obj: Any,
|
|
|
|
|
indent: int = 0,
|
|
|
|
|
width: int = 80,
|
|
|
|
|
) -> str:
|
|
|
|
|
# built-in pprint module support dataclasses from python 3.10
|
|
|
|
|
if sys.version_info >= (3, 10):
|
|
|
|
|
from pprint import pformat
|
|
|
|
|
|
|
|
|
|
return pformat(obj, indent, width)
|
|
|
|
|
|
|
|
|
|
return _pformat(obj, indent=indent, width=width)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _pformat(
|
|
|
|
|
obj: Any,
|
|
|
|
|
indent: int,
|
|
|
|
|
width: int,
|
|
|
|
|
curr_indent: int = 0,
|
|
|
|
|
) -> str:
|
|
|
|
|
assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
|
|
|
|
|
|
|
|
|
|
class_name = obj.__class__.__name__
|
|
|
|
|
# update current indentation level with class name
|
|
|
|
|
curr_indent += len(class_name) + 1
|
|
|
|
|
|
|
|
|
|
fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
|
|
|
|
|
|
|
|
|
|
fields_str = []
|
|
|
|
|
for name, attr in fields_list:
|
|
|
|
|
# update the current indent level with the field name
|
|
|
|
|
# dict, list, set and tuple also add indent as done in pprint
|
|
|
|
|
_curr_indent = curr_indent + len(name) + 1
|
|
|
|
|
if is_dataclass(attr):
|
|
|
|
|
str_repr = _pformat(attr, indent, width, _curr_indent)
|
|
|
|
|
elif isinstance(attr, dict):
|
|
|
|
|
str_repr = _format_dict(attr, indent, width, _curr_indent)
|
|
|
|
|
elif isinstance(attr, (list, set, tuple)):
|
|
|
|
|
str_repr = _format_list(attr, indent, width, _curr_indent)
|
|
|
|
|
else:
|
|
|
|
|
str_repr = repr(attr)
|
|
|
|
|
|
|
|
|
|
fields_str.append(f"{name}={str_repr}")
|
|
|
|
|
|
|
|
|
|
indent_str = curr_indent * " "
|
|
|
|
|
body = f",\n{indent_str}".join(fields_str)
|
|
|
|
|
return f"{class_name}({body})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format_dict(
|
2024-06-29 04:48:07 +00:00
|
|
|
attr: dict[Any, Any],
|
Pretty-print dataclasses (#76810)
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
2022-05-16 14:20:41 +00:00
|
|
|
indent: int,
|
|
|
|
|
width: int,
|
|
|
|
|
curr_indent: int,
|
|
|
|
|
) -> str:
|
|
|
|
|
curr_indent += indent + 3
|
|
|
|
|
dict_repr = []
|
|
|
|
|
for k, v in attr.items():
|
|
|
|
|
k_repr = repr(k)
|
|
|
|
|
v_str = (
|
|
|
|
|
_pformat(v, indent, width, curr_indent + len(k_repr))
|
|
|
|
|
if is_dataclass(v)
|
|
|
|
|
else repr(v)
|
|
|
|
|
)
|
|
|
|
|
dict_repr.append(f"{k_repr}: {v_str}")
|
|
|
|
|
|
|
|
|
|
return _format(dict_repr, indent, width, curr_indent, "{", "}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format_list(
|
2024-06-29 04:48:07 +00:00
|
|
|
attr: list[Any] | set[Any] | tuple[Any, ...],
|
Pretty-print dataclasses (#76810)
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
2022-05-16 14:20:41 +00:00
|
|
|
indent: int,
|
|
|
|
|
width: int,
|
|
|
|
|
curr_indent: int,
|
|
|
|
|
) -> str:
|
|
|
|
|
curr_indent += indent + 1
|
|
|
|
|
list_repr = [
|
|
|
|
|
_pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
|
|
|
|
|
for l in attr
|
|
|
|
|
]
|
|
|
|
|
start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
|
|
|
|
|
return _format(list_repr, indent, width, curr_indent, start, end)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _format(
|
2024-06-29 04:48:07 +00:00
|
|
|
fields_str: list[str],
|
Pretty-print dataclasses (#76810)
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
2022-05-16 14:20:41 +00:00
|
|
|
indent: int,
|
|
|
|
|
width: int,
|
|
|
|
|
curr_indent: int,
|
|
|
|
|
start: str,
|
|
|
|
|
end: str,
|
|
|
|
|
) -> str:
|
|
|
|
|
delimiter, curr_indent_str = "", ""
|
|
|
|
|
# if it exceed the max width then we place one element per line
|
|
|
|
|
if len(repr(fields_str)) >= width:
|
|
|
|
|
delimiter = "\n"
|
|
|
|
|
curr_indent_str = " " * curr_indent
|
|
|
|
|
|
|
|
|
|
indent_str = " " * indent
|
|
|
|
|
body = f", {delimiter}{curr_indent_str}".join(fields_str)
|
|
|
|
|
return f"{start}{indent_str}{body}{end}"
|
2022-07-08 21:56:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class NamespaceHelper:
|
|
|
|
|
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
|
|
|
|
|
|
|
|
|
|
e.g. for namespace_str torch::lazy,
|
|
|
|
|
|
|
|
|
|
prologue:
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace lazy {
|
|
|
|
|
|
|
|
|
|
epilogue:
|
|
|
|
|
} // namespace lazy
|
|
|
|
|
} // namespace torch
|
|
|
|
|
"""
|
|
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def __init__(
|
|
|
|
|
self, namespace_str: str, entity_name: str = "", max_level: int = 2
|
|
|
|
|
) -> None:
|
2022-07-08 21:56:49 +00:00
|
|
|
# cpp_namespace can be a colon joined string such as torch::lazy
|
|
|
|
|
cpp_namespaces = namespace_str.split("::")
|
|
|
|
|
assert (
|
|
|
|
|
len(cpp_namespaces) <= max_level
|
|
|
|
|
), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
|
|
|
|
|
self.cpp_namespace_ = namespace_str
|
|
|
|
|
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
|
|
|
|
|
self.epilogue_ = "\n".join(
|
|
|
|
|
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
|
|
|
|
|
)
|
|
|
|
|
self.namespaces_ = cpp_namespaces
|
|
|
|
|
self.entity_name_ = entity_name
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_namespaced_entity(
|
|
|
|
|
namespaced_entity: str, max_level: int = 2
|
2024-06-29 04:48:07 +00:00
|
|
|
) -> NamespaceHelper:
|
2022-07-08 21:56:49 +00:00
|
|
|
"""
|
|
|
|
|
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
|
|
|
|
"""
|
|
|
|
|
names = namespaced_entity.split("::")
|
|
|
|
|
entity_name = names[-1]
|
|
|
|
|
namespace_str = "::".join(names[:-1])
|
|
|
|
|
return NamespaceHelper(
|
|
|
|
|
namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def prologue(self) -> str:
|
|
|
|
|
return self.prologue_
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def epilogue(self) -> str:
|
|
|
|
|
return self.epilogue_
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def entity_name(self) -> str:
|
|
|
|
|
return self.entity_name_
|
|
|
|
|
|
|
|
|
|
# Only allow certain level of namespaces
|
|
|
|
|
def get_cpp_namespace(self, default: str = "") -> str:
|
|
|
|
|
"""
|
|
|
|
|
Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
|
|
|
|
|
Return default if namespace string is empty.
|
|
|
|
|
"""
|
|
|
|
|
return self.cpp_namespace_ if self.cpp_namespace_ else default
|
2022-08-01 15:16:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class OrderedSet(Generic[T]):
|
2024-06-29 04:48:07 +00:00
|
|
|
storage: dict[T, Literal[None]]
|
2022-08-01 15:16:03 +00:00
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def __init__(self, iterable: Iterable[T] | None = None) -> None:
|
2022-08-01 15:16:03 +00:00
|
|
|
if iterable is None:
|
|
|
|
|
self.storage = {}
|
|
|
|
|
else:
|
2024-01-30 20:46:48 +00:00
|
|
|
self.storage = dict.fromkeys(iterable)
|
2022-08-01 15:16:03 +00:00
|
|
|
|
|
|
|
|
def __contains__(self, item: T) -> bool:
|
|
|
|
|
return item in self.storage
|
|
|
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
|
|
|
return iter(self.storage.keys())
|
|
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def update(self, items: OrderedSet[T]) -> None:
|
2022-08-01 15:16:03 +00:00
|
|
|
self.storage.update(items.storage)
|
|
|
|
|
|
|
|
|
|
def add(self, item: T) -> None:
|
|
|
|
|
self.storage[item] = None
|
|
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def copy(self) -> OrderedSet[T]:
|
2022-08-01 15:16:03 +00:00
|
|
|
ret: OrderedSet[T] = OrderedSet()
|
|
|
|
|
ret.storage = self.storage.copy()
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2024-06-29 04:48:07 +00:00
|
|
|
def union(*args: OrderedSet[T]) -> OrderedSet[T]:
|
2022-08-01 15:16:03 +00:00
|
|
|
ret = args[0].copy()
|
|
|
|
|
for s in args[1:]:
|
|
|
|
|
ret.update(s)
|
|
|
|
|
return ret
|
|
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
|
2022-08-01 15:16:03 +00:00
|
|
|
return OrderedSet.union(self, other)
|
|
|
|
|
|
2024-06-29 04:48:07 +00:00
|
|
|
def __ior__(self, other: OrderedSet[T]) -> Self:
|
2022-08-01 15:16:03 +00:00
|
|
|
self.update(other)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
if isinstance(other, OrderedSet):
|
|
|
|
|
return self.storage == other.storage
|
|
|
|
|
else:
|
|
|
|
|
return set(self.storage.keys()) == other
|