diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.cc b/orttraining/orttraining/core/framework/gradient_graph_builder.cc index c9cda92972..d66591318d 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.cc +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.cc @@ -18,6 +18,44 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace training { +namespace { + +const std::unordered_set GRAD_ALLOWED_TYPES{ + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, + ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, +}; + +std::tuple IsAllowedForGradient(const Graph* graph, const NodeArg* node_arg) { + // This is a temporary workaround for Send nodes, who output a bool. If we don't allow building gradient for + // it, the PipelineTransformer tests will fail. We should revisit the PipelineTransformer and fix it in a + // better way. (While it might not be a priority for us now.) + bool skip_elem_type_check = false; + const Node* node = graph->GetProducerNode(node_arg->Name()); + if (node && (node->OpType() == "Send")) { + skip_elem_type_check = true; + } + + bool is_tensor_type = false; + bool is_allowed_type_for_grad = false; + const auto* type_proto = node_arg->TypeAsProto(); + int32_t type = -1; + if (nullptr != type_proto && type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { + is_tensor_type = true; + type = type_proto->tensor_type().elem_type(); + + if (skip_elem_type_check) { + is_allowed_type_for_grad = true; + } else { + is_allowed_type_for_grad = GRAD_ALLOWED_TYPES.find(type) != GRAD_ALLOWED_TYPES.end(); + } + } + + return {is_tensor_type, is_allowed_type_for_grad, type}; +} +} // namespace + using namespace common; GradientGraphBuilder::GradientGraphBuilder(Graph* graph, @@ -57,10 +95,13 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph, const Node* node = graph_->GetProducerNode(name); if (node) { - if (forward_reachable_nodes.find(node) == forward_reachable_nodes.end()) { + const auto rets = IsAllowedForGradient(graph_, node_arg); + bool is_allowed_type_for_grad = std::get<1>(rets); + if (forward_reachable_nodes.find(node) == forward_reachable_nodes.end() || !is_allowed_type_for_grad) { non_differentiable_y_node_arg_names_.insert(name); LOGS(logger_, INFO) << "The model weights and inputs are non-differentiable from " << name << ". " - << "ORT will assume no gradient will be provided for " << name << "."; + << "ORT will assume no gradient will be provided for " << name + << ", is_allowed_type_for_grad: " << is_allowed_type_for_grad; } else { y_node_args_.insert(node_arg); y_nodes_.insert(node); @@ -169,17 +210,16 @@ NodeSet GradientGraphBuilder::ReverseBFSWithStopGradient(const NodeSet& nodes) c continue; } const NodeArg* node_arg = n->InputDefs()[edge_it->GetDstArgIndex()]; - const auto* type_proto = node_arg->TypeAsProto(); - if (nullptr != type_proto && type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kTensorType) { - const int32_t type = type_proto->tensor_type().elem_type(); - if (GRAD_ALLOWED_TYPES.find(type) == GRAD_ALLOWED_TYPES.end()) { + const auto [is_tensor_type, is_allowed_type_for_grad, type] = IsAllowedForGradient(graph_, node_arg); + if (is_tensor_type) { + if (!is_allowed_type_for_grad) { LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name() << "because element type is: " << type; + << " of node: " << n->Name() << " because element type is: " << type; continue; } } else { LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() - << " of node: " << n->Name() << "because it is not a Tensor type"; + << " of node: " << n->Name() << " because it is not a Tensor type"; continue; } @@ -286,6 +326,21 @@ Status GradientGraphBuilder::Build(const std::unordered_set* p_init } const NodeArg* node_arg = node->OutputDefs()[edge_it->GetSrcArgIndex()]; + + // Make sure node_arg as input of next_node, has the data type allowed to compute gradient. + const auto [is_tensor_type, is_allowed_type_for_grad, type] = IsAllowedForGradient(graph_, node_arg); + if (is_tensor_type) { + if (!is_allowed_type_for_grad) { + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << next_node.Name() << " because element type is: " << type; + continue; + } + } else { + LOGS(logger_, VERBOSE) << "Skip building gradient for input_" << edge_it->GetDstArgIndex() + << " of node: " << next_node.Name() << " because it is not a Tensor type"; + continue; + } + std::string grad_node_arg_name = GradientBuilderBase::GradientName(node_arg->Name()); pending_[grad_node_arg_name] += 1; diff --git a/orttraining/orttraining/core/framework/gradient_graph_builder.h b/orttraining/orttraining/core/framework/gradient_graph_builder.h index 1552ad8dc5..8068d4825c 100644 --- a/orttraining/orttraining/core/framework/gradient_graph_builder.h +++ b/orttraining/orttraining/core/framework/gradient_graph_builder.h @@ -150,19 +150,13 @@ class GradientGraphBuilder { // The 1st and 3rd inputs are not differentiable. std::unordered_map> python_op_input_require_grad_info_; - const std::unordered_set GRAD_ALLOWED_TYPES{ - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, - ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, - }; const std::unordered_set* GetStopGradientEdges(const Node& node) const; /** Performs a BFS on the graph with STOP_GRADIENT_EDGES constrain It will skip traversing over the edges defined in STOP_GRADIENT_EDGES map. The resulting node set contains all the nodes that are differentiable wrt the x_node_args - @param Starting nodes arg name for BFS + @param x_node_arg_names Starting nodes arg name for BFS @returns All the nodes visited during BFS */ NodeSet BFSWithStopGradient(const std::unordered_set& x_node_arg_names) const; @@ -171,14 +165,13 @@ class GradientGraphBuilder { Performs a ReverseBFS on the graph with STOP_GRADIENT_EDGES constrain It will skip traversing over the edges defined in STOP_GRADIENT_EDGES map. The resulting node set contains all the nodes that are differentiable wrt the input nodes - @param Starting nodes for ReverseBFS + @param nodes Starting nodes for ReverseBFS @returns All the nodes visited during ReverseBFS */ NodeSet ReverseBFSWithStopGradient(const NodeSet& nodes) const; /** - Check if 'x_node_args_' are reachable from 'y_node_args_' for computing the partial derivative - @param reachable_nodes All the nodes reachable from the 'y_node_args_' + Check if 'x_node_args_' are reachable from 'y_node_args_' for computing the partial derivative. @returns OK if all 'x_node_args_' are reachable, else an ONNXRUNTIME INVALID_ARGUMENT status */ Status CheckNodeArgsReachable() const; diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index bfc6e4a5bb..c5948e563f 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -309,7 +309,9 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() { if (std::find(non_differentiable_indices.begin(), non_differentiable_indices.end(), i) == non_differentiable_indices.end()) { - yield_output_node_args.emplace_back(gradient_graph.GetNodeArg(grad_name)); + NodeArg* grad_node_arg = gradient_graph.GetNodeArg(grad_name); + ORT_ENFORCE(grad_node_arg != nullptr, "Differentiable param grad node arg should exist."); + yield_output_node_args.emplace_back(grad_node_arg); graph_info_.module_output_gradient_name.emplace_back(grad_name); } } @@ -335,9 +337,10 @@ void OrtModuleGraphBuilder::HandleOutputsAndGrads() { attributes.insert({full_shape_outputs_name, full_shape_outputs}); - // Handle potential duplciated output_gradient names + // Handle potential duplicated output_gradient names std::unordered_map> name_to_idx; for (size_t i = 0; i < yield_output_node_args.size(); ++i) { + ORT_ENFORCE(yield_output_node_args[i] != nullptr); const std::string& name = yield_output_node_args[i]->Name(); auto it = name_to_idx.find(name); if (it == name_to_idx.end()) {