Regain performance by caching initializer names in ORTModule (#7685)

This commit is contained in:
baijumeswani 2021-05-13 20:54:49 -07:00 committed by GitHub
parent 19704aedbb
commit 37f69fcee5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 81 additions and 37 deletions

View file

@ -39,12 +39,10 @@ Status OrtModuleGraphBuilder::Initialize(std::istream& model_istream,
graph_info_.user_output_names.emplace_back(node_arg->Name());
}
graph_info_.initializer_names_to_train = std::unordered_set<std::string>(
config_.initializer_names_to_train.begin(),
config_.initializer_names_to_train.end());
graph_info_.initializer_names = std::unordered_set<std::string>(
config_.initializer_names.begin(),
config_.initializer_names.end());
graph_info_.initializer_names_to_train.assign(config.initializer_names_to_train.begin(),
config.initializer_names_to_train.end());
graph_info_.initializer_names.assign(config.initializer_names.begin(),
config.initializer_names.end());
std::vector<const NodeArg*> input_args;
for (const auto& input_name : graph_info_.user_input_names) {
@ -320,7 +318,7 @@ void OrtModuleGraphBuilder::ReorderOutputs() {
std::string initializer_gradient_name = GradientBuilderBase::GradientName(initializer_name);
ORT_ENFORCE(gradient_output_arg_map.find(initializer_gradient_name) != gradient_output_arg_map.end(),
"Trainable initializer grad is not found on gradient graph.");
graph_info_.initializer_grad_names_to_train.emplace(initializer_gradient_name);
graph_info_.initializer_grad_names_to_train.emplace_back(initializer_gradient_name);
new_output_args.emplace_back(gradient_output_arg_map[initializer_gradient_name]);
}

View file

@ -44,11 +44,11 @@ struct GraphInfo {
// Map from user input names to corresponding user input grad names for those user inputs that require grad.
std::unordered_map<std::string, std::string> user_input_grad_names{};
// All initializers (trainable as well as non trainable).
std::unordered_set<std::string> initializer_names{};
std::vector<std::string> initializer_names{};
// Trainable initializers.
std::unordered_set<std::string> initializer_names_to_train{};
std::vector<std::string> initializer_names_to_train{};
// Trainable initializer grad names, ordered according to initializer_names_to_train.
std::unordered_set<std::string> initializer_grad_names_to_train{};
std::vector<std::string> initializer_grad_names_to_train{};
// The user outputs.
std::vector<std::string> user_output_names{};
// Indices of output grads that are non-differentiable.

View file

@ -52,6 +52,8 @@ class GraphExecutionManager(ABC):
self._optimized_onnx_model = None
self._graph_builder = None
self._graph_info = None
self._graph_initializer_names = None
self._graph_initializer_names_to_train = None
# TrainingAgent or InferenceAgent
self._execution_agent = None
@ -156,6 +158,11 @@ class GraphExecutionManager(ABC):
self._optimized_onnx_model = onnx.load_model_from_string(self._graph_builder.get_model())
self._graph_info = self._graph_builder.get_graph_info()
# TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train
# a set (unordered_set in the backend) that does not require a copy on each reference.
self._graph_initializer_names = set(self._graph_info.initializer_names)
self._graph_initializer_names_to_train = set(self._graph_info.initializer_names_to_train)
def _get_session_config(self):
"""Creates and returns the session configuration to be used for the ExecutionAgent"""
providers = None

View file

@ -90,7 +90,7 @@ class InferenceManager(GraphExecutionManager):
self._device,
*_io._combine_input_buffers_initializers(
[param for name, param in self._flattened_module.named_parameters()
if name in self._graph_info.initializer_names],
if name in self._graph_initializer_names],
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),

View file

@ -174,8 +174,8 @@ class TrainingManager(GraphExecutionManager):
# Append gradients of initializer to results
# Go over each initializer, check if it required grad and append to results accordingly
initializer_index = num_user_input_grads
for initializer_name, _ in self._flattened_module.named_parameters():
if initializer_name in self._graph_info.initializer_names_to_train:
for initializer_name in self._graph_info.initializer_names:
if initializer_name in self._graph_initializer_names_to_train:
results.append(_utils._ortvalue_to_torch_tensor(backward_outputs[initializer_index]))
initializer_index += 1
else:
@ -188,7 +188,7 @@ class TrainingManager(GraphExecutionManager):
_ORTModuleFunction.apply(
*_io._combine_input_buffers_initializers(
[param for name, param in self._flattened_module.named_parameters()
if name in self._graph_info.initializer_names],
if name in self._graph_initializer_names],
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
@ -236,13 +236,11 @@ class TrainingManager(GraphExecutionManager):
initializer_names_to_train_set_user_model = {name for name, param in
self._flattened_module.named_parameters() if param.requires_grad}
initializer_names_to_train_set_onnx_graph = self._graph_info.initializer_names_to_train \
if self._graph_info else None
# If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder
# needs to be reinitialized so it can compute the backward output for the new inputs that require_grad
if input_info.require_grad_names != self._input_info.require_grad_names or \
initializer_names_to_train_set_user_model != initializer_names_to_train_set_onnx_graph:
initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train:
self._input_info = input_info
self._initialize_graph_builder(training=True)
return True

View file

@ -197,6 +197,59 @@ class NeuralNetPartialNoGradModel(torch.nn.Module):
out = self.fc2(out)
return out
class UnusedEndParameterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
super(UnusedEndParameterNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size1)
self.relu = torch.nn.ReLU()
# fc2 is an unused initializer (which is in the end of initializer list)
# which will be dropped after export
self.fc2 = torch.nn.Linear(hidden_size1, hidden_size2)
self.register_buffer("buffer", torch.ones(hidden_size1))
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = out + self.buffer
return out
class UnusedBeginParameterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
super(UnusedBeginParameterNet, self).__init__()
# fc1 is an unused initializer (which is in the begining of initializer list)
# which will be dropped after export
self.fc1 = torch.nn.Linear(input_size, hidden_size1)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(input_size, hidden_size2)
self.register_buffer("buffer", torch.ones(hidden_size2))
def forward(self, input1):
out = self.fc2(input1)
out = self.relu(out)
out = out + self.buffer
return out
class UnusedMiddleParameterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
super(UnusedMiddleParameterNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size1)
self.relu = torch.nn.ReLU()
# fc2 is an unused initializer (which is in the middle of initializer list)
# which will be dropped after export
self.fc2 = torch.nn.Linear(hidden_size1, hidden_size2)
self.fc3 = torch.nn.Linear(hidden_size1, num_classes)
self.register_buffer("buffer", torch.ones(num_classes))
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc3(out)
out = out + self.buffer
return out
# TODO: This is a workaround for the problem that pytest is still cleaning up the previous test
# while the next task already start.
@pytest.fixture(autouse=True)
@ -2393,27 +2446,15 @@ def test_model_with_registered_buffer_and_dropped_parameters():
# Ensure that no exceptions are raised
out = model(bool_argument, x)
def test_unused_parameters():
class UnusedParameterNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(UnusedParameterNet, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
# fc2 is an unused initializer which will be dropped after export
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
self.register_buffer("buffer", torch.ones(hidden_size))
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = out + self.buffer
return out
@pytest.mark.parametrize("model, none_pt_params",
[(UnusedBeginParameterNet(784, 500, 400, 10), ['fc1.weight', 'fc1.bias']),
(UnusedMiddleParameterNet(784, 500, 400, 10), ['fc2.weight', 'fc2.bias']),
(UnusedEndParameterNet(784, 500, 400, 10), ['fc2.weight', 'fc2.bias'])])
def test_unused_parameters(model, none_pt_params):
device = 'cuda'
N, D_in, H, D_out = 64, 784, 500, 10
model = UnusedParameterNet(D_in, H, D_out).to(device)
N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10
model = model.to(device)
ort_model = ORTModule(copy.deepcopy(model))
# Make sure model runs without any exception
@ -2429,7 +2470,7 @@ def test_unused_parameters():
loss_ort.backward()
_test_helpers.assert_values_are_close(out_ort, out_pt)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model,
none_pt_params=['fc2.weight', 'fc2.bias'])
none_pt_params=none_pt_params)
# Also try in eval mode
model.eval()