Remove dead code

This commit is contained in:
Thiago Crepaldi 2020-11-19 17:46:08 -08:00
parent 41b88ce91d
commit e986ae5f86

View file

@ -392,31 +392,3 @@ class ORTModule(torch.nn.Module):
split_graphs_info = module_gradient_graph_builder.get_split_graphs_info()
return gradient_model, forward_model, backward_model, split_graphs_info
@staticmethod
def _get_io_info_from_onnx_graph(model, graphs_info):
type_map = {key: None for key in [
*graphs_info.user_input_names,
*graphs_info.initializer_names_to_train,
*graphs_info.initializer_grad_names_to_train,
*graphs_info.user_output_names,
*graphs_info.intermediate_tensor_names,
*graphs_info.user_output_grad_names
]}
for input in model.graph.input:
if input.name in type_map and type_map[input.name] is None:
type_map[input.name] = input.type
input_grad_name = input.name + '_grad'
if input_grad_name in type_map and type_map[input_grad_name] is None:
type_map[input_grad_name] = input.type
for output in model.graph.output:
if output.name in type_map and type_map[output.name] is None:
type_map[output.name] = output.type
output_grad_name = output.name + '_grad'
if output_grad_name in type_map and type_map[output_grad_name] is None:
type_map[output_grad_name] = output.type
return type_map