diff --git a/test/test_foreach.py b/test/test_foreach.py index 784cb09d94c..91eb860fe38 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -959,59 +959,6 @@ class TestForeach(TestCase): sample.args = new_args _test(func, sample) - @ops( - foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db, - dtypes=OpDTypes.supported, - allowed_dtypes=(torch.float64, torch.complex128), - ) - def test_outplace_forward_mode_AD(self, device, dtype, op): - if not op.supports_forward_ad: - self.skipTest("forward AD not supported") - - # note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex. - if dtype == torch.float64 and op.name in ( - "_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2", - "_foreach_log", "_foreach_pow", "_foreach_sqrt", - ): - value_range = {"low": 0.5, "high": 1.0} - else: - value_range = {} - for sample in op.sample_inputs( - device, dtype, requires_grad=True, num_input_tenosrs=[5], same_size=True, **value_range, - ): - # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])` - if op.name == "_foreach_pow" and isinstance(sample.input, Number): - continue - - def func(*tensorlist): - kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {} - return op.method_variant(tensorlist, *sample.args, **kwargs) - - working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, False) - if not working_sample: - if not err_msg_pattern: - # lhs of float64 and rhs of complex. - continue - with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)): - gradcheck( - func, - sample.input, - raise_exception=True, - check_forward_ad=True, - check_batched_forward_grad=False, - check_backward_ad=False, - check_batched_grad=False, - ) - else: - gradcheck( - func, - sample.input, - raise_exception=True, - check_forward_ad=True, - check_backward_ad=False, - check_batched_grad=False, - ) - @ops( foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_lerp_op_db, dtypes=OpDTypes.supported, @@ -1021,6 +968,36 @@ class TestForeach(TestCase): if not op.supports_forward_ad: self.skipTest("forward AD not supported") + # note(crcrpar): The combinations below are failing in its forward path, + # which is before forward-mode AD happens. This function gates the combinations where + # - subtraction with Scalar/ScalarList of boolean value: + # - combinations where the in-place op in questions tries to write out complex result + # into float storage (= `self`) + def check_sample_eligibility(op, sample, dtype): + if ( + op.name == "_foreach_sub" + and ( + (isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0])) + or isinstance(sample.args[0], bool) + ) + ): + return False, _BOOL_SUB_ERR_MSG + rhs_arg_has_complex_number = sample.args and (( + isinstance(sample.args[0], list) + and any(isinstance(a, complex) for a in sample.args[0]) + ) or ( + isinstance(sample.args[0], complex) + )) + if dtype == torch.float64 and rhs_arg_has_complex_number: + if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"): + return False, "result type ComplexDouble can't be cast to the desired output type Double" + if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"): + return False, "clamp is not supported for complex types" + if op.name == "_foreach_pow": + return False, "Found dtype Double but expected ComplexDouble" + + return True, "" + for sample in op.sample_inputs( device, dtype, requires_grad=True, num_input_tensors=[5], same_size=True, ): @@ -1031,7 +1008,7 @@ class TestForeach(TestCase): op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs) return tensorlist - working_sample, err_msg_pattern = check_forward_mode_AD_sample(op, sample, dtype, True) + working_sample, err_msg_pattern = check_sample_eligibility(op, sample, dtype) if not working_sample: with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)): gradcheck( @@ -1083,38 +1060,6 @@ class TestForeach(TestCase): self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list) -# TODO(crcrpar): Hide this inside torch/testing/_internal. -# would end up adding another layer to `foreach_inputs_sample_func.__call__` -# so that we can use this function as something like the first argument of `filter` function. -# Even after moving this function to testing, I personally think it'd be better to check the error message. -def check_forward_mode_AD_sample(op, sample, dtype, is_inplace): - if ( - op.name == "_foreach_sub" - and ( - (isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0])) - or isinstance(sample.args[0], bool) - ) - ): - return False, _BOOL_SUB_ERR_MSG - rhs_arg_has_complex_number = sample.args and (( - isinstance(sample.args[0], list) - and any(isinstance(a, complex) for a in sample.args[0]) - ) or ( - isinstance(sample.args[0], complex) - )) - if rhs_arg_has_complex_number and dtype == torch.float64: - if op.name in ("_foreach_clamp_max", "_foreach_clamp_min"): - return False, "clamp is not supported for complex types" - if not is_inplace: - return False, "" - else: - if op.name == "_foreach_pow": - return False, "Found dtype Double but expected ComplexDouble" - if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"): - return False, "result type ComplexDouble can't be cast to the desired output type Double" - return True, "" - - instantiate_device_type_tests(TestForeach, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4eccc8fb002..ab7152fe718 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -3040,16 +3040,13 @@ - name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type()) other: div_tensor_other_backward(grads[i], self[i], other[i]) - result: (self_t - other_t * result[i]) / other_p - name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] self: pow_backward_self(grads[i], self[i], exponent[i]) exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i]) - result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj() - name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] self: pow_backward(grads[i], self[i], exponent[i]) - result: pow_backward(self_t.conj(), self_p, exponent[i]).conj() - name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i]) @@ -3058,22 +3055,12 @@ # of `maximum` and `minimum` don't have the overload def with Scalar as their second argument. - name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0) - result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar) - name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0) - result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) - name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0) - result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar) - name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0) - result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) - -# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle: -# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)` -- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[] - self: norm_backward(grads[i], self[i], ord, result[i]) - result: norm_jvp(self_p, self_t, ord, result[i]) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index f1a85310b99..5fb93e12d67 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -1769,7 +1769,7 @@ def emit_body( def emit_any_has_forward_grad() -> List[str]: content: List[str] = [] - if not is_foreach: + if not is_inplace_foreach: for derivative in fw_derivatives: requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) if info and info.output_differentiability_conditions: @@ -1783,17 +1783,11 @@ def emit_body( bool_vector_name = get_any_has_forward_grad_name(derivative.var_names) cur_derivative_conditions = [ FW_DERIVATIVE_CHECK_TEMPLATE.substitute( - req_inp=( - inp.name - if not inplace - else refargname2inplace_foreacharg[inp.name].name - ) + req_inp=refargname2inplace_foreacharg[inp.name].name + ( "[i]" if is_tensor_list_type( - inp.type - if not inplace - else refargname2inplace_foreacharg[inp.name].type + refargname2inplace_foreacharg[inp.name].type ) else "" ), @@ -1835,10 +1829,8 @@ def emit_body( unpacked_arguments = "" for inp in differentiable_inputs: inp_name = inp.name - is_input_tensorlist = is_foreach and is_tensor_list_type( - inp.type - if not inplace - else refargname2inplace_foreacharg[inp.name].type + is_input_tensorlist = is_inplace_foreach and is_tensor_list_type( + refargname2inplace_foreacharg[inp.name].type ) input_suffix = "[i]" if is_input_tensorlist else "" if is_inplace_foreach: @@ -1895,14 +1887,14 @@ def emit_body( # Is there a way to get from BaseType to BaseCType if len(derivative.var_types) == 1: opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() - if not is_foreach: + if not is_inplace_foreach: fw_grad_setters.append( FW_DERIVATIVE_SETTER_TENSOR.substitute( out_arg=res[0], is_inplace=is_inplace_str ) ) else: - assert res[0] == ("result" if not inplace else "self") + assert res[0] == "self" fw_grad_setters.append( FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( out_arg=res[0], is_inplace=is_inplace_str @@ -1927,29 +1919,18 @@ def emit_body( assert ( len(derivative.var_types) == 1 ), "Expected number of outputs to be 1 if function returns ListType" - if not is_foreach: - opt_res_grad_type = OptionalCType( - VectorCType(BaseCType(tensorT)) - ).cpp_type() - fw_grad_setters.append( - FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute( - out_arg=res[0], is_inplace=is_inplace_str - ) - ) - else: - # TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow? - # Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml` - # can reach here. - opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() - fw_grad_setters.append( - FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( - out_arg=res[0], is_inplace=is_inplace_str - ) + opt_res_grad_type = OptionalCType( + VectorCType(BaseCType(tensorT)) + ).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute( + out_arg=res[0], is_inplace=is_inplace_str ) + ) else: raise RuntimeError("Unsupported output type for forward derivative") - if not is_foreach: + if not is_inplace_foreach: fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = c10::nullopt;" # View ops create fw_grad that already is a view of the base's fw_grad so just use that content.append( @@ -1967,30 +1948,20 @@ def emit_body( f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts" "(self.size(), c10::nullopt);" ) - foreach_forward_grad_formula = derivative.formula - _foreach_arg: Union[Argument, DifferentiableInput] - if inplace: - for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): - # note(crcrpar): Massage only Scalar and ArrayRef here. - if not ( - is_tensor_type(_foreach_arg.type) - or is_tensor_list_type(_foreach_arg.type) - ): - pattern = _foreach_arg.name - if isinstance(_foreach_arg.type, ListType): - pattern += "[i]" - foreach_forward_grad_formula = ( - foreach_forward_grad_formula.replace( - _ref_arg.name, pattern - ) - ) - else: - if ( - "result" in foreach_forward_grad_formula - and "result[i]" not in foreach_forward_grad_formula + inplace_foreach_forward_grad_formula = derivative.formula + for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): + # note(crcrpar): Massage only Scalar and ArrayRef here. + if not ( + is_tensor_type(_foreach_arg.type) + or is_tensor_list_type(_foreach_arg.type) ): - foreach_forward_grad_formula = ( - foreach_forward_grad_formula.replace("result", "result[i]") + pattern = _foreach_arg.name + if isinstance(_foreach_arg.type, ListType): + pattern += "[i]" + inplace_foreach_forward_grad_formula = ( + inplace_foreach_forward_grad_formula.replace( + _ref_arg.name, pattern + ) ) content.append( @@ -2001,7 +1972,7 @@ def emit_body( get_any_has_forward_grad_name(derivative.var_names) + "[i]" for derivative in fw_derivatives ), - formula=foreach_forward_grad_formula, + formula=inplace_foreach_forward_grad_formula, unpacked_arguments=unpacked_arguments, ) ) @@ -2063,11 +2034,7 @@ def emit_body( else: any_has_fw_grad = " || ".join( [ - ( - FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE - if is_tensor_list_type(inp.type) - else FW_DERIVATIVE_CHECK_TEMPLATE - ).substitute(req_inp=inp.name) + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) for inp in differentiable_inputs if inp.name in derivative.required_inputs_fw_grad ] diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 18ebbc73df6..5bba7a3baf7 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -272,13 +272,9 @@ def postprocess_forward_derivatives( args_with_derivatives: Sequence[Binding], ) -> List[ForwardDerivative]: def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: - is_foreach = f.func.name.name.base.startswith("_foreach_") required_inputs = set() for arg in args_with_derivatives: - if ( - arg.type in ("at::TensorList", "const at::ITensorListRef &") - and not is_foreach - ): + if arg.type in ("at::TensorList", "const at::ITensorListRef &"): # The functions taking TensorList handle everything internally continue arg_name = arg.name diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index fa8c54488c9..ca22a0b6528 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8779,15 +8779,6 @@ foreach_unary_op_db: List[OpInfo] = [ sample_inputs_func=foreach_inputs_sample_func(1, False, False), supports_autograd=True, supports_forward_ad=True, - skips=( - # note(crcrpar): excluding cdouble from dtypes above might be better. - # Guard for `error: "In-place abs is not supported for complex tensors."` - DecorateInfo( - unittest.skip("_foreach_zero is not implemented"), - 'TestForeach', - 'test_outplace_forward_mode_AD', - ), - ), ), ] diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 4ffb75820bf..a9cf148b77b 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -462,82 +462,13 @@ def gen_foreach_derivativeinfo( for arg in foreach_function.func.arguments.flat_non_out if arg.name in all_var_names ] - - forward_derivatives: List[ForwardDerivative] = [] - fw_derivative: ForwardDerivative - for fw_derivative in ref_diff_info.forward_derivatives: - var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef] - var_types: List[Type] = list(fw_derivative.var_types) - required_inputs_fw_grad: List[str] = [] - required_inputs_primal: List[str] = [] - if fw_derivative.required_inputs_fw_grad is not None: - required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) - if fw_derivative.required_inputs_primal: - required_inputs_primal = list(fw_derivative.required_inputs_primal) - modified_formula = fw_derivative.formula - - # Foreach's result is TensorList - if "result" in modified_formula: - modified_formula = fw_derivative.formula.replace("result", "result[i]") - - for foreach_arg, ref_arg in zip( - foreach_function.func.arguments.flat_non_out, - ref_diff_info.func.func.arguments.flat_non_out, - ): - # Modify reference forward formula - if ( - isinstance(foreach_arg.type, ListType) - and not foreach_arg.type.is_tensor_like() - ): - # Assuming ScalarList - modified_formula = modified_formula.replace( - ref_arg.name, foreach_arg.name + "[i]" - ) - elif foreach_arg.type.is_tensor_like(): - # Assuming TensorList / Tensor - assert isinstance(foreach_arg.type, ListType) - for suffix in ("_p", "_t"): - curr_expr = ref_arg.name + suffix - if curr_expr in modified_formula: - new_expr = foreach_arg.name + suffix - modified_formula = modified_formula.replace(curr_expr, new_expr) - else: - # Assuming Scalar - if foreach_arg.name != ref_arg.name: - modified_formula = modified_formula.replace( - ref_arg.name, foreach_arg.name - ) - - # note(crcrpar): there should exist a cooler way... - for i, name in enumerate(var_names): - if name == ref_arg.name: - var_names[i] = foreach_arg.name - var_types[i] = foreach_arg.type - for i, name in enumerate(required_inputs_fw_grad): - if name == ref_arg.name: - required_inputs_fw_grad[i] = foreach_arg.name - for i, name in enumerate(required_inputs_primal): - if name == ref_arg.name: - required_inputs_primal[i] = foreach_arg.name - forward_derivatives.append( - ForwardDerivative( - formula=modified_formula, - var_names=tuple(var_names), - var_types=tuple(var_types), - required_inputs_fw_grad=tuple(required_inputs_fw_grad), - required_inputs_primal=tuple(required_inputs_primal), - required_original_self_value=fw_derivative.required_original_self_value, - is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, - ) - ) - return ( DifferentiabilityInfo( name=foreach_function.func.name.name.base, func=foreach_function, op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", derivatives=modified_derivative_formulas, - forward_derivatives=forward_derivatives, + forward_derivatives=[], all_saved_inputs=tuple(set(all_saved_inputs)), all_saved_outputs=tuple(set(all_saved_outputs)), available_named_gradients=(),