mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
[ORTModule] ATen Support for aten::upsample_nearest (#13364)
ATen support for aten::upsample_nearest, which is required for Huggingface's diffusers model training using ORTModule.
This commit is contained in:
parent
b6b3f41636
commit
67150baa8d
5 changed files with 130 additions and 32 deletions
|
|
@ -212,6 +212,9 @@ class SymbolicShapeInference:
|
|||
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
|
||||
"numpy_T": self._infer_Transpose,
|
||||
"native_group_norm": self._infer_aten_group_norm,
|
||||
"upsample_nearest1d": self._infer_aten_upsample_nearest,
|
||||
"upsample_nearest2d": self._infer_aten_upsample_nearest,
|
||||
"upsample_nearest3d": self._infer_aten_upsample_nearest,
|
||||
}
|
||||
self.run_ = True
|
||||
self.suggested_merge_ = {}
|
||||
|
|
@ -1366,12 +1369,30 @@ class SymbolicShapeInference:
|
|||
node.output[i],
|
||||
output_dtype,
|
||||
[
|
||||
N if N is not None else self._new_symbolic_dim_from_output(node, i, 0),
|
||||
as_scalar(group) if group is not None else self._new_symbolic_dim_from_output(node, i, 1),
|
||||
N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)),
|
||||
as_scalar(group)
|
||||
if group is not None
|
||||
else str(self._new_symbolic_dim_from_output(node, i, 1)),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
def _infer_aten_upsample_nearest(self, node):
|
||||
new_shape = None
|
||||
input_shape = self._get_shape(node, 0)
|
||||
if input_shape is not None:
|
||||
new_shape = input_shape[:2]
|
||||
output_size = self._try_get_value(node, 1)
|
||||
if output_size is not None:
|
||||
new_shape += [dim_size.item() for dim_size in output_size]
|
||||
else:
|
||||
rank = len(input_shape)
|
||||
new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
|
||||
if node.output[0] and new_shape is not None:
|
||||
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
||||
|
||||
def _infer_BatchNormalization(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,42 +6,24 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
template <typename T>
|
||||
template <typename TSrc, typename TDst>
|
||||
c10::IValue ToIValue(const DLManagedTensor* dlpack, bool is_optional) {
|
||||
TORCH_INTERNAL_ASSERT((dlpack->dl_tensor.ndim == 0 && dlpack->dl_tensor.shape == nullptr) ||
|
||||
(dlpack->dl_tensor.ndim == 1 && dlpack->dl_tensor.shape[0] == 1));
|
||||
T value = *reinterpret_cast<const T*>(dlpack->dl_tensor.data);
|
||||
return is_optional ? c10::IValue(c10::optional<T>(value)) : c10::IValue(value);
|
||||
TDst value = static_cast<TDst>(*reinterpret_cast<const TSrc*>(dlpack->dl_tensor.data));
|
||||
return is_optional ? c10::IValue(c10::optional<TDst>(value)) : c10::IValue(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename TSrc, typename TDst>
|
||||
c10::IValue ToListIValue(const DLManagedTensor* dlpack, bool is_optional) {
|
||||
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.ndim == 1);
|
||||
const T* p_data = reinterpret_cast<const T*>(dlpack->dl_tensor.data);
|
||||
c10::List<T> list_value;
|
||||
const TSrc* p_data = reinterpret_cast<const TSrc*>(dlpack->dl_tensor.data);
|
||||
c10::List<TDst> list_value;
|
||||
size_t len = static_cast<size_t>(dlpack->dl_tensor.shape[0]);
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
list_value.emplace_back(p_data[i]);
|
||||
list_value.emplace_back(static_cast<TDst>(p_data[i]));
|
||||
}
|
||||
return is_optional ? c10::IValue(c10::optional<c10::List<T>>(list_value)) : c10::IValue(list_value);
|
||||
}
|
||||
|
||||
c10::IValue Int64ToBoolIValue(const DLManagedTensor* dlpack, bool is_list, bool is_optional) {
|
||||
if (is_list) {
|
||||
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.ndim == 1);
|
||||
const int64_t* p_data = reinterpret_cast<const int64_t*>(dlpack->dl_tensor.data);
|
||||
c10::List<bool> list_value;
|
||||
size_t len = static_cast<size_t>(dlpack->dl_tensor.shape[0]);
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
list_value.emplace_back(static_cast<bool>(p_data[i]));
|
||||
}
|
||||
return is_optional ? c10::IValue(c10::optional<c10::List<bool>>(list_value)) : c10::IValue(list_value);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT((dlpack->dl_tensor.ndim == 0 && dlpack->dl_tensor.shape == nullptr) ||
|
||||
(dlpack->dl_tensor.ndim == 1 && dlpack->dl_tensor.shape[0] == 1));
|
||||
bool value = static_cast<bool>(*reinterpret_cast<const int64_t*>(dlpack->dl_tensor.data));
|
||||
return is_optional ? c10::IValue(c10::optional<bool>(value)) : c10::IValue(value);
|
||||
return is_optional ? c10::IValue(c10::optional<c10::List<TDst>>(list_value)) : c10::IValue(list_value);
|
||||
}
|
||||
|
||||
struct ATenOperator {
|
||||
|
|
@ -78,22 +60,26 @@ struct ATenOperator {
|
|||
case c10::TypeKind::IntType: {
|
||||
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLInt &&
|
||||
dlpack->dl_tensor.dtype.bits == 64);
|
||||
i_value = is_list ? ToListIValue<int64_t>(dlpack, is_optional) : ToIValue<int64_t>(dlpack, is_optional);
|
||||
i_value = is_list ? ToListIValue<int64_t, int64_t>(dlpack, is_optional)
|
||||
: ToIValue<int64_t, int64_t>(dlpack, is_optional);
|
||||
} break;
|
||||
case c10::TypeKind::FloatType: {
|
||||
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLFloat &&
|
||||
dlpack->dl_tensor.dtype.bits == 32);
|
||||
i_value = is_list ? ToListIValue<float>(dlpack, is_optional) : ToIValue<float>(dlpack, is_optional);
|
||||
// PyTorch's IValue doesn't support float, so we convert it to double.
|
||||
i_value =
|
||||
is_list ? ToListIValue<float, double>(dlpack, is_optional) : ToIValue<float, double>(dlpack, is_optional);
|
||||
} break;
|
||||
case c10::TypeKind::BoolType: {
|
||||
// In torch 1.8.1, exporter has bug which exports bool constant to int64 type tensor.
|
||||
// This bug has been fixed since torch 1.9.0. To make torch 1.8.1 work, add special handling here.
|
||||
if (dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLInt && dlpack->dl_tensor.dtype.bits == 64) {
|
||||
i_value = Int64ToBoolIValue(dlpack, is_list, is_optional);
|
||||
i_value =
|
||||
is_list ? ToListIValue<int64_t, bool>(dlpack, is_optional) : ToIValue<int64_t, bool>(dlpack, is_optional);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLUInt &&
|
||||
dlpack->dl_tensor.dtype.bits == 8);
|
||||
i_value = is_list ? ToListIValue<bool>(dlpack, is_optional) : ToIValue<bool>(dlpack, is_optional);
|
||||
i_value = is_list ? ToListIValue<bool, bool>(dlpack, is_optional) : ToIValue<bool, bool>(dlpack, is_optional);
|
||||
}
|
||||
} break;
|
||||
default: // TODO: will add more type support if needed.
|
||||
|
|
|
|||
|
|
@ -235,3 +235,33 @@ def native_group_norm_gradient():
|
|||
{"operator": {"value": "native_group_norm_backward", "dtype": "string"}},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _upsample_nearest_gradient(backward_fn):
|
||||
return [
|
||||
("Shape", ["I(0)"], ["Shape_X"]),
|
||||
(
|
||||
("ATen", "org.pytorch.aten"),
|
||||
["GO(0)", "I(1)", "Shape_X", "I(2)"],
|
||||
["GI(0)"],
|
||||
{
|
||||
"operator": {"value": backward_fn, "dtype": "string"},
|
||||
"overload_name": {"value": "vec", "dtype": "string"},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest1d", "vec")
|
||||
def upsample_nearest1d_gradient():
|
||||
return _upsample_nearest_gradient("upsample_nearest1d_backward")
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest2d", "vec")
|
||||
def upsample_nearest2d_gradient():
|
||||
return _upsample_nearest_gradient("upsample_nearest2d_backward")
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
|
||||
def upsample_nearest3d_gradient():
|
||||
return _upsample_nearest_gradient("upsample_nearest3d_backward")
|
||||
|
|
|
|||
|
|
@ -693,3 +693,29 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
|||
operator_s="native_group_norm",
|
||||
outputs=3,
|
||||
)[0]
|
||||
|
||||
|
||||
def _upsample_nearest(g, input, output_size, scale_factors, forward_fn):
|
||||
return g.op(
|
||||
"org.pytorch.aten::ATen",
|
||||
input,
|
||||
output_size,
|
||||
scale_factors,
|
||||
operator_s=forward_fn,
|
||||
overload_name_s="vec",
|
||||
)
|
||||
|
||||
|
||||
@register_symbolic("upsample_nearest1d")
|
||||
def upsample_nearest1d(g, input, output_size, scale_factors):
|
||||
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest1d")
|
||||
|
||||
|
||||
@register_symbolic("upsample_nearest2d")
|
||||
def upsample_nearest2d(g, input, output_size, scale_factors):
|
||||
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest2d")
|
||||
|
||||
|
||||
@register_symbolic("upsample_nearest3d")
|
||||
def upsample_nearest3d(g, input, output_size, scale_factors):
|
||||
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")
|
||||
|
|
|
|||
|
|
@ -1691,6 +1691,41 @@ def test_aten_group_norm():
|
|||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_rank", (3, 4, 5))
|
||||
@pytest.mark.parametrize("use_factor", (True, False))
|
||||
def test_aten_upsample_nearest(input_rank, use_factor):
|
||||
class _NeuralNetUpsampleNearest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(_NeuralNetUpsampleNearest, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return (
|
||||
torch.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest")
|
||||
if use_factor
|
||||
else torch.nn.functional.interpolate(input, size=12, mode="nearest")
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = _NeuralNetUpsampleNearest().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input):
|
||||
prediction = model(input)
|
||||
prediction.sum().backward()
|
||||
return prediction
|
||||
|
||||
# reset manual seed to reset the generator
|
||||
torch.manual_seed(2333)
|
||||
input_size = [2 * (dim + 1) for dim in range(input_rank)]
|
||||
pt_input = torch.randn(input_size, dtype=torch.float, device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_prediction = run_step(pt_model, pt_input)
|
||||
ort_prediction = run_step(ort_model, ort_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
def test_gradient_correctness_cast_chain():
|
||||
class NeuralNetCast(torch.nn.Module):
|
||||
def __init__(self, D):
|
||||
|
|
|
|||
Loading…
Reference in a new issue