diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index ae18052a3..4ca69ad3f 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -94,8 +94,8 @@ class HfArgumentParser(ArgumentParser): field.type = prim_type if isinstance(field.type, type) and issubclass(field.type, Enum): - kwargs["choices"] = list(field.type) - kwargs["type"] = field.type + kwargs["choices"] = [x.value for x in field.type] + kwargs["type"] = type(kwargs["choices"][0]) if field.default is not dataclasses.MISSING: kwargs["default"] = field.default elif field.type is bool or field.type is Optional[bool]: @@ -198,7 +198,7 @@ class HfArgumentParser(ArgumentParser): data = json.loads(Path(json_file).read_text()) outputs = [] for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype)} + keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in data.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) @@ -211,7 +211,7 @@ class HfArgumentParser(ArgumentParser): """ outputs = [] for dtype in self.dataclass_types: - keys = {f.name for f in dataclasses.fields(dtype)} + keys = {f.name for f in dataclasses.fields(dtype) if f.init} inputs = {k: v for k, v in args.items() if k in keys} obj = dtype(**inputs) outputs.append(obj) diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py index db937de93..22493a23b 100644 --- a/tests/test_hf_argparser.py +++ b/tests/test_hf_argparser.py @@ -55,7 +55,10 @@ class BasicEnum(Enum): @dataclass class EnumExample: - foo: BasicEnum = BasicEnum.toto + foo: BasicEnum = "toto" + + def __post_init__(self): + self.foo = BasicEnum(self.foo) @dataclass @@ -133,14 +136,18 @@ class HfArgumentParserTest(unittest.TestCase): parser = HfArgumentParser(EnumExample) expected = argparse.ArgumentParser() - expected.add_argument("--foo", default=BasicEnum.toto, choices=list(BasicEnum), type=BasicEnum) + expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str) self.argparsersEqual(parser, expected) args = parser.parse_args([]) - self.assertEqual(args.foo, BasicEnum.toto) + self.assertEqual(args.foo, "toto") + enum_ex = parser.parse_args_into_dataclasses([])[0] + self.assertEqual(enum_ex.foo, BasicEnum.toto) args = parser.parse_args(["--foo", "titi"]) - self.assertEqual(args.foo, BasicEnum.titi) + self.assertEqual(args.foo, "titi") + enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0] + self.assertEqual(enum_ex.foo, BasicEnum.titi) def test_with_list(self): parser = HfArgumentParser(ListExample)