mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[WIP] Validating input_col for certain datapipes (#80267)
Follow up from #79344. Currently WIP due to multiple test failures. Waiting for #80140 to land Pull Request resolved: https://github.com/pytorch/pytorch/pull/80267 Approved by: https://github.com/ejguan
This commit is contained in:
parent
30a5583d75
commit
5c49c7bbba
4 changed files with 193 additions and 4 deletions
|
|
@ -1218,6 +1218,24 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
def fn_nn(d0, d1):
|
||||
return -d0, -d1, d0 + d1
|
||||
|
||||
def fn_n1_def(d0, d1=1):
|
||||
return d0 + d1
|
||||
|
||||
def fn_n1_kwargs(d0, d1, **kwargs):
|
||||
return d0 + d1
|
||||
|
||||
def fn_n1_pos(d0, d1, *args):
|
||||
return d0 + d1
|
||||
|
||||
def fn_n1_sep_pos(d0, *args, d1):
|
||||
return d0 + d1
|
||||
|
||||
def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
|
||||
return d0 + d1
|
||||
|
||||
p_fn_n1 = partial(fn_n1, d1=1)
|
||||
p_fn_cmplx = partial(fn_cmplx, d2=2)
|
||||
|
||||
def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
|
||||
for constr in (list, tuple):
|
||||
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
|
||||
|
|
@ -1231,6 +1249,12 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
# Reset
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
_helper(lambda data: data, fn_n1_def, 0, 1)
|
||||
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
|
||||
_helper(lambda data: data, p_fn_n1, 0, 1)
|
||||
_helper(lambda data: data, p_fn_cmplx, 0, 1)
|
||||
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
|
||||
_helper(lambda data: (data[0] + data[1], ), fn_n1_pos, [0, 1, 2])
|
||||
|
||||
# Replacing with one input column and default output column
|
||||
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
|
||||
|
|
@ -1238,7 +1262,20 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
# The index of input column is out of range
|
||||
_helper(None, fn_1n, 3, error=IndexError)
|
||||
# Unmatched input columns with fn arguments
|
||||
_helper(None, fn_n1, 1, error=TypeError)
|
||||
_helper(None, fn_n1, 1, error=ValueError)
|
||||
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
|
||||
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
|
||||
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
|
||||
_helper(None, fn_cmplx, 0, 1, ValueError)
|
||||
_helper(None, fn_n1_pos, 1, error=ValueError)
|
||||
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
|
||||
_helper(None, p_fn_n1, [0, 1], error=ValueError)
|
||||
_helper(None, fn_1n, [1, 2], error=ValueError)
|
||||
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
|
||||
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
|
||||
# Fn has keyword-only arguments
|
||||
_helper(None, fn_n1_kwargs, 1, error=ValueError)
|
||||
_helper(None, fn_cmplx, [0, 1], 2, ValueError)
|
||||
|
||||
# Replacing with multiple input columns and default output column (the left-most input column)
|
||||
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
|
||||
|
|
@ -1278,6 +1315,28 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
def fn_nn(d0, d1):
|
||||
return -d0, -d1, d0 + d1
|
||||
|
||||
def fn_n1_def(d0, d1=1):
|
||||
return d0 + d1
|
||||
|
||||
p_fn_n1 = partial(fn_n1, d1=1)
|
||||
|
||||
def fn_n1_pos(d0, d1, *args):
|
||||
return d0 + d1
|
||||
|
||||
def fn_n1_kwargs(d0, d1, **kwargs):
|
||||
return d0 + d1
|
||||
|
||||
def fn_kwonly(*, d0, d1):
|
||||
return d0 + d1
|
||||
|
||||
def fn_has_nondefault_kwonly(d0, *, d1):
|
||||
return d0 + d1
|
||||
|
||||
def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
|
||||
return d0 + d1
|
||||
|
||||
p_fn_cmplx = partial(fn_cmplx, d2=2)
|
||||
|
||||
# Prevent modification in-place to support resetting
|
||||
def _dict_update(data, newdata, remove_idx=None):
|
||||
_data = dict(data)
|
||||
|
|
@ -1304,13 +1363,33 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
# Reset
|
||||
self.assertEqual(list(res_dp), list(ref_dp))
|
||||
|
||||
_helper(lambda data: data, fn_n1_def, 'x', 'y')
|
||||
_helper(lambda data: data, p_fn_n1, 'x', 'y')
|
||||
_helper(lambda data: data, p_fn_cmplx, 'x', 'y')
|
||||
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
|
||||
p_fn_cmplx, ["x", "y", "z"], "z")
|
||||
|
||||
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ['x', 'y'], 'z')
|
||||
|
||||
_helper(None, fn_n1_pos, 'x', error=ValueError)
|
||||
_helper(None, fn_n1_kwargs, 'x', error=ValueError)
|
||||
# non-default kw-only args
|
||||
_helper(None, fn_kwonly, ['x', 'y'], error=ValueError)
|
||||
_helper(None, fn_has_nondefault_kwonly, ['x', 'y'], error=ValueError)
|
||||
_helper(None, fn_cmplx, ['x', 'y'], error=ValueError)
|
||||
|
||||
|
||||
# Replacing with one input column and default output column
|
||||
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
|
||||
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
|
||||
# The key of input column is not in dict
|
||||
_helper(None, fn_1n, "a", error=KeyError)
|
||||
# Unmatched input columns with fn arguments
|
||||
_helper(None, fn_n1, "y", error=TypeError)
|
||||
_helper(None, fn_n1, "y", error=ValueError)
|
||||
_helper(None, fn_1n, ["x", "y"], error=ValueError)
|
||||
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
|
||||
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
|
||||
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
|
||||
# Replacing with multiple input columns and default output column (the left-most input column)
|
||||
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
|
||||
_helper(lambda data: _dict_update(
|
||||
|
|
@ -1508,6 +1587,32 @@ class TestFunctionalIterDataPipe(TestCase):
|
|||
input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2])
|
||||
self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)])
|
||||
|
||||
# invalid input col
|
||||
with self.assertRaises(ValueError):
|
||||
tuple_input_ds.filter(_mul_filter_fn, input_col=0)
|
||||
|
||||
p_mul_filter_fn = partial(_mul_filter_fn, b=1)
|
||||
out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0)
|
||||
self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
|
||||
|
||||
def _mul_filter_fn_with_defaults(a, b=1):
|
||||
return a + b < 10
|
||||
|
||||
out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0)
|
||||
self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])
|
||||
|
||||
def _mul_filter_fn_with_kw_only(*, a, b):
|
||||
return a + b < 10
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0)
|
||||
|
||||
def _mul_filter_fn_with_kw_only_1_default(*, a, b=1):
|
||||
return a + b < 10
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0)
|
||||
|
||||
# __len__ Test: DataPipe has no valid len
|
||||
with self.assertRaisesRegex(TypeError, r"has no len"):
|
||||
len(filter_dp)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from torch.utils.data.datapipes._decorator import functional_datapipe
|
|||
from torch.utils.data._utils.collate import default_collate
|
||||
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
|
||||
from torch.utils.data.datapipes.datapipe import IterDataPipe
|
||||
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
|
||||
from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn,
|
||||
validate_input_col)
|
||||
|
||||
__all__ = [
|
||||
"CollatorIterDataPipe",
|
||||
|
|
@ -80,6 +81,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
|
|||
raise ValueError("`output_col` must be a single-element list or tuple")
|
||||
output_col = output_col[0]
|
||||
self.output_col = output_col
|
||||
validate_input_col(fn, input_col)
|
||||
|
||||
def _apply_fn(self, data):
|
||||
if self.input_col is None and self.output_col is None:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from torch.utils.data.datapipes.utils.common import (
|
|||
_check_unpickable_fn,
|
||||
_deprecation_warning,
|
||||
StreamWrapper,
|
||||
validate_input_col
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -69,6 +70,7 @@ class FilterIterDataPipe(IterDataPipe[T_co]):
|
|||
self.drop_empty_batches = drop_empty_batches
|
||||
|
||||
self.input_col = input_col
|
||||
validate_input_col(filter_fn, input_col)
|
||||
|
||||
def _apply_filter_fn(self, data) -> bool:
|
||||
if self.input_col is None:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||
from torch.utils.data._utils.serialization import DILL_AVAILABLE
|
||||
|
||||
__all__ = [
|
||||
"validate_input_col",
|
||||
"StreamWrapper",
|
||||
"get_file_binaries_from_pathnames",
|
||||
"get_file_pathnames_from_root",
|
||||
|
|
@ -20,6 +21,86 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
|
||||
"""
|
||||
Checks that function used in a callable datapipe works with the input column
|
||||
|
||||
This simply ensures that the number of positional arguments matches the size
|
||||
of the input column. The function must not contain any non-default
|
||||
keyword-only arguments.
|
||||
|
||||
Examples:
|
||||
>>> def f(a, b, *, c=1):
|
||||
>>> return a + b + c
|
||||
>>> def f_def(a, b=1, *, c=1):
|
||||
>>> return a + b + c
|
||||
>>> assert validate_input_col(f, [1, 2])
|
||||
>>> assert validate_input_col(f_def, 1)
|
||||
>>> assert validate_input_col(f_def, [1, 2])
|
||||
|
||||
Notes:
|
||||
If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments,
|
||||
for example, f(a, *args), the validator will accept any size of input column
|
||||
greater than or equal to the number of positional arguments.
|
||||
(in this case, 1).
|
||||
|
||||
Args:
|
||||
fn: The function to check.
|
||||
input_col: The input column to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the function is not compatible with the input column.
|
||||
"""
|
||||
sig = inspect.signature(fn)
|
||||
if isinstance(input_col, (list, tuple)):
|
||||
input_col_size = len(input_col)
|
||||
else:
|
||||
input_col_size = 1
|
||||
|
||||
fn_name = str(fn)
|
||||
|
||||
pos = []
|
||||
var_positional = False
|
||||
non_default_kw_only = []
|
||||
|
||||
for p in sig.parameters.values():
|
||||
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
|
||||
pos.append(p)
|
||||
elif p.kind is inspect.Parameter.VAR_POSITIONAL:
|
||||
var_positional = True
|
||||
elif p.kind is inspect.Parameter.KEYWORD_ONLY:
|
||||
if p.default is p.empty:
|
||||
non_default_kw_only.append(p)
|
||||
else:
|
||||
continue
|
||||
|
||||
if len(non_default_kw_only) > 0:
|
||||
raise ValueError(
|
||||
f"The function {fn_name} takes {len(non_default_kw_only)} "
|
||||
f"non-default keyword-only parameters, which is not allowed."
|
||||
)
|
||||
|
||||
if len(sig.parameters) < input_col_size:
|
||||
if not var_positional:
|
||||
raise ValueError(
|
||||
f"The function {fn_name} takes {len(sig.parameters)} "
|
||||
f"parameters, but {input_col_size} are required."
|
||||
)
|
||||
else:
|
||||
if len(pos) > input_col_size:
|
||||
if any(p.default is p.empty for p in pos[input_col_size:]):
|
||||
raise ValueError(
|
||||
f"The function {fn_name} takes {len(pos)} "
|
||||
f"positional parameters, but {input_col_size} are required."
|
||||
)
|
||||
elif len(pos) < input_col_size:
|
||||
if not var_positional:
|
||||
raise ValueError(
|
||||
f"The function {fn_name} takes {len(pos)} "
|
||||
f"positional parameters, but {input_col_size} are required."
|
||||
)
|
||||
|
||||
|
||||
def _is_local_fn(fn):
|
||||
# Functions or Methods
|
||||
if hasattr(fn, "__code__"):
|
||||
|
|
@ -33,7 +114,6 @@ def _is_local_fn(fn):
|
|||
return "<locals>" in fn_type.__qualname__
|
||||
return False
|
||||
|
||||
|
||||
def _check_unpickable_fn(fn: Callable):
|
||||
"""
|
||||
Checks function is pickable or not. If it is a lambda or local function, a UserWarning
|
||||
|
|
|
|||
Loading…
Reference in a new issue