[WIP] Validating input_col for certain datapipes (#80267)

Follow up from #79344.

Currently WIP due to multiple test failures.

Waiting for #80140 to land
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80267
Approved by: https://github.com/ejguan
This commit is contained in:
Robert 2022-08-24 17:34:28 +00:00 committed by PyTorch MergeBot
parent 30a5583d75
commit 5c49c7bbba
4 changed files with 193 additions and 4 deletions

View file

@ -1218,6 +1218,24 @@ class TestFunctionalIterDataPipe(TestCase):
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1
def fn_n1_def(d0, d1=1):
return d0 + d1
def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1
def fn_n1_pos(d0, d1, *args):
return d0 + d1
def fn_n1_sep_pos(d0, *args, d1):
return d0 + d1
def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1
p_fn_n1 = partial(fn_n1, d1=1)
p_fn_cmplx = partial(fn_cmplx, d2=2)
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
@ -1231,6 +1249,12 @@ class TestFunctionalIterDataPipe(TestCase):
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
_helper(lambda data: data, fn_n1_def, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
_helper(lambda data: data, p_fn_n1, 0, 1)
_helper(lambda data: data, p_fn_cmplx, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
_helper(lambda data: (data[0] + data[1], ), fn_n1_pos, [0, 1, 2])
# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
@ -1238,7 +1262,20 @@ class TestFunctionalIterDataPipe(TestCase):
# The index of input column is out of range
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, 1, error=TypeError)
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
_helper(None, fn_cmplx, 0, 1, ValueError)
_helper(None, fn_n1_pos, 1, error=ValueError)
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
_helper(None, p_fn_n1, [0, 1], error=ValueError)
_helper(None, fn_1n, [1, 2], error=ValueError)
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
# Fn has keyword-only arguments
_helper(None, fn_n1_kwargs, 1, error=ValueError)
_helper(None, fn_cmplx, [0, 1], 2, ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
@ -1278,6 +1315,28 @@ class TestFunctionalIterDataPipe(TestCase):
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1
def fn_n1_def(d0, d1=1):
return d0 + d1
p_fn_n1 = partial(fn_n1, d1=1)
def fn_n1_pos(d0, d1, *args):
return d0 + d1
def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1
def fn_kwonly(*, d0, d1):
return d0 + d1
def fn_has_nondefault_kwonly(d0, *, d1):
return d0 + d1
def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1
p_fn_cmplx = partial(fn_cmplx, d2=2)
# Prevent modification in-place to support resetting
def _dict_update(data, newdata, remove_idx=None):
_data = dict(data)
@ -1304,13 +1363,33 @@ class TestFunctionalIterDataPipe(TestCase):
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
_helper(lambda data: data, fn_n1_def, 'x', 'y')
_helper(lambda data: data, p_fn_n1, 'x', 'y')
_helper(lambda data: data, p_fn_cmplx, 'x', 'y')
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
p_fn_cmplx, ["x", "y", "z"], "z")
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ['x', 'y'], 'z')
_helper(None, fn_n1_pos, 'x', error=ValueError)
_helper(None, fn_n1_kwargs, 'x', error=ValueError)
# non-default kw-only args
_helper(None, fn_kwonly, ['x', 'y'], error=ValueError)
_helper(None, fn_has_nondefault_kwonly, ['x', 'y'], error=ValueError)
_helper(None, fn_cmplx, ['x', 'y'], error=ValueError)
# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, "y", error=TypeError)
_helper(None, fn_n1, "y", error=ValueError)
_helper(None, fn_1n, ["x", "y"], error=ValueError)
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(lambda data: _dict_update(
@ -1508,6 +1587,32 @@ class TestFunctionalIterDataPipe(TestCase):
input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2])
self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)])
# invalid input col
with self.assertRaises(ValueError):
tuple_input_ds.filter(_mul_filter_fn, input_col=0)
p_mul_filter_fn = partial(_mul_filter_fn, b=1)
out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0)
self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
def _mul_filter_fn_with_defaults(a, b=1):
return a + b < 10
out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0)
self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
def _mul_filter_fn_with_kw_only(*, a, b):
return a + b < 10
with self.assertRaises(ValueError):
tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0)
def _mul_filter_fn_with_kw_only_1_default(*, a, b=1):
return a + b < 10
with self.assertRaises(ValueError):
tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0)
# __len__ Test: DataPipe has no valid len
with self.assertRaisesRegex(TypeError, r"has no len"):
len(filter_dp)

View file

@ -7,7 +7,8 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn,
validate_input_col)
__all__ = [
"CollatorIterDataPipe",
@ -80,6 +81,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
raise ValueError("`output_col` must be a single-element list or tuple")
output_col = output_col[0]
self.output_col = output_col
validate_input_col(fn, input_col)
def _apply_fn(self, data):
if self.input_col is None and self.output_col is None:

View file

@ -7,6 +7,7 @@ from torch.utils.data.datapipes.utils.common import (
_check_unpickable_fn,
_deprecation_warning,
StreamWrapper,
validate_input_col
)
@ -69,6 +70,7 @@ class FilterIterDataPipe(IterDataPipe[T_co]):
self.drop_empty_batches = drop_empty_batches
self.input_col = input_col
validate_input_col(filter_fn, input_col)
def _apply_filter_fn(self, data) -> bool:
if self.input_col is None:

View file

@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from torch.utils.data._utils.serialization import DILL_AVAILABLE
__all__ = [
"validate_input_col",
"StreamWrapper",
"get_file_binaries_from_pathnames",
"get_file_pathnames_from_root",
@ -20,6 +21,86 @@ __all__ = [
]
def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
"""
Checks that function used in a callable datapipe works with the input column
This simply ensures that the number of positional arguments matches the size
of the input column. The function must not contain any non-default
keyword-only arguments.
Examples:
>>> def f(a, b, *, c=1):
>>> return a + b + c
>>> def f_def(a, b=1, *, c=1):
>>> return a + b + c
>>> assert validate_input_col(f, [1, 2])
>>> assert validate_input_col(f_def, 1)
>>> assert validate_input_col(f_def, [1, 2])
Notes:
If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments,
for example, f(a, *args), the validator will accept any size of input column
greater than or equal to the number of positional arguments.
(in this case, 1).
Args:
fn: The function to check.
input_col: The input column to check.
Raises:
ValueError: If the function is not compatible with the input column.
"""
sig = inspect.signature(fn)
if isinstance(input_col, (list, tuple)):
input_col_size = len(input_col)
else:
input_col_size = 1
fn_name = str(fn)
pos = []
var_positional = False
non_default_kw_only = []
for p in sig.parameters.values():
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
pos.append(p)
elif p.kind is inspect.Parameter.VAR_POSITIONAL:
var_positional = True
elif p.kind is inspect.Parameter.KEYWORD_ONLY:
if p.default is p.empty:
non_default_kw_only.append(p)
else:
continue
if len(non_default_kw_only) > 0:
raise ValueError(
f"The function {fn_name} takes {len(non_default_kw_only)} "
f"non-default keyword-only parameters, which is not allowed."
)
if len(sig.parameters) < input_col_size:
if not var_positional:
raise ValueError(
f"The function {fn_name} takes {len(sig.parameters)} "
f"parameters, but {input_col_size} are required."
)
else:
if len(pos) > input_col_size:
if any(p.default is p.empty for p in pos[input_col_size:]):
raise ValueError(
f"The function {fn_name} takes {len(pos)} "
f"positional parameters, but {input_col_size} are required."
)
elif len(pos) < input_col_size:
if not var_positional:
raise ValueError(
f"The function {fn_name} takes {len(pos)} "
f"positional parameters, but {input_col_size} are required."
)
def _is_local_fn(fn):
# Functions or Methods
if hasattr(fn, "__code__"):
@ -33,7 +114,6 @@ def _is_local_fn(fn):
return "<locals>" in fn_type.__qualname__
return False
def _check_unpickable_fn(fn: Callable):
"""
Checks function is pickable or not. If it is a lambda or local function, a UserWarning