diff --git a/test/edge/custom_ops.yaml b/test/edge/custom_ops.yaml index b85fd12bd32..2ff2db88f97 100644 --- a/test/edge/custom_ops.yaml +++ b/test/edge/custom_ops.yaml @@ -1,3 +1,4 @@ - func: custom::add_3.out(Tensor a, Tensor b, Tensor c, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - CPU: custom::add_3_out + kernels: + - arg_meta: null + kernel_name: custom::add_3_out diff --git a/torchgen/executorch/parse.py b/torchgen/executorch/parse.py index b7ae5b3b6df..f6f30b4554a 100644 --- a/torchgen/executorch/parse.py +++ b/torchgen/executorch/parse.py @@ -130,9 +130,7 @@ def parse_et_yaml( et_kernel = extract_kernel_fields(es) # Remove ET specific fields from entries for BC compatibility - for entry in es: - for field in ET_FIELDS: - entry.pop(field, None) + strip_et_fields(es) native_yaml = parse_native_yaml( path, @@ -142,3 +140,12 @@ def parse_et_yaml( loaded_yaml=es, ) return native_yaml.native_functions, et_kernel + + +def strip_et_fields(es: object) -> None: + """Given a loaded yaml representing a list of operators, + remove ET specific fields from every entries for BC compatibility + """ + for entry in es: # type: ignore[attr-defined] + for field in ET_FIELDS: + entry.pop(field, None)