pytorch/tools/test/test_codegen.py
Huy Do 347b036350 Apply ufmt linter to all py files under tools (#81285)
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on.

This batch (as copied from the current BLACK linter config):
* `tools/**/*.py`

Upcoming batchs:
* `torchgen/**/*.py`
* `torch/package/**/*.py`
* `torch/onnx/**/*.py`
* `torch/_refs/**/*.py`
* `torch/_prims/**/*.py`
* `torch/_meta_registrations.py`
* `torch/_decomp/**/*.py`
* `test/onnx/**/*.py`

Once they are all formatted, BLACK linter will be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285
Approved by: https://github.com/suo
2022-07-13 07:59:22 +00:00

211 lines
8.1 KiB
Python

import dataclasses
import typing
import unittest
import torchgen.model
from tools.autograd import gen_autograd_functions, load_derivatives
from torchgen.gen import get_native_function_schema_registrations
from torchgen.selective_build.selector import SelectiveBuilder
class TestCreateDerivative(unittest.TestCase):
def test_named_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
derivative = load_derivatives.create_derivative(
native_function,
formula="func_backward(grad_x, grad_y)",
var_names=(),
available_named_gradients=["grad_x", "grad_y"],
)
self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})
def test_non_differentiable_output(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
differentiability_info = load_derivatives.create_differentiability_info(
defn={
"name": specification,
"a": "grads[0]",
"b": "grads[2]",
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
)
self.assertSequenceEqual(
differentiability_info.available_named_gradients,
# grad_y is not present because y is a
# bool and thus not differentiable.
["grad_x", "grad_z"],
)
def test_indexed_grads(self) -> None:
schema = torchgen.model.FunctionSchema.parse(
"func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
derivative = load_derivatives.create_derivative(
native_function,
formula="func_backward(grads[0], grads[1])",
var_names=(),
available_named_gradients=["grad_x", "grad_y"],
)
self.assertSetEqual(derivative.named_gradients, set())
def test_named_grads_and_indexed_grads(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
with self.assertRaisesRegex(
RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"'
):
load_derivatives.create_differentiability_info(
defn={
"name": specification,
# Uh-oh, the derivatives reference gradients by
# name and by index.
"a": "grad_x",
"b": "grads[1]",
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
)
class TestGenAutogradFunctions(unittest.TestCase):
def test_non_differentiable_output_invalid_type(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
differentiability_info = load_derivatives.create_differentiability_info(
defn={
"name": specification,
"a": "grad_x",
"b": "grad_z",
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
)
definition = gen_autograd_functions.process_function(
differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION
)
# grad_z should map to grads[1], not grads[2] because output 1
# (y) is not differentiable.
assert "grad_z = grads[2]" not in definition
assert "grad_z = grads[1]" in definition
def test_non_differentiable_output_output_differentiability(self) -> None:
specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
schema = torchgen.model.FunctionSchema.parse(specification)
native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
differentiability_info = load_derivatives.create_differentiability_info(
defn={
"name": specification,
"a": "grad_x",
"b": "grad_z",
"output_differentiability": [True, False, True],
},
functions_by_signature={schema.signature(): [native_function]},
functions_by_schema={specification: native_function},
op_counter=typing.Counter[str](),
)
definition = gen_autograd_functions.process_function(
differentiability_info, gen_autograd_functions.FUNCTION_DEFINITION
)
# grad_z should map to grads[1], not grads[2] because output 1
# (y) is not differentiable.
assert "grad_z = grads[2]" not in definition
assert "grad_z = grads[1]" in definition
class TestGenSchemaRegistration(unittest.TestCase):
def setUp(self) -> None:
self.selector = SelectiveBuilder.get_nop_selector()
self.custom_native_function, _ = torchgen.model.NativeFunction.from_yaml(
{"func": "custom::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
def test_default_namespace_schema_registration_code_valid(self) -> None:
native_functions = [DEFAULT_NATIVE_FUNCTION]
registrations, _ = get_native_function_schema_registrations(
native_functions=native_functions,
schema_selector=self.selector,
)
self.assertEqual(registrations, ['m.def("func() -> bool", {});\n'])
def test_custom_namespace_schema_registration_code_valid(self) -> None:
_, registrations = get_native_function_schema_registrations(
native_functions=[self.custom_native_function],
schema_selector=self.selector,
)
self.assertEqual(
registrations,
"""
TORCH_LIBRARY(custom, m) {
m.def("func() -> bool", {});
};""",
)
def test_mixed_namespace_schema_registration_code_valid(self) -> None:
(
aten_registrations,
custom_registrations,
) = get_native_function_schema_registrations(
native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function],
schema_selector=self.selector,
)
self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
self.assertEqual(
custom_registrations,
"""
TORCH_LIBRARY(custom, m) {
m.def("func() -> bool", {});
};""",
)
def test_3_namespaces_schema_registration_code_invalid(self) -> None:
custom2_native_function, _ = torchgen.model.NativeFunction.from_yaml(
{"func": "custom2::func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
with self.assertRaises(AssertionError):
get_native_function_schema_registrations(
native_functions=[
DEFAULT_NATIVE_FUNCTION,
self.custom_native_function,
custom2_native_function,
],
schema_selector=self.selector,
)
# Represents the most basic NativeFunction. Use dataclasses.replace()
# to edit for use.
DEFAULT_NATIVE_FUNCTION, _ = torchgen.model.NativeFunction.from_yaml(
{"func": "func() -> bool"},
loc=torchgen.model.Location(__file__, 1),
valid_tags=set(),
)
if __name__ == "__main__":
unittest.main()