diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 66165b9c4dd..730e664a9e2 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -1218,6 +1218,24 @@ class TestFunctionalIterDataPipe(TestCase): def fn_nn(d0, d1): return -d0, -d1, d0 + d1 + def fn_n1_def(d0, d1=1): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_sep_pos(d0, *args, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + p_fn_cmplx = partial(fn_cmplx, d2=2) + def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): for constr in (list, tuple): datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) @@ -1231,6 +1249,12 @@ class TestFunctionalIterDataPipe(TestCase): self.assertEqual(list(res_dp), list(ref_dp)) # Reset self.assertEqual(list(res_dp), list(ref_dp)) + _helper(lambda data: data, fn_n1_def, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) + _helper(lambda data: data, p_fn_n1, 0, 1) + _helper(lambda data: data, p_fn_cmplx, 0, 1) + _helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2) + _helper(lambda data: (data[0] + data[1], ), fn_n1_pos, [0, 1, 2]) # Replacing with one input column and default output column _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) @@ -1238,7 +1262,20 @@ class TestFunctionalIterDataPipe(TestCase): # The index of input column is out of range _helper(None, fn_1n, 3, error=IndexError) # Unmatched input columns with fn arguments - _helper(None, fn_n1, 1, error=TypeError) + _helper(None, fn_n1, 1, error=ValueError) + _helper(None, fn_n1, [0, 1, 2], error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError) + _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError) + _helper(None, fn_cmplx, 0, 1, ValueError) + _helper(None, fn_n1_pos, 1, error=ValueError) + _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) + _helper(None, p_fn_n1, [0, 1], error=ValueError) + _helper(None, fn_1n, [1, 2], error=ValueError) + # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) + _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) + # Fn has keyword-only arguments + _helper(None, fn_n1_kwargs, 1, error=ValueError) + _helper(None, fn_cmplx, [0, 1], 2, ValueError) # Replacing with multiple input columns and default output column (the left-most input column) _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) @@ -1278,6 +1315,28 @@ class TestFunctionalIterDataPipe(TestCase): def fn_nn(d0, d1): return -d0, -d1, d0 + d1 + def fn_n1_def(d0, d1=1): + return d0 + d1 + + p_fn_n1 = partial(fn_n1, d1=1) + + def fn_n1_pos(d0, d1, *args): + return d0 + d1 + + def fn_n1_kwargs(d0, d1, **kwargs): + return d0 + d1 + + def fn_kwonly(*, d0, d1): + return d0 + d1 + + def fn_has_nondefault_kwonly(d0, *, d1): + return d0 + d1 + + def fn_cmplx(d0, d1=1, *args, d2, **kwargs): + return d0 + d1 + + p_fn_cmplx = partial(fn_cmplx, d2=2) + # Prevent modification in-place to support resetting def _dict_update(data, newdata, remove_idx=None): _data = dict(data) @@ -1304,13 +1363,33 @@ class TestFunctionalIterDataPipe(TestCase): # Reset self.assertEqual(list(res_dp), list(ref_dp)) + _helper(lambda data: data, fn_n1_def, 'x', 'y') + _helper(lambda data: data, p_fn_n1, 'x', 'y') + _helper(lambda data: data, p_fn_cmplx, 'x', 'y') + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), + p_fn_cmplx, ["x", "y", "z"], "z") + + _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ['x', 'y'], 'z') + + _helper(None, fn_n1_pos, 'x', error=ValueError) + _helper(None, fn_n1_kwargs, 'x', error=ValueError) + # non-default kw-only args + _helper(None, fn_kwonly, ['x', 'y'], error=ValueError) + _helper(None, fn_has_nondefault_kwonly, ['x', 'y'], error=ValueError) + _helper(None, fn_cmplx, ['x', 'y'], error=ValueError) + + # Replacing with one input column and default output column _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") # The key of input column is not in dict _helper(None, fn_1n, "a", error=KeyError) # Unmatched input columns with fn arguments - _helper(None, fn_n1, "y", error=TypeError) + _helper(None, fn_n1, "y", error=ValueError) + _helper(None, fn_1n, ["x", "y"], error=ValueError) + _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) + _helper(None, p_fn_n1, ["x", "y"], error=ValueError) + _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) # Replacing with multiple input columns and default output column (the left-most input column) _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) _helper(lambda data: _dict_update( @@ -1508,6 +1587,32 @@ class TestFunctionalIterDataPipe(TestCase): input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2]) self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)]) + # invalid input col + with self.assertRaises(ValueError): + tuple_input_ds.filter(_mul_filter_fn, input_col=0) + + p_mul_filter_fn = partial(_mul_filter_fn, b=1) + out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0) + self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)]) + + def _mul_filter_fn_with_defaults(a, b=1): + return a + b < 10 + + out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0) + self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)]) + + def _mul_filter_fn_with_kw_only(*, a, b): + return a + b < 10 + + with self.assertRaises(ValueError): + tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0) + + def _mul_filter_fn_with_kw_only_1_default(*, a, b=1): + return a + b < 10 + + with self.assertRaises(ValueError): + tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0) + # __len__ Test: DataPipe has no valid len with self.assertRaisesRegex(TypeError, r"has no len"): len(filter_dp) diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 653a25bb72b..30b04885787 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -7,7 +7,8 @@ from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data._utils.collate import default_collate from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper from torch.utils.data.datapipes.datapipe import IterDataPipe -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn +from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn, + validate_input_col) __all__ = [ "CollatorIterDataPipe", @@ -80,6 +81,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]): raise ValueError("`output_col` must be a single-element list or tuple") output_col = output_col[0] self.output_col = output_col + validate_input_col(fn, input_col) def _apply_fn(self, data): if self.input_col is None and self.output_col is None: diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index a31f2f933f3..2ba91b36fff 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -7,6 +7,7 @@ from torch.utils.data.datapipes.utils.common import ( _check_unpickable_fn, _deprecation_warning, StreamWrapper, + validate_input_col ) @@ -69,6 +70,7 @@ class FilterIterDataPipe(IterDataPipe[T_co]): self.drop_empty_batches = drop_empty_batches self.input_col = input_col + validate_input_col(filter_fn, input_col) def _apply_filter_fn(self, data) -> bool: if self.input_col is None: diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 42227bfaf59..4ca0ced8943 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from torch.utils.data._utils.serialization import DILL_AVAILABLE __all__ = [ + "validate_input_col", "StreamWrapper", "get_file_binaries_from_pathnames", "get_file_pathnames_from_root", @@ -20,6 +21,86 @@ __all__ = [ ] +def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]): + """ + Checks that function used in a callable datapipe works with the input column + + This simply ensures that the number of positional arguments matches the size + of the input column. The function must not contain any non-default + keyword-only arguments. + + Examples: + >>> def f(a, b, *, c=1): + >>> return a + b + c + >>> def f_def(a, b=1, *, c=1): + >>> return a + b + c + >>> assert validate_input_col(f, [1, 2]) + >>> assert validate_input_col(f_def, 1) + >>> assert validate_input_col(f_def, [1, 2]) + + Notes: + If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments, + for example, f(a, *args), the validator will accept any size of input column + greater than or equal to the number of positional arguments. + (in this case, 1). + + Args: + fn: The function to check. + input_col: The input column to check. + + Raises: + ValueError: If the function is not compatible with the input column. + """ + sig = inspect.signature(fn) + if isinstance(input_col, (list, tuple)): + input_col_size = len(input_col) + else: + input_col_size = 1 + + fn_name = str(fn) + + pos = [] + var_positional = False + non_default_kw_only = [] + + for p in sig.parameters.values(): + if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): + pos.append(p) + elif p.kind is inspect.Parameter.VAR_POSITIONAL: + var_positional = True + elif p.kind is inspect.Parameter.KEYWORD_ONLY: + if p.default is p.empty: + non_default_kw_only.append(p) + else: + continue + + if len(non_default_kw_only) > 0: + raise ValueError( + f"The function {fn_name} takes {len(non_default_kw_only)} " + f"non-default keyword-only parameters, which is not allowed." + ) + + if len(sig.parameters) < input_col_size: + if not var_positional: + raise ValueError( + f"The function {fn_name} takes {len(sig.parameters)} " + f"parameters, but {input_col_size} are required." + ) + else: + if len(pos) > input_col_size: + if any(p.default is p.empty for p in pos[input_col_size:]): + raise ValueError( + f"The function {fn_name} takes {len(pos)} " + f"positional parameters, but {input_col_size} are required." + ) + elif len(pos) < input_col_size: + if not var_positional: + raise ValueError( + f"The function {fn_name} takes {len(pos)} " + f"positional parameters, but {input_col_size} are required." + ) + + def _is_local_fn(fn): # Functions or Methods if hasattr(fn, "__code__"): @@ -33,7 +114,6 @@ def _is_local_fn(fn): return "" in fn_type.__qualname__ return False - def _check_unpickable_fn(fn: Callable): """ Checks function is pickable or not. If it is a lambda or local function, a UserWarning