[ORTModule] ATen Support for aten::upsample_nearest (#13364)

ATen support for aten::upsample_nearest, which is required for
Huggingface's diffusers model training using ORTModule.
This commit is contained in:
Vincent Wang 2022-10-20 08:30:04 +08:00 committed by GitHub
parent b6b3f41636
commit 67150baa8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 130 additions and 32 deletions

View file

@ -212,6 +212,9 @@ class SymbolicShapeInference:
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"numpy_T": self._infer_Transpose,
"native_group_norm": self._infer_aten_group_norm,
"upsample_nearest1d": self._infer_aten_upsample_nearest,
"upsample_nearest2d": self._infer_aten_upsample_nearest,
"upsample_nearest3d": self._infer_aten_upsample_nearest,
}
self.run_ = True
self.suggested_merge_ = {}
@ -1366,12 +1369,30 @@ class SymbolicShapeInference:
node.output[i],
output_dtype,
[
N if N is not None else self._new_symbolic_dim_from_output(node, i, 0),
as_scalar(group) if group is not None else self._new_symbolic_dim_from_output(node, i, 1),
N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)),
as_scalar(group)
if group is not None
else str(self._new_symbolic_dim_from_output(node, i, 1)),
],
)
)
def _infer_aten_upsample_nearest(self, node):
new_shape = None
input_shape = self._get_shape(node, 0)
if input_shape is not None:
new_shape = input_shape[:2]
output_size = self._try_get_value(node, 1)
if output_size is not None:
new_shape += [dim_size.item() for dim_size in output_size]
else:
rank = len(input_shape)
new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
if node.output[0] and new_shape is not None:
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
def _infer_BatchNormalization(self, node):
self._propagate_shape_and_type(node)

View file

@ -6,42 +6,24 @@
#include <unordered_map>
#include <vector>
template <typename T>
template <typename TSrc, typename TDst>
c10::IValue ToIValue(const DLManagedTensor* dlpack, bool is_optional) {
TORCH_INTERNAL_ASSERT((dlpack->dl_tensor.ndim == 0 && dlpack->dl_tensor.shape == nullptr) ||
(dlpack->dl_tensor.ndim == 1 && dlpack->dl_tensor.shape[0] == 1));
T value = *reinterpret_cast<const T*>(dlpack->dl_tensor.data);
return is_optional ? c10::IValue(c10::optional<T>(value)) : c10::IValue(value);
TDst value = static_cast<TDst>(*reinterpret_cast<const TSrc*>(dlpack->dl_tensor.data));
return is_optional ? c10::IValue(c10::optional<TDst>(value)) : c10::IValue(value);
}
template <typename T>
template <typename TSrc, typename TDst>
c10::IValue ToListIValue(const DLManagedTensor* dlpack, bool is_optional) {
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.ndim == 1);
const T* p_data = reinterpret_cast<const T*>(dlpack->dl_tensor.data);
c10::List<T> list_value;
const TSrc* p_data = reinterpret_cast<const TSrc*>(dlpack->dl_tensor.data);
c10::List<TDst> list_value;
size_t len = static_cast<size_t>(dlpack->dl_tensor.shape[0]);
for (size_t i = 0; i < len; i++) {
list_value.emplace_back(p_data[i]);
list_value.emplace_back(static_cast<TDst>(p_data[i]));
}
return is_optional ? c10::IValue(c10::optional<c10::List<T>>(list_value)) : c10::IValue(list_value);
}
c10::IValue Int64ToBoolIValue(const DLManagedTensor* dlpack, bool is_list, bool is_optional) {
if (is_list) {
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.ndim == 1);
const int64_t* p_data = reinterpret_cast<const int64_t*>(dlpack->dl_tensor.data);
c10::List<bool> list_value;
size_t len = static_cast<size_t>(dlpack->dl_tensor.shape[0]);
for (size_t i = 0; i < len; i++) {
list_value.emplace_back(static_cast<bool>(p_data[i]));
}
return is_optional ? c10::IValue(c10::optional<c10::List<bool>>(list_value)) : c10::IValue(list_value);
}
TORCH_INTERNAL_ASSERT((dlpack->dl_tensor.ndim == 0 && dlpack->dl_tensor.shape == nullptr) ||
(dlpack->dl_tensor.ndim == 1 && dlpack->dl_tensor.shape[0] == 1));
bool value = static_cast<bool>(*reinterpret_cast<const int64_t*>(dlpack->dl_tensor.data));
return is_optional ? c10::IValue(c10::optional<bool>(value)) : c10::IValue(value);
return is_optional ? c10::IValue(c10::optional<c10::List<TDst>>(list_value)) : c10::IValue(list_value);
}
struct ATenOperator {
@ -78,22 +60,26 @@ struct ATenOperator {
case c10::TypeKind::IntType: {
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLInt &&
dlpack->dl_tensor.dtype.bits == 64);
i_value = is_list ? ToListIValue<int64_t>(dlpack, is_optional) : ToIValue<int64_t>(dlpack, is_optional);
i_value = is_list ? ToListIValue<int64_t, int64_t>(dlpack, is_optional)
: ToIValue<int64_t, int64_t>(dlpack, is_optional);
} break;
case c10::TypeKind::FloatType: {
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLFloat &&
dlpack->dl_tensor.dtype.bits == 32);
i_value = is_list ? ToListIValue<float>(dlpack, is_optional) : ToIValue<float>(dlpack, is_optional);
// PyTorch's IValue doesn't support float, so we convert it to double.
i_value =
is_list ? ToListIValue<float, double>(dlpack, is_optional) : ToIValue<float, double>(dlpack, is_optional);
} break;
case c10::TypeKind::BoolType: {
// In torch 1.8.1, exporter has bug which exports bool constant to int64 type tensor.
// This bug has been fixed since torch 1.9.0. To make torch 1.8.1 work, add special handling here.
if (dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLInt && dlpack->dl_tensor.dtype.bits == 64) {
i_value = Int64ToBoolIValue(dlpack, is_list, is_optional);
i_value =
is_list ? ToListIValue<int64_t, bool>(dlpack, is_optional) : ToIValue<int64_t, bool>(dlpack, is_optional);
} else {
TORCH_INTERNAL_ASSERT(dlpack->dl_tensor.dtype.code == DLDataTypeCode::kDLUInt &&
dlpack->dl_tensor.dtype.bits == 8);
i_value = is_list ? ToListIValue<bool>(dlpack, is_optional) : ToIValue<bool>(dlpack, is_optional);
i_value = is_list ? ToListIValue<bool, bool>(dlpack, is_optional) : ToIValue<bool, bool>(dlpack, is_optional);
}
} break;
default: // TODO: will add more type support if needed.

View file

@ -235,3 +235,33 @@ def native_group_norm_gradient():
{"operator": {"value": "native_group_norm_backward", "dtype": "string"}},
),
]
def _upsample_nearest_gradient(backward_fn):
return [
("Shape", ["I(0)"], ["Shape_X"]),
(
("ATen", "org.pytorch.aten"),
["GO(0)", "I(1)", "Shape_X", "I(2)"],
["GI(0)"],
{
"operator": {"value": backward_fn, "dtype": "string"},
"overload_name": {"value": "vec", "dtype": "string"},
},
),
]
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest1d", "vec")
def upsample_nearest1d_gradient():
return _upsample_nearest_gradient("upsample_nearest1d_backward")
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest2d", "vec")
def upsample_nearest2d_gradient():
return _upsample_nearest_gradient("upsample_nearest2d_backward")
@register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec")
def upsample_nearest3d_gradient():
return _upsample_nearest_gradient("upsample_nearest3d_backward")

View file

@ -693,3 +693,29 @@ def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
operator_s="native_group_norm",
outputs=3,
)[0]
def _upsample_nearest(g, input, output_size, scale_factors, forward_fn):
return g.op(
"org.pytorch.aten::ATen",
input,
output_size,
scale_factors,
operator_s=forward_fn,
overload_name_s="vec",
)
@register_symbolic("upsample_nearest1d")
def upsample_nearest1d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest1d")
@register_symbolic("upsample_nearest2d")
def upsample_nearest2d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest2d")
@register_symbolic("upsample_nearest3d")
def upsample_nearest3d(g, input, output_size, scale_factors):
return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d")

View file

@ -1691,6 +1691,41 @@ def test_aten_group_norm():
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
@pytest.mark.parametrize("input_rank", (3, 4, 5))
@pytest.mark.parametrize("use_factor", (True, False))
def test_aten_upsample_nearest(input_rank, use_factor):
class _NeuralNetUpsampleNearest(torch.nn.Module):
def __init__(self):
super(_NeuralNetUpsampleNearest, self).__init__()
def forward(self, input):
return (
torch.nn.functional.interpolate(input, scale_factor=2.0, mode="nearest")
if use_factor
else torch.nn.functional.interpolate(input, size=12, mode="nearest")
)
device = "cuda"
pt_model = _NeuralNetUpsampleNearest().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input):
prediction = model(input)
prediction.sum().backward()
return prediction
# reset manual seed to reset the generator
torch.manual_seed(2333)
input_size = [2 * (dim + 1) for dim in range(input_rank)]
pt_input = torch.randn(input_size, dtype=torch.float, device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
def test_gradient_correctness_cast_chain():
class NeuralNetCast(torch.nn.Module):
def __init__(self, D):