diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc index 96c6bfa92c..6689cd10cc 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.cc @@ -39,12 +39,10 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream, graph_info_.user_output_names.emplace_back(node_arg->Name()); } - graph_info_.initializer_names_to_train = std::unordered_set( - config_.initializer_names_to_train.begin(), - config_.initializer_names_to_train.end()); - graph_info_.initializer_names = std::unordered_set( - config_.initializer_names.begin(), - config_.initializer_names.end()); + graph_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(), + config.initializer_names_to_train.end()); + graph_info_.initializer_names.assign(config.initializer_names.begin(), + config.initializer_names.end()); std::vector input_args; for (const auto& input_name : graph_info_.user_input_names) { @@ -320,7 +318,7 @@ void OrtModuleGraphBuilder::ReorderOutputs() { std::string initializer_gradient_name = GradientBuilderBase::GradientName(initializer_name); ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(), "Trainable initializer grad is not found on gradient graph."); - graph_info_.initializer_grad_names_to_train.emplace(initializer_gradient_name); + graph_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name); new_output_args.emplace_back(gradient_output_arg_map[initializer_gradient_name]); } diff --git a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h index 139ef468f9..9925be3a20 100644 --- a/orttraining/orttraining/core/framework/ortmodule_graph_builder.h +++ b/orttraining/orttraining/core/framework/ortmodule_graph_builder.h @@ -44,11 +44,11 @@ struct GraphInfo { // Map from user input names to corresponding user input grad names for those user inputs that require grad. std::unordered_map user_input_grad_names{}; // All initializers (trainable as well as non trainable). - std::unordered_set initializer_names{}; + std::vector initializer_names{}; // Trainable initializers. - std::unordered_set initializer_names_to_train{}; + std::vector initializer_names_to_train{}; // Trainable initializer grad names, ordered according to initializer_names_to_train. - std::unordered_set initializer_grad_names_to_train{}; + std::vector initializer_grad_names_to_train{}; // The user outputs. std::vector user_output_names{}; // Indices of output grads that are non-differentiable. diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 0202e46b1c..6624d46df3 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -52,6 +52,8 @@ class GraphExecutionManager(ABC): self._optimized_onnx_model = None self._graph_builder = None self._graph_info = None + self._graph_initializer_names = None + self._graph_initializer_names_to_train = None # TrainingAgent or InferenceAgent self._execution_agent = None @@ -156,6 +158,11 @@ class GraphExecutionManager(ABC): self._optimized_onnx_model = onnx.load_model_from_string(self._graph_builder.get_model()) self._graph_info = self._graph_builder.get_graph_info() + # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train + # a set (unordered_set in the backend) that does not require a copy on each reference. + self._graph_initializer_names = set(self._graph_info.initializer_names) + self._graph_initializer_names_to_train = set(self._graph_info.initializer_names_to_train) + def _get_session_config(self): """Creates and returns the session configuration to be used for the ExecutionAgent""" providers = None diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index 5359e05b6a..3ca6514248 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -90,7 +90,7 @@ class InferenceManager(GraphExecutionManager): self._device, *_io._combine_input_buffers_initializers( [param for name, param in self._flattened_module.named_parameters() - if name in self._graph_info.initializer_names], + if name in self._graph_initializer_names], self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 44adc5628f..848340b8fa 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -174,8 +174,8 @@ class TrainingManager(GraphExecutionManager): # Append gradients of initializer to results # Go over each initializer, check if it required grad and append to results accordingly initializer_index = num_user_input_grads - for initializer_name, _ in self._flattened_module.named_parameters(): - if initializer_name in self._graph_info.initializer_names_to_train: + for initializer_name in self._graph_info.initializer_names: + if initializer_name in self._graph_initializer_names_to_train: results.append(_utils._ortvalue_to_torch_tensor(backward_outputs[initializer_index])) initializer_index += 1 else: @@ -188,7 +188,7 @@ class TrainingManager(GraphExecutionManager): _ORTModuleFunction.apply( *_io._combine_input_buffers_initializers( [param for name, param in self._flattened_module.named_parameters() - if name in self._graph_info.initializer_names], + if name in self._graph_initializer_names], self._graph_info.user_input_names, self._input_info, self._flattened_module.named_buffers(), @@ -236,13 +236,11 @@ class TrainingManager(GraphExecutionManager): initializer_names_to_train_set_user_model = {name for name, param in self._flattened_module.named_parameters() if param.requires_grad} - initializer_names_to_train_set_onnx_graph = self._graph_info.initializer_names_to_train \ - if self._graph_info else None # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad if input_info.require_grad_names != self._input_info.require_grad_names or \ - initializer_names_to_train_set_user_model != initializer_names_to_train_set_onnx_graph: + initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train: self._input_info = input_info self._initialize_graph_builder(training=True) return True diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 78c0876549..0418af4c14 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -197,6 +197,59 @@ class NeuralNetPartialNoGradModel(torch.nn.Module): out = self.fc2(out) return out +class UnusedEndParameterNet(torch.nn.Module): + def __init__(self, input_size, hidden_size1, hidden_size2, num_classes): + super(UnusedEndParameterNet, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size1) + self.relu = torch.nn.ReLU() + # fc2 is an unused initializer (which is in the end of initializer list) + # which will be dropped after export + self.fc2 = torch.nn.Linear(hidden_size1, hidden_size2) + self.register_buffer("buffer", torch.ones(hidden_size1)) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = out + self.buffer + return out + +class UnusedBeginParameterNet(torch.nn.Module): + def __init__(self, input_size, hidden_size1, hidden_size2, num_classes): + super(UnusedBeginParameterNet, self).__init__() + + # fc1 is an unused initializer (which is in the begining of initializer list) + # which will be dropped after export + self.fc1 = torch.nn.Linear(input_size, hidden_size1) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(input_size, hidden_size2) + self.register_buffer("buffer", torch.ones(hidden_size2)) + + def forward(self, input1): + out = self.fc2(input1) + out = self.relu(out) + out = out + self.buffer + return out + +class UnusedMiddleParameterNet(torch.nn.Module): + def __init__(self, input_size, hidden_size1, hidden_size2, num_classes): + super(UnusedMiddleParameterNet, self).__init__() + + self.fc1 = torch.nn.Linear(input_size, hidden_size1) + self.relu = torch.nn.ReLU() + # fc2 is an unused initializer (which is in the middle of initializer list) + # which will be dropped after export + self.fc2 = torch.nn.Linear(hidden_size1, hidden_size2) + self.fc3 = torch.nn.Linear(hidden_size1, num_classes) + self.register_buffer("buffer", torch.ones(num_classes)) + + def forward(self, input1): + out = self.fc1(input1) + out = self.relu(out) + out = self.fc3(out) + out = out + self.buffer + return out + # TODO: This is a workaround for the problem that pytest is still cleaning up the previous test # while the next task already start. @pytest.fixture(autouse=True) @@ -2393,27 +2446,15 @@ def test_model_with_registered_buffer_and_dropped_parameters(): # Ensure that no exceptions are raised out = model(bool_argument, x) -def test_unused_parameters(): - class UnusedParameterNet(torch.nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super(UnusedParameterNet, self).__init__() - - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.relu = torch.nn.ReLU() - # fc2 is an unused initializer which will be dropped after export - self.fc2 = torch.nn.Linear(hidden_size, num_classes) - self.register_buffer("buffer", torch.ones(hidden_size)) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = out + self.buffer - return out - +@pytest.mark.parametrize("model, none_pt_params", + [(UnusedBeginParameterNet(784, 500, 400, 10), ['fc1.weight', 'fc1.bias']), + (UnusedMiddleParameterNet(784, 500, 400, 10), ['fc2.weight', 'fc2.bias']), + (UnusedEndParameterNet(784, 500, 400, 10), ['fc2.weight', 'fc2.bias'])]) +def test_unused_parameters(model, none_pt_params): device = 'cuda' - N, D_in, H, D_out = 64, 784, 500, 10 - model = UnusedParameterNet(D_in, H, D_out).to(device) + N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 + model = model.to(device) ort_model = ORTModule(copy.deepcopy(model)) # Make sure model runs without any exception @@ -2429,7 +2470,7 @@ def test_unused_parameters(): loss_ort.backward() _test_helpers.assert_values_are_close(out_ort, out_pt) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, - none_pt_params=['fc2.weight', 'fc2.bias']) + none_pt_params=none_pt_params) # Also try in eval mode model.eval()