Support for unused model initializers (#7631)

* Support for unused model initializers

* Change graph_info.initializer* to sets
This commit is contained in:
baijumeswani 2021-05-11 12:26:56 -07:00 committed by GitHub
parent 88d2fc8f1e
commit c5aeaa9419
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 83 additions and 24 deletions

View file

@ -39,10 +39,12 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream,
graph_info_.user_output_names.emplace_back(node_arg->Name());
}
graph_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
config.initializer_names_to_train.end());
graph_info_.initializer_names.assign(config.initializer_names.begin(),
config.initializer_names.end());
graph_info_.initializer_names_to_train = std::unordered_set<std::string>(
config_.initializer_names_to_train.begin(),
config_.initializer_names_to_train.end());
graph_info_.initializer_names = std::unordered_set<std::string>(
config_.initializer_names.begin(),
config_.initializer_names.end());
std::vector<const NodeArg*> input_args;
for (const auto& input_name : graph_info_.user_input_names) {
@ -50,7 +52,7 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream,
}
// Remove all the initializers from the graph and move them to graph inputs.
for (const auto& initializer_name : graph_info_.initializer_names) {
for (const auto& initializer_name : config_.initializer_names) {
const NodeArg* node_arg = graph.GetNodeArg(initializer_name);
ORT_ENFORCE(node_arg != nullptr);
input_args.emplace_back(node_arg);
@ -314,11 +316,11 @@ void OrtModuleGraphBuilder::ReorderOutputs() {
// Add initializer gradients to graph outputs.
graph_info_.initializer_grad_names_to_train.clear();
for (const auto& initializer_name : graph_info_.initializer_names_to_train) {
for (const auto& initializer_name : config_.initializer_names_to_train) {
std::string initializer_gradient_name = GradientBuilderBase::GradientName(initializer_name);
ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(),
"Trainable initializer grad is not found on gradient graph.");
graph_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
graph_info_.initializer_grad_names_to_train.emplace(initializer_gradient_name);
new_output_args.emplace_back(gradient_output_arg_map[initializer_gradient_name]);
}

View file

@ -44,11 +44,11 @@ struct GraphInfo {
// Map from user input names to corresponding user input grad names for those user inputs that require grad.
std::unordered_map<std::string, std::string> user_input_grad_names{};
// All initializers (trainable as well as non trainable).
std::vector<std::string> initializer_names{};
std::unordered_set<std::string> initializer_names{};
// Trainable initializers.
std::vector<std::string> initializer_names_to_train{};
std::unordered_set<std::string> initializer_names_to_train{};
// Trainable initializer grad names, ordered according to initializer_names_to_train.
std::vector<std::string> initializer_grad_names_to_train{};
std::unordered_set<std::string> initializer_grad_names_to_train{};
// The user outputs.
std::vector<std::string> user_output_names{};
// Indices of output grads that are non-differentiable.

View file

@ -282,10 +282,15 @@ class GraphExecutionManager(ABC):
def _initialize_graph_builder(self, training):
"""Creates a new OrtModuleGraphBuilder, initializes it and saves it to self._graph_builder"""
# All initializer names along with user inputs are a part of the onnx graph inputs
# since the onnx model was exported with the flag keep_initializers_as_inputs=True
onnx_initializer_names = {p.name for p in self._onnx_model.graph.input}
# TODO: PyTorch exporter bug: changes the initializer order in ONNX model
initializer_names = [name for name, _ in self._flattened_module.named_parameters()]
initializer_names_to_train = [name for name,
param in self._flattened_module.named_parameters() if param.requires_grad]
initializer_names = [name for name, _ in self._flattened_module.named_parameters()
if name in onnx_initializer_names]
initializer_names_to_train = [name for name, param in self._flattened_module.named_parameters()
if param.requires_grad and name in onnx_initializer_names]
# Build and optimize the full graph
grad_builder_config = C.OrtModuleGraphBuilderConfiguration()

View file

@ -89,7 +89,8 @@ class InferenceManager(GraphExecutionManager):
self._optimized_onnx_model,
self._device,
*_io._combine_input_buffers_initializers(
self._flattened_module.named_parameters(),
[param for name, param in self._flattened_module.named_parameters()
if name in self._graph_info.initializer_names],
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),

View file

@ -73,7 +73,7 @@ class _InputInfo(object):
if name in self.keyword_names}
return args, kwargs
def _combine_input_buffers_initializers(param_names, onnx_input_names, input_info, buffer_names, inputs, kwargs, device):
def _combine_input_buffers_initializers(params, onnx_input_names, input_info, buffer_names, inputs, kwargs, device):
'''Creates forward `*inputs` list from user input and PyTorch initializers
ONNX Runtime forward requires an ordered list of:
@ -120,8 +120,9 @@ def _combine_input_buffers_initializers(param_names, onnx_input_names, input_inf
else:
raise RuntimeError(f'Input is present in ONNX graph but not provided: {name}.')
# Initializers
result.extend([param[1] for param in param_names])
# params is a list of all initializers known to the onnx graph
result.extend(params)
return result

View file

@ -173,22 +173,22 @@ class TrainingManager(GraphExecutionManager):
# Append gradients of initializer to results
# Go over each initializer, check if it required grad and append to results accordingly
initializer_names_to_train_set = set(self._graph_info.initializer_names_to_train)
initializer_index = num_user_input_grads
for initializer_name in self._graph_info.initializer_names:
if initializer_name in initializer_names_to_train_set:
for initializer_name, _ in self._flattened_module.named_parameters():
if initializer_name in self._graph_info.initializer_names_to_train:
results.append(_utils._ortvalue_to_torch_tensor(backward_outputs[initializer_index]))
initializer_index += 1
else:
results.append(None)
return tuple(results)
return _io.unflatten_user_output(self._module_output_schema,
self._graph_info.user_output_names,
_ORTModuleFunction.apply(
*_io._combine_input_buffers_initializers(
self._flattened_module.named_parameters(),
[param for name, param in self._flattened_module.named_parameters()
if name in self._graph_info.initializer_names],
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
@ -238,7 +238,7 @@ class TrainingManager(GraphExecutionManager):
initializer_names_to_train_set_user_model = {name for name, param in
self._flattened_module.named_parameters() if param.requires_grad}
initializer_names_to_train_set_onnx_graph = set(self._graph_info.initializer_names_to_train) \
initializer_names_to_train_set_onnx_graph = self._graph_info.initializer_names_to_train \
if self._graph_info else None
# If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder

View file

@ -160,7 +160,7 @@ def assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_param
assert pt_name in ort_name
if pt_name in none_pt_params:
assert pt_param.grad is None
assert not torch.is_nonzero(torch.count_nonzero(ort_param.grad))
assert ort_param.grad is None or not torch.is_nonzero(torch.count_nonzero(ort_param.grad))
else:
assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)

View file

@ -2362,3 +2362,53 @@ def test_model_with_registered_buffer_and_dropped_parameters():
# Ensure that no exceptions are raised
out = model(bool_argument, x)
def test_unused_parameters():
class UnusedParameterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(UnusedParameterNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
# fc2 is an unused initializer which will be dropped after export
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
self.register_buffer("buffer", torch.ones(hidden_size))
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = out + self.buffer
return out
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = UnusedParameterNet(D_in, H, D_out).to(device)
ort_model = ORTModule(copy.deepcopy(model))
# Make sure model runs without any exception
for _ in range(5):
x = torch.randn(N, D_in, device=device)
y = copy.deepcopy(x)
out_pt = model(x)
out_ort = ort_model(y)
loss_pt = out_pt.sum()
loss_pt.backward()
loss_ort = out_ort.sum()
loss_ort.backward()
_test_helpers.assert_values_are_close(out_ort, out_pt)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model,
none_pt_params=['fc2.weight', 'fc2.bias'])
# Also try in eval mode
model.eval()
ort_model.eval()
x = torch.randn(N, D_in, device=device)
y = copy.deepcopy(x)
# Make sure model runs without any exception
out_pt = model(x)
out_ort = ort_model(y)
_test_helpers.assert_values_are_close(out_ort, out_pt)