diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 7cd3768218..b088f5bb2a 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -199,6 +199,45 @@ def dtype_torch_to_numpy(torch_dtype): else: raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) +class model_loss_cls(torch.nn.Module): + def __init__(self, model, loss_fn): + super(model_loss_cls, self).__init__() + self.model_ = model + self.loss_fn_ = loss_fn + + def forward(self, *inputs): + # here we assume input can be unpacked into input and label + input, label = inputs[:-1], inputs[-1] + preds = self.model_(*input) + return self.loss_fn_(preds, label), preds + +class WrapModel(torch.nn.Module): + def __init__(self, model, loss_fn, input_names): + super(WrapModel, self).__init__() + self.model_ = model + self.loss_fn_ = loss_fn + self.input_names_ = input_names + + def forward(self, *inputs): + import inspect + # *inputs is given by torch trace. It is in the order of input_names. + # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. + sig = inspect.signature(self.model_.forward) + ordered_list_keys = list(sig.parameters.keys()) + + input_dict = {} + for key in sig.parameters.keys(): + if key in self.input_names_: + input_dict[key] = inputs[self.input_names_.index(key)] + + model_out = self.model_(**input_dict) + if self.loss_fn_ is None: + return model_out + + label = inputs[-1] + preds = model_out + return self.loss_fn_(preds, label), preds + def wrap_for_input_match(model, loss_fn, input_names): import inspect sig = inspect.signature(model.forward) @@ -211,18 +250,6 @@ def wrap_for_input_match(model, loss_fn, input_names): # label shall be the second input to loss_fn. ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]] - class model_loss_cls(torch.nn.Module): - def __init__(self, model, loss_fn): - super(model_loss_cls, self).__init__() - self.model_ = model - self.loss_fn_ = loss_fn - - def forward(self, *inputs): - # here we assume input can be unpacked into input and label - input, label = inputs[:-1], inputs[-1] - preds = self.model_(*input) - return self.loss_fn_(preds, label), preds - # name match is needed only when input_names are a subset # of expected inputs (inputs to model and loss_fn combined). if len(input_names) > len(ordered_list_keys): @@ -248,32 +275,6 @@ def wrap_for_input_match(model, loss_fn, input_names): if match: return model_loss_cls(model, loss_fn) if loss_fn else model - class WrapModel(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super(WrapModel, self).__init__() - self.model_ = model - self.loss_fn_ = loss_fn - self.input_names_ = input_names - - def forward(self, *inputs): - # *inputs is given by torch trace. It is in the order of input_names. - # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. - sig = inspect.signature(self.model_.forward) - ordered_list_keys = list(sig.parameters.keys()) - - input_dict = {} - for key in sig.parameters.keys(): - if key in self.input_names_: - input_dict[key] = inputs[self.input_names_.index(key)] - - model_out = self.model_(**input_dict) - if self.loss_fn_ is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn_(preds, label), preds - model = WrapModel(model, loss_fn, input_names) return model @@ -362,23 +363,16 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op onnx_model = onnx.load_model_from_string(f.getvalue()) # Remove 'model_.' prefix introduced by model wrapper for initializers. - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith('model_.'): - replace_name_dict[n.name] = n.name[len('model_.'):] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - # onnx model initializer may contain non-trainable registered buffers that are not part - # of pytorch model named parameteres. - named_parameters = model.model_.named_parameters() if hasattr(model, 'model_') else model.named_parameters() - assert set([n for n, t in named_parameters]).issubset( - set([n.name for n in onnx_model.graph.initializer])), \ - "Initializer names do not match between PyTorch model and ONNX model, " \ - "please report a bug to ONNX Runtime." + if isinstance(model, WrapModel) or isinstance(model, model_loss_cls): + replace_name_dict = {} + for n in onnx_model.graph.initializer: + if n.name.startswith('model_.'): + replace_name_dict[n.name] = n.name[len('model_.'):] + n.name = replace_name_dict[n.name] + for n in onnx_model.graph.node: + for i, name in enumerate(n.input): + if name in replace_name_dict: + n.input[i] = replace_name_dict[name] return onnx_model