diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index a7228861fa7..807d47c26a4 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -782,7 +782,6 @@ def validate_function_matches_schema( compare(kwargonly, schema.arguments.flat_kwarg_only) - def infer_schema(prototype_function: typing.Callable) -> str: sig = inspect.signature(prototype_function) diff --git a/torch/_library/abstract_impl.py b/torch/_library/abstract_impl.py index 19ee1d295d3..e09d3eace9b 100644 --- a/torch/_library/abstract_impl.py +++ b/torch/_library/abstract_impl.py @@ -137,9 +137,7 @@ class AbstractImplCtx: that depends on the data of the input Tensors. Args: - min (int): A statically known inclusive lower bound for this symint. - min must be at least 2 due to implementation details of - torch.compile. Default: 2. + min (int): A statically known inclusive lower bound for this symint. Default: 0 max (Optional[int]): A statically known inclusive upper bound for this symint. Default: None @@ -202,5 +200,7 @@ class AbstractImplCtx: ) result = self._shape_env.create_unbacked_symint() - torch.fx.experimental.symbolic_shapes.constrain_range(result, min=0, max=max) + torch.fx.experimental.symbolic_shapes._constrain_range_for_size( + result, min=min, max=max + ) return result