mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
ConvTransposeGrad CUDA Kernel (#17201)
This commit is contained in:
parent
33415b9da4
commit
fca81cc5d5
15 changed files with 1475 additions and 268 deletions
|
|
@ -201,6 +201,10 @@ set(training_ops_excluded_files
|
|||
"reduction/reduction_ops.cc" # no double type support
|
||||
"cuda_training_kernels.cc"
|
||||
"cuda_training_kernels.h"
|
||||
"nn/conv_shared.cc"
|
||||
"nn/conv_shared.h"
|
||||
"nn/conv_transpose_grad.cc"
|
||||
"nn/conv_transpose_grad.h"
|
||||
)
|
||||
|
||||
function(auto_set_source_files_hip_language)
|
||||
|
|
|
|||
|
|
@ -2070,5 +2070,22 @@ IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) {
|
|||
{GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetConvTransposeGradient) {
|
||||
std::vector<ArgDef> outputs;
|
||||
for (int i = 0; i < GetSrcNodeInputSize(); i++) {
|
||||
if (IsGradientRequiredForSrcNodeInput(i)) {
|
||||
outputs.push_back(GI(i));
|
||||
} else {
|
||||
outputs.push_back(ArgDef("", nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"ConvTransposeGrad", kMSDomain, 1},
|
||||
{GO(0), I(0), I(1)},
|
||||
outputs,
|
||||
SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ DECLARE_GRADIENT_BUILDER(GetLSTMGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetGRUGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
|
||||
|
||||
DECLARE_GRADIENT_BUILDER(GetExternalGradient)
|
||||
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient);
|
||||
REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient);
|
||||
REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient);
|
||||
|
||||
REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4908,6 +4908,21 @@ Return true if all elements are true and false otherwise.
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ConvTransposeGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.Input(0, "dY", "Gradient of output Y", "T")
|
||||
.Input(1, "X", "Input tensor", "T")
|
||||
.Input(2, "W", "Weight tensor", "T")
|
||||
.Output(0, "dX", "Gradient of X", "T", OpSchema::Optional)
|
||||
.Output(1, "dW", "Gradient of W", "T", OpSchema::Optional)
|
||||
.Output(2, "dB", "Gradient of B", "T", OpSchema::Optional)
|
||||
.AllowUncheckedAttributes()
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)"},
|
||||
"Constrain input and output types to float tensors.");
|
||||
}
|
||||
|
||||
} // namespace training
|
||||
|
|
|
|||
|
|
@ -3039,6 +3039,204 @@ TEST(GradientCheckerTest, LeakyReluGrad) {
|
|||
UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer);
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
void ConvTransposeGradientCheckerTest(std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers) {
|
||||
float max_error;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{"ConvTranspose"};
|
||||
|
||||
float error_tolerance = 1e-1f;
|
||||
|
||||
// 1D convolution
|
||||
{
|
||||
TensorShape x_shape({2, 2, 5});
|
||||
TensorShape w_shape({2, 2, 3});
|
||||
TensorShape b_shape({2});
|
||||
TensorShape y_shape({2, 2, 5});
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(
|
||||
op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3}), MakeAttribute("pads", std::vector<int64_t>{1, 1})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 1D strided convolution
|
||||
{
|
||||
TensorShape x_shape({2, 1, 7});
|
||||
TensorShape w_shape({1, 1, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 13});
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(
|
||||
op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3}), MakeAttribute("pads", std::vector<int64_t>{1, 1}),
|
||||
MakeAttribute("strides", std::vector<int64_t>{2})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 1D pointwise convolution (with padding)
|
||||
{
|
||||
TensorShape x_shape({2, 1, 5});
|
||||
TensorShape w_shape({1, 1, 1});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 3});
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(
|
||||
op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{1}), MakeAttribute("pads", std::vector<int64_t>{1, 1})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 1D pointwise convolution (no padding)
|
||||
{
|
||||
TensorShape x_shape({2, 1, 5});
|
||||
TensorShape w_shape({1, 1, 1});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 5});
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(
|
||||
op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{1}), MakeAttribute("pads", std::vector<int64_t>{0, 0})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D convolution
|
||||
{
|
||||
TensorShape x_shape({1, 1, 3, 3});
|
||||
TensorShape w_shape({1, 1, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({1, 1, 3, 3});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D convolution
|
||||
{
|
||||
TensorShape x_shape({2, 1, 5, 5});
|
||||
TensorShape w_shape({1, 1, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 5, 5});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D pointwise convolution (with padding)
|
||||
{
|
||||
TensorShape x_shape({1, 1, 3, 3});
|
||||
TensorShape w_shape({1, 1, 1, 1});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({1, 1, 1, 1});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{1, 1}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D pointwise convolution (no padding)
|
||||
{
|
||||
TensorShape x_shape({1, 1, 3, 3});
|
||||
TensorShape w_shape({1, 1, 1, 1});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({1, 1, 3, 3});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{1, 1}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{0, 0, 0, 0})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D strided convolution
|
||||
{
|
||||
TensorShape x_shape({2, 1, 7, 5});
|
||||
TensorShape w_shape({1, 1, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 13, 9});
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(
|
||||
op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1}), MakeAttribute("strides", std::vector<int64_t>{2, 2})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D dilated convolution (no padding)
|
||||
{
|
||||
TensorShape x_shape({2, 1, 5, 5});
|
||||
TensorShape w_shape({1, 1, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 9, 9});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{0, 0, 0, 0}),
|
||||
MakeAttribute("dilations", std::vector<int64_t>{2, 2})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 2D dilated convolution (with padding)
|
||||
{
|
||||
TensorShape x_shape({2, 1, 7, 5});
|
||||
TensorShape w_shape({1, 1, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 9, 7});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1}),
|
||||
MakeAttribute("dilations", std::vector<int64_t>{2, 2})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 3D convolution
|
||||
{
|
||||
TensorShape x_shape({2, 1, 5, 5, 5});
|
||||
TensorShape w_shape({1, 1, 3, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 5, 5, 5});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1, 1, 1})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
|
||||
// 3D strided convolution
|
||||
{
|
||||
TensorShape x_shape({2, 1, 7, 5, 5});
|
||||
TensorShape w_shape({1, 1, 3, 3, 3});
|
||||
TensorShape b_shape({1});
|
||||
TensorShape y_shape({2, 1, 13, 9, 9});
|
||||
ASSERT_STATUS_OK(
|
||||
gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error,
|
||||
{MakeAttribute("kernel_shape", std::vector<int64_t>{3, 3, 3}),
|
||||
MakeAttribute("pads", std::vector<int64_t>{1, 1, 1, 1, 1, 1}),
|
||||
MakeAttribute("strides", std::vector<int64_t>{2, 2, 2})},
|
||||
false, false, execution_providers));
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GradientCheckerTest, ConvTransposeGrad) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
ConvTransposeGradientCheckerTest(&execution_providers);
|
||||
}
|
||||
#endif // USE_CUDA
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
|||
|
|
@ -6096,7 +6096,6 @@ def test_ortmodule_log_level_control(log_level, caplog):
|
|||
found_missing_inference_log = False
|
||||
for record in caplog.records:
|
||||
msg = record.getMessage()
|
||||
print(msg)
|
||||
if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg:
|
||||
found_missing_inference_log = True
|
||||
break
|
||||
|
|
@ -6205,3 +6204,167 @@ def test_leakyrelu_gradient():
|
|||
_test_helpers.assert_values_are_close(pt_prediction, ort_prediction)
|
||||
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
|
||||
_test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm"
|
||||
)
|
||||
@pytest.mark.parametrize("use_fp16", [False, True])
|
||||
@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"])
|
||||
def test_conv_transpose_gradient(use_fp16, conv_algo_search):
|
||||
class ChainedTransposedConv(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Transposed Convolution 1D
|
||||
self.conv1d_transpose = nn.ConvTranspose1d(
|
||||
in_channels=4, out_channels=2, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
# Transposed Convolution 2D
|
||||
self.conv2d_transpose = nn.ConvTranspose2d(
|
||||
in_channels=2, out_channels=3, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
# Transposed Convolution 3D
|
||||
self.conv3d_transpose = nn.ConvTranspose3d(
|
||||
in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.relu3 = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
out1d = self.relu1(self.conv1d_transpose(x))
|
||||
out2d = self.relu2(self.conv2d_transpose(out1d.unsqueeze(2)))
|
||||
out3d = self.relu3(self.conv3d_transpose(out2d.unsqueeze(2)))
|
||||
return out3d.squeeze(2)
|
||||
|
||||
if conv_algo_search is not None:
|
||||
os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search
|
||||
|
||||
def run_step(model, x):
|
||||
with amp.autocast(use_fp16):
|
||||
loss = model(x).sum()
|
||||
loss.backward()
|
||||
|
||||
return (
|
||||
x.grad,
|
||||
model.conv1d_transpose.weight.grad,
|
||||
model.conv1d_transpose.bias.grad,
|
||||
model.conv2d_transpose.weight.grad,
|
||||
model.conv2d_transpose.bias.grad,
|
||||
model.conv3d_transpose.weight.grad,
|
||||
model.conv3d_transpose.bias.grad,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = ChainedTransposedConv().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.randn(1, 4, 8, requires_grad=True, device=device)
|
||||
ort_x = copy.deepcopy(pt_x)
|
||||
|
||||
pt_grads = run_step(pt_model, pt_x)
|
||||
ort_grads = run_step(ort_model, ort_x)
|
||||
|
||||
for pt_grad, ort_grad in zip(pt_grads, ort_grads):
|
||||
if use_fp16:
|
||||
assert torch.allclose(pt_grad, ort_grad, atol=1e-3, rtol=1e-3)
|
||||
else:
|
||||
assert torch.allclose(pt_grad, ort_grad)
|
||||
|
||||
if conv_algo_search is not None:
|
||||
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm"
|
||||
)
|
||||
@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"])
|
||||
def test_conv_transpose_gradient_with_groups(conv_algo_search):
|
||||
class TransposedConv3DWithGroups(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# in_channels, out_channels, kernel_size, stride, padding
|
||||
self.conv_transpose = nn.ConvTranspose3d(
|
||||
in_channels=6, out_channels=4, kernel_size=3, stride=2, padding=1, groups=2
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_transpose(x)
|
||||
|
||||
if conv_algo_search is not None:
|
||||
os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search
|
||||
|
||||
def run_step(model, x):
|
||||
loss = model(x).sum()
|
||||
loss.backward()
|
||||
|
||||
return (
|
||||
x.grad,
|
||||
model.conv_transpose.weight.grad,
|
||||
model.conv_transpose.bias.grad,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = TransposedConv3DWithGroups().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
pt_x = torch.randn(1, 6, 8, 16, 16, requires_grad=True, device=device)
|
||||
ort_x = copy.deepcopy(pt_x)
|
||||
|
||||
pt_grads = run_step(pt_model, pt_x)
|
||||
ort_grads = run_step(ort_model, ort_x)
|
||||
|
||||
for pt_grad, ort_grad in zip(pt_grads, ort_grads):
|
||||
assert torch.allclose(pt_grad, ort_grad)
|
||||
|
||||
if conv_algo_search is not None:
|
||||
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm"
|
||||
)
|
||||
@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"])
|
||||
def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_search):
|
||||
class ConvTransposeComplexModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv_transpose = nn.ConvTranspose3d(
|
||||
16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2), dilation=(1, 2, 1)
|
||||
)
|
||||
self.param = nn.Parameter(torch.randn(20, 33, 21, 50, 97))
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_transpose(x) * self.param
|
||||
|
||||
if conv_algo_search is not None:
|
||||
os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search
|
||||
|
||||
def run_step(model, x):
|
||||
loss = model(x).sum()
|
||||
loss.backward()
|
||||
|
||||
return (
|
||||
x.grad,
|
||||
model.conv_transpose.weight.grad,
|
||||
model.conv_transpose.bias.grad,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = ConvTransposeComplexModel().to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model)).to(device)
|
||||
|
||||
pt_x = torch.randn(20, 16, 10, 50, 100, requires_grad=True, device=device)
|
||||
ort_x = copy.deepcopy(pt_x)
|
||||
|
||||
pt_grads = run_step(pt_model, pt_x)
|
||||
ort_grads = run_step(ort_model, ort_x)
|
||||
|
||||
for pt_grad, ort_grad in zip(pt_grads, ort_grads):
|
||||
assert torch.allclose(pt_grad, ort_grad, atol=1e-2, rtol=1e-2)
|
||||
|
||||
if conv_algo_search is not None:
|
||||
del os.environ["ORTMODULE_CONV_ALGO_SEARCH"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,360 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime::contrib::test {
|
||||
|
||||
using namespace onnxruntime::test;
|
||||
|
||||
#if USE_CUDA
|
||||
namespace {
|
||||
|
||||
struct ConvTransposeGradOpAttributes {
|
||||
std::vector<int64_t> dilations;
|
||||
int64_t group;
|
||||
std::vector<int64_t> kernel_shape;
|
||||
std::vector<int64_t> pads;
|
||||
std::vector<int64_t> strides;
|
||||
};
|
||||
|
||||
void TestConvTransposeGradOp(const ConvTransposeGradOpAttributes& attributes,
|
||||
const std::vector<std::vector<float>>& inputs,
|
||||
const std::vector<std::vector<int64_t>>& input_shapes,
|
||||
const std::vector<std::vector<float>>& outputs,
|
||||
const std::vector<std::vector<int64_t>>& output_shapes,
|
||||
bool is_half = false) {
|
||||
OpTester test("ConvTransposeGrad", 1, kMSDomain);
|
||||
test.AddAttribute("group", attributes.group);
|
||||
test.AddAttribute("kernel_shape", attributes.kernel_shape);
|
||||
test.AddAttribute("pads", attributes.pads);
|
||||
|
||||
if (!attributes.dilations.empty()) {
|
||||
test.AddAttribute("dilations", attributes.dilations);
|
||||
}
|
||||
|
||||
if (!attributes.strides.empty()) {
|
||||
test.AddAttribute("strides", attributes.strides);
|
||||
}
|
||||
|
||||
if (is_half) {
|
||||
std::vector<MLFloat16> dY_half(inputs[0].size());
|
||||
ConvertFloatToMLFloat16(inputs[0].data(), dY_half.data(), static_cast<int>(inputs[0].size()));
|
||||
test.AddInput<MLFloat16>("dY", input_shapes[0], dY_half);
|
||||
|
||||
std::vector<MLFloat16> X_half(inputs[1].size());
|
||||
ConvertFloatToMLFloat16(inputs[1].data(), X_half.data(), static_cast<int>(inputs[1].size()));
|
||||
test.AddInput<MLFloat16>("X", input_shapes[1], X_half);
|
||||
|
||||
std::vector<MLFloat16> W_half(inputs[2].size());
|
||||
ConvertFloatToMLFloat16(inputs[2].data(), W_half.data(), static_cast<int>(inputs[2].size()));
|
||||
test.AddInput<MLFloat16>("W", input_shapes[2], W_half);
|
||||
|
||||
std::vector<MLFloat16> dX_half(outputs[0].size());
|
||||
ConvertFloatToMLFloat16(outputs[0].data(), dX_half.data(), static_cast<int>(outputs[0].size()));
|
||||
test.AddOutput<MLFloat16>("dX", output_shapes[0], dX_half);
|
||||
|
||||
std::vector<MLFloat16> dW_half(outputs[1].size());
|
||||
ConvertFloatToMLFloat16(outputs[1].data(), dW_half.data(), static_cast<int>(outputs[1].size()));
|
||||
test.AddOutput<MLFloat16>("dW", output_shapes[1], dW_half);
|
||||
|
||||
if (outputs.size() >= 3) {
|
||||
std::vector<MLFloat16> dB_half(outputs[2].size());
|
||||
ConvertFloatToMLFloat16(outputs[2].data(), dB_half.data(), static_cast<int>(outputs[2].size()));
|
||||
test.AddOutput<MLFloat16>("dB", output_shapes[2], dB_half);
|
||||
}
|
||||
} else {
|
||||
test.AddInput<float>("dY", input_shapes[0], inputs[0]);
|
||||
test.AddInput<float>("X", input_shapes[1], inputs[1]);
|
||||
test.AddInput<float>("W", input_shapes[2], inputs[2]);
|
||||
|
||||
test.AddOutput<float>("dX", output_shapes[0], outputs[0]);
|
||||
test.AddOutput<float>("dW", output_shapes[1], outputs[1]);
|
||||
|
||||
if (outputs.size() >= 3) {
|
||||
test.AddOutput<float>("dB", output_shapes[2], outputs[2]);
|
||||
}
|
||||
}
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(ConvTransposeGradTest, ConvTranspose1DDefaultAttributes) {
|
||||
ConvTransposeGradOpAttributes attrs = {
|
||||
std::vector<int64_t>{1}, // dilations
|
||||
1, // group
|
||||
std::vector<int64_t>{2}, // kernel_shape
|
||||
std::vector<int64_t>{0, 0}, // pads
|
||||
std::vector<int64_t>{1}, // strides
|
||||
};
|
||||
|
||||
std::vector<float> dY(12, 1.0f);
|
||||
std::vector<int64_t> dY_shape = {1, 2, 6};
|
||||
std::vector<float> X = {0.1868f, -0.1679f, 1.2677f, 2.1288f, -0.0331f,
|
||||
1.0454f, 0.7722f, 0.2963f, -0.8684f, -0.0547f};
|
||||
std::vector<int64_t> X_shape = {1, 2, 5};
|
||||
std::vector<float> W = {0.0847f, -0.0066f,
|
||||
0.1212f, 0.2317f,
|
||||
-0.4975f, 0.2762f,
|
||||
-0.2644f, 0.3210f};
|
||||
std::vector<int64_t> W_shape = {2, 2, 2};
|
||||
std::vector<float> dX = {0.4309f, 0.4309f, 0.4309f, 0.4309f, 0.4309f,
|
||||
-0.1647f, -0.1647f, -0.1647f, -0.1647f, -0.1647f};
|
||||
std::vector<int64_t> dX_shape = X_shape;
|
||||
std::vector<float> dW = {3.3823f, 3.3823f,
|
||||
3.3823f, 3.3823f,
|
||||
1.1908f, 1.1908f,
|
||||
1.1908f, 1.1908f};
|
||||
std::vector<int64_t> dW_shape = W_shape;
|
||||
std::vector<float> dB = {6.f, 6.f};
|
||||
std::vector<int64_t> dB_shape = {2};
|
||||
|
||||
for (const bool is_half : {false, true})
|
||||
TestConvTransposeGradOp(
|
||||
attrs, // attributes
|
||||
{dY, X, W}, // inputs
|
||||
{dY_shape, X_shape, W_shape}, // input shapes
|
||||
{dX, dW, dB}, // outputs
|
||||
{dX_shape, dW_shape, dB_shape}, // output shapes
|
||||
is_half);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeGradTest, ConvTranspose1DStrideAndPadding) {
|
||||
ConvTransposeGradOpAttributes attrs = {
|
||||
std::vector<int64_t>{1}, // dilations
|
||||
1, // group
|
||||
std::vector<int64_t>{2}, // kernel_shape
|
||||
std::vector<int64_t>{2, 2}, // pads
|
||||
std::vector<int64_t>{2}, // strides
|
||||
};
|
||||
|
||||
std::vector<float> dY(12, 1.0f);
|
||||
std::vector<int64_t> dY_shape = {1, 2, 6};
|
||||
std::vector<float> X = {-0.0254f, -1.4303f, -0.1568f, 1.2318f, -0.8365f,
|
||||
2.0836f, -1.0181f, -0.7539f, 0.4484f, -0.5799f};
|
||||
std::vector<int64_t> X_shape = {1, 2, 5};
|
||||
std::vector<float> W = {-0.1438f, 0.2386f,
|
||||
-0.3085f, 0.1149f,
|
||||
-0.1653f, -0.0707f,
|
||||
-0.1479f, -0.0918f};
|
||||
std::vector<int64_t> W_shape = {2, 2, 2};
|
||||
std::vector<float> dX = {0.0000f, -0.0988f, -0.0988f, -0.0988f, 0.0000f,
|
||||
0.0000f, -0.4757f, -0.4757f, -0.4757f, 0.0000f};
|
||||
std::vector<int64_t> dX_shape = X_shape;
|
||||
std::vector<float> dW = {-0.3553f, -0.3553f,
|
||||
-0.3553f, -0.3553f,
|
||||
-1.3236f, -1.3236f,
|
||||
-1.3236f, -1.3236f};
|
||||
std::vector<int64_t> dW_shape = W_shape;
|
||||
std::vector<float> dB = {6.f, 6.f};
|
||||
std::vector<int64_t> dB_shape = {2};
|
||||
|
||||
for (const bool is_half : {false, true})
|
||||
TestConvTransposeGradOp(
|
||||
attrs, // attributes
|
||||
{dY, X, W}, // inputs
|
||||
{dY_shape, X_shape, W_shape}, // input shapes
|
||||
{dX, dW, dB}, // outputs
|
||||
{dX_shape, dW_shape, dB_shape}, // output shapes
|
||||
is_half);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeGradTest, ConvTranspose1D) {
|
||||
ConvTransposeGradOpAttributes attrs = {
|
||||
std::vector<int64_t>{2}, // dilations
|
||||
2, // group
|
||||
std::vector<int64_t>{3}, // kernel_shape
|
||||
std::vector<int64_t>{2, 2}, // pads
|
||||
std::vector<int64_t>{2}, // strides
|
||||
};
|
||||
|
||||
std::vector<float> dY(38, 1.0f);
|
||||
std::vector<int64_t> dY_shape = {1, 2, 19};
|
||||
std::vector<float> X = {0.2816f, 1.4660f, 0.1002f, -0.2460f, -0.1027f, 0.1228f, -0.8516f, -1.0246f, -0.6576f, -1.0280f,
|
||||
0.1093f, 0.1447f, 1.1279f, 0.1085f, -0.3438f, -0.6224f, -0.0902f, 2.2791f, -2.1910f, 1.9736f};
|
||||
std::vector<int64_t> X_shape = {1, 2, 10};
|
||||
std::vector<float> W = {-0.1050f, -0.0622f, -0.3632f,
|
||||
-0.3861f, -0.0134f, -0.0277f};
|
||||
std::vector<int64_t> W_shape = {2, 1, 3};
|
||||
std::vector<float> dX = {-0.4254f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.1672f,
|
||||
-0.0411f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.3995f};
|
||||
std::vector<int64_t> dX_shape = X_shape;
|
||||
std::vector<float> dW = {-2.2215f, -1.9400f, -0.9120f,
|
||||
2.3863f, 2.4956f, 0.5220f};
|
||||
std::vector<int64_t> dW_shape = W_shape;
|
||||
std::vector<float> dB = {19.f, 19.f};
|
||||
std::vector<int64_t> dB_shape = {2};
|
||||
|
||||
for (const bool is_half : {false, true})
|
||||
TestConvTransposeGradOp(
|
||||
attrs, // attributes
|
||||
{dY, X, W}, // inputs
|
||||
{dY_shape, X_shape, W_shape}, // input shapes
|
||||
{dX, dW, dB}, // outputs
|
||||
{dX_shape, dW_shape, dB_shape}, // output shapes
|
||||
is_half);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeGradTest, ConvTranspose2DDefaultAttributes) {
|
||||
ConvTransposeGradOpAttributes attrs = {
|
||||
std::vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
std::vector<int64_t>{3, 3}, // kernel_shape
|
||||
std::vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
std::vector<int64_t>{1, 1}, // strides
|
||||
};
|
||||
|
||||
std::vector<float> dY(98, 1.0f);
|
||||
std::vector<int64_t> dY_shape = {1, 2, 7, 7};
|
||||
std::vector<float> X = {1.1371f, -0.1498f, -1.7541f, -0.7585f, 1.6009f, -0.7496f, 0.1535f, -0.2533f, -1.0811f, 0.9760f,
|
||||
-0.2528f, 0.1820f, -1.7450f, 0.1632f, -0.3469f, 1.1150f, -2.6888f, -0.1632f, -0.3269f, 0.6904f,
|
||||
1.3036f, 0.7883f, 0.4459f, 0.1223f, 0.1576f, -0.8187f, 0.2281f, 1.5320f, 1.2643f, -0.5163f,
|
||||
1.0677f, -0.2141f, 1.2992f, -2.1865f, -0.6346f, 0.8938f, 0.8346f, -2.7397f, 0.9223f, 0.8166f,
|
||||
1.1736f, -1.3644f, 0.0316f, -1.2904f, 0.7062f, 0.2470f, 0.4559f, 0.8493f, 1.0519f, 0.9915f};
|
||||
std::vector<int64_t> X_shape = {1, 2, 5, 5};
|
||||
std::vector<float> W = {0.0761f, 0.0270f, -0.1677f, 0.1803f, -0.0824f, -0.0285f,
|
||||
0.2098f, -0.0569f, -0.1514f, 0.0338f, -0.1962f, -0.2169f,
|
||||
0.0432f, -0.1977f, -0.0814f, -0.1866f, -0.1574f, -0.0198f,
|
||||
0.0097f, 0.0019f, -0.1204f, 0.2018f, -0.1750f, -0.0549f,
|
||||
-0.0687f, -0.1269f, 0.1913f, 0.1331f, -0.0632f, 0.0821f,
|
||||
0.0127f, 0.1761f, -0.0883f, -0.1370f, 0.1472f, 0.0690f};
|
||||
std::vector<int64_t> W_shape = {2, 2, 3, 3};
|
||||
std::vector<float> dX = {-0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f,
|
||||
-0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f,
|
||||
-0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f,
|
||||
0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f,
|
||||
0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f};
|
||||
std::vector<int64_t> dX_shape = X_shape;
|
||||
std::vector<float> dW = {-1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f,
|
||||
-1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f,
|
||||
-1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f,
|
||||
4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f,
|
||||
4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f,
|
||||
4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f};
|
||||
std::vector<int64_t> dW_shape = W_shape;
|
||||
std::vector<float> dB = {49.f, 49.f};
|
||||
std::vector<int64_t> dB_shape = {2};
|
||||
|
||||
for (const bool is_half : {false, true})
|
||||
TestConvTransposeGradOp(
|
||||
attrs, // attributes
|
||||
{dY, X, W}, // inputs
|
||||
{dY_shape, X_shape, W_shape}, // input shapes
|
||||
{dX, dW, dB}, // outputs
|
||||
{dX_shape, dW_shape, dB_shape}, // output shapes
|
||||
is_half);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeGradTest, ConvTranspose2D) {
|
||||
ConvTransposeGradOpAttributes attrs = {
|
||||
std::vector<int64_t>{2, 2}, // dilations
|
||||
2, // group
|
||||
std::vector<int64_t>{3, 3}, // kernel_shape
|
||||
std::vector<int64_t>{2, 2, 2, 2}, // pads
|
||||
std::vector<int64_t>{2, 2}, // strides
|
||||
};
|
||||
|
||||
std::vector<float> dY(162U, 1.0f);
|
||||
std::vector<int64_t> dY_shape = {1, 2, 9, 9};
|
||||
std::vector<float> X = {-1.0158f, 0.1709f, -0.1660f, 0.3881f, 0.4017f, 1.5497f, 1.1205f, 0.2553f, -0.4359f, -0.0467f,
|
||||
1.1374f, -0.0713f, 0.2248f, 0.8915f, -0.7239f, 0.1679f, -1.5604f, -0.8521f, 0.8966f, 3.3743f,
|
||||
-0.5516f, 0.2516f, -0.4091f, -0.9868f, 0.3008f, 1.1066f, -0.7039f, -1.5273f, -0.3666f, 0.9392f,
|
||||
0.1264f, -1.6604f, -1.4810f, 0.6654f, -0.2007f, -1.0660f, -0.5420f, -0.7030f, 0.0411f, 2.1082f,
|
||||
-0.7995f, 0.2422f, 1.2848f, -0.1747f, 1.7935f, -0.1123f, -0.6668f, -2.2383f, 1.5419f, -2.7614f};
|
||||
std::vector<int64_t> X_shape = {1, 2, 5, 5};
|
||||
std::vector<float> W = {-0.2057f, -0.0411f, 0.0277f, 0.2221f, 0.1901f, 0.1435f,
|
||||
-0.2249f, 0.3299f, -0.2203f, -0.1013f, -0.3326f, 0.1005f,
|
||||
-0.0536f, 0.3067f, 0.3297f, 0.2728f, 0.1649f, -0.2548f};
|
||||
std::vector<int64_t> W_shape = {2, 1, 3, 3};
|
||||
std::vector<float> dX = {0.4431f, 0.4403f, 0.4403f, 0.4403f, 0.5171f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f,
|
||||
0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f,
|
||||
0.3202f, 0.3366f, 0.3366f, 0.3366f, 0.1654f, 0.5465f, 0.7658f, 0.7658f, 0.7658f, 0.6908f,
|
||||
0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f,
|
||||
0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.4043f, 0.2494f, 0.2494f, 0.2494f, -0.1808f};
|
||||
std::vector<int64_t> dX_shape = X_shape;
|
||||
std::vector<float> dW = {2.2293f, 4.5327f, 1.6281f, 3.0240f, 4.3115f, 1.0052f,
|
||||
3.8675f, 5.7067f, 2.7011f, -2.7512f, -4.6026f, -5.5423f,
|
||||
-4.4098f, -5.1546f, -7.0335f, -0.2852f, -0.9177f, -5.5580f};
|
||||
std::vector<int64_t> dW_shape = W_shape;
|
||||
std::vector<float> dB = {81.f, 81.f};
|
||||
std::vector<int64_t> dB_shape = {2};
|
||||
|
||||
for (const bool is_half : {false, true})
|
||||
TestConvTransposeGradOp(
|
||||
attrs, // attributes
|
||||
{dY, X, W}, // inputs
|
||||
{dY_shape, X_shape, W_shape}, // input shapes
|
||||
{dX, dW, dB}, // outputs
|
||||
{dX_shape, dW_shape, dB_shape}, // output shapes
|
||||
is_half);
|
||||
}
|
||||
|
||||
TEST(ConvTransposeGradTest, ConvTranspose3D) {
|
||||
ConvTransposeGradOpAttributes attrs = {
|
||||
std::vector<int64_t>{2, 2, 2}, // dilations
|
||||
2, // group
|
||||
std::vector<int64_t>{2, 2, 2}, // kernel_shape
|
||||
std::vector<int64_t>{2, 2, 2, 2, 2, 2}, // pads
|
||||
std::vector<int64_t>{2, 2, 2}, // strides
|
||||
};
|
||||
|
||||
std::vector<float> dY(250U, 1.0f);
|
||||
std::vector<int64_t> dY_shape = {1, 2, 5, 5, 5};
|
||||
std::vector<float> X = {-0.2396f, 0.4280f, -1.3505f, -0.4366f, -1.3296f, 0.3531f, 0.0645f, -1.5480f,
|
||||
-1.7464f, -0.9160f, 1.5065f, -0.0788f, 0.0487f, 2.4641f, 0.3855f, 2.0499f,
|
||||
0.7068f, -0.8076f, -0.4442f, 0.1003f, -0.5056f, -0.1430f, -0.3744f, -0.2637f,
|
||||
-1.1012f, 1.0213f, 0.0503f, 0.0147f, -0.3664f, 0.8834f, -1.1478f, -0.8221f,
|
||||
-0.5649f, -0.4224f, -0.6779f, -0.9363f, 1.1972f, 0.2094f, 0.5676f, -0.2718f,
|
||||
-0.1678f, -0.4178f, -0.4672f, 0.2777f, -0.7953f, -0.5603f, -2.8694f, 1.5743f,
|
||||
-0.5057f, -0.2529f, 0.5894f, -0.3980f, -0.6719f, -0.3425f, 0.0821f, 0.8672f,
|
||||
0.7218f, 1.5519f, 1.6513f, -1.1956f, 0.8471f, 0.4295f, -1.3917f, -1.2202f,
|
||||
0.1054f, -2.2191f, -0.9546f, 1.1750f, -2.3637f, 1.6297f, -0.5796f, 0.3850f,
|
||||
0.9287f, -0.3492f, -0.7284f, 0.2987f, -0.7534f, 0.7747f, -1.3198f, -0.3633f,
|
||||
1.8635f, -0.3187f, 0.9032f, -0.6083f, -0.4236f, -0.1929f, -1.1715f, -0.5591f,
|
||||
-1.8290f, -1.1503f, 0.1430f, 0.6048f, -0.3148f, 1.0638f, -0.2946f, -0.4990f,
|
||||
-1.4443f, -0.7757f, -1.5374f, -0.4567f, -0.2998f, 0.0521f, 1.6293f, -0.6720f,
|
||||
-0.0102f, -0.6598f, 0.5005f, 0.4203f, 1.3911f, 1.5988f, 0.3991f, 1.4931f,
|
||||
0.9741f, 0.3557f, 0.1088f, -1.1806f, 1.1115f, -1.3283f, 1.7235f, 0.4177f,
|
||||
0.7992f, -1.7248f, -0.5339f, -0.3153f, 0.1379f, 0.7493f, 0.3028f, -0.9473f};
|
||||
std::vector<int64_t> X_shape = {1, 2, 4, 4, 4};
|
||||
std::vector<float> W = {-0.1093f, -0.0511f, 0.1132f, 0.3369f, -0.3531f, -0.1766f, 0.0628f, 0.2118f,
|
||||
0.3068f, 0.3217f, -0.2903f, -0.1633f, -0.3261f, -0.0990f, 0.2497f, -0.1553f};
|
||||
std::vector<int64_t> W_shape = {2, 1, 2, 2, 2};
|
||||
std::vector<float> dX = {0.2118f, 0.2746f, 0.2746f, 0.0628f, 0.0352f, -0.2550f, -0.2550f, -0.2902f,
|
||||
0.0352f, -0.2550f, -0.2550f, -0.2902f, -0.1766f, -0.5297f, -0.5297f, -0.3531f,
|
||||
0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f,
|
||||
0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f,
|
||||
0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f,
|
||||
0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f,
|
||||
0.3369f, 0.4501f, 0.4501f, 0.1132f, 0.2858f, 0.2897f, 0.2897f, 0.0038f,
|
||||
0.2858f, 0.2897f, 0.2897f, 0.0038f, -0.0511f, -0.1604f, -0.1604f, -0.1093f,
|
||||
-0.1553f, 0.0944f, 0.0944f, 0.2497f, -0.2542f, -0.3307f, -0.3307f, -0.0765f,
|
||||
-0.2542f, -0.3307f, -0.3307f, -0.0765f, -0.0990f, -0.4251f, -0.4251f, -0.3261f,
|
||||
-0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f,
|
||||
-0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f,
|
||||
-0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f,
|
||||
-0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f,
|
||||
-0.1633f, -0.4536f, -0.4536f, -0.2903f, 0.1584f, 0.1749f, 0.1749f, 0.0165f,
|
||||
0.1584f, 0.1749f, 0.1749f, 0.0165f, 0.3217f, 0.6285f, 0.6285f, 0.3068f};
|
||||
std::vector<int64_t> dX_shape = X_shape;
|
||||
std::vector<float> dW = {-2.3068f, -2.1096f, -0.4322f, 0.4820f, 1.5420f, -4.1569f, -4.9628f, -5.5716f,
|
||||
1.0492f, 1.6683f, -6.3262f, -3.2359f, 2.4532f, -2.3299f, -5.1917f, -9.2525f};
|
||||
std::vector<int64_t> dW_shape = W_shape;
|
||||
std::vector<float> dB = {125.f, 125.f};
|
||||
std::vector<int64_t> dB_shape = {2};
|
||||
|
||||
for (const bool is_half : {false, true})
|
||||
TestConvTransposeGradOp(
|
||||
attrs, // attributes
|
||||
{dY, X, W}, // inputs
|
||||
{dY_shape, X_shape, W_shape}, // input shapes
|
||||
{dX, dW, dB}, // outputs
|
||||
{dX_shape, dW_shape, dB_shape}, // output shapes
|
||||
is_half);
|
||||
}
|
||||
#endif // USE_CUDA
|
||||
|
||||
} // namespace onnxruntime::contrib::test
|
||||
|
|
@ -85,6 +85,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvTransposeGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvTransposeGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropoutGrad);
|
||||
|
|
@ -346,6 +349,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvTransposeGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvTransposeGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DivGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, DivGrad)>,
|
||||
|
|
|
|||
|
|
@ -3,13 +3,6 @@
|
|||
|
||||
#include "orttraining/training_ops/cuda/nn/conv_grad.h"
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
|
||||
// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad is referenced on PyTorch's implementation
|
||||
// from aten/src/ATen/native/cudnn/Conv_v7.cpp.
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
|
|
@ -22,229 +15,6 @@ REGISTER_GRADIENT_KERNEL_TYPED(float)
|
|||
REGISTER_GRADIENT_KERNEL_TYPED(double)
|
||||
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t;
|
||||
using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t;
|
||||
using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t;
|
||||
using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t;
|
||||
|
||||
cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) {
|
||||
return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc,
|
||||
args.x_tensor, algo, workspace_size);
|
||||
}
|
||||
|
||||
cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) {
|
||||
return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc,
|
||||
args.w_desc, algo, workspace_size);
|
||||
}
|
||||
|
||||
template <typename T_Algo>
|
||||
size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) {
|
||||
// Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info.
|
||||
size_t free, total;
|
||||
CUDA_CALL_THROW(cudaMemGetInfo(&free, &total));
|
||||
// Assuming 10% of fragmentation.
|
||||
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
|
||||
size_t max_workspace_size = 0;
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
cudnnStatus_t status;
|
||||
size_t workspace_size;
|
||||
status = GetWorkspaceSize(args, algo[i], &workspace_size);
|
||||
if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size ||
|
||||
workspace_size > free)
|
||||
continue;
|
||||
max_workspace_size = workspace_size;
|
||||
}
|
||||
|
||||
return max_workspace_size;
|
||||
}
|
||||
|
||||
template <typename T_Perf>
|
||||
std::vector<T_Perf> GetValidAlgorithms(const T_Perf* perf_results, int n_algo) {
|
||||
std::vector<T_Perf> result;
|
||||
result.reserve(n_algo);
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
T_Perf perf = perf_results[i];
|
||||
if (perf.status == CUDNN_STATUS_SUCCESS) {
|
||||
result.emplace_back(perf);
|
||||
}
|
||||
}
|
||||
ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN");
|
||||
// TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups
|
||||
// when cuDNN version < 7.5. Need to add handling for such special case.
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ConvParamsHash {
|
||||
// ConvParams must be a POD because we read out its memory constant as char* when hashing.
|
||||
static_assert(std::is_pod<ConvParams>::value, "ConvParams is not POD");
|
||||
size_t operator()(const ConvParams& conv_params) const {
|
||||
auto ptr = reinterpret_cast<const uint8_t*>(&conv_params);
|
||||
uint32_t value = 0x811C9DC5;
|
||||
for (int i = 0; i < static_cast<int>(sizeof(ConvParams)); ++i) {
|
||||
value ^= ptr[i];
|
||||
value *= 0x01000193;
|
||||
}
|
||||
return static_cast<size_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvParamsEqual {
|
||||
// ConvParams must be a POD because we read out its memory constant as char* when hashing.
|
||||
static_assert(std::is_pod<ConvParams>::value, "ConvParams is not POD");
|
||||
bool operator()(const ConvParams& a, const ConvParams& b) const {
|
||||
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
|
||||
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
|
||||
return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T_Perf>
|
||||
struct AlgoPerfCache {
|
||||
mutable OrtMutex mutex;
|
||||
std::unordered_map<ConvParams, T_Perf, ConvParamsHash, ConvParamsEqual> map;
|
||||
|
||||
bool Find(const ConvParams& params, T_Perf* result) {
|
||||
std::lock_guard<OrtMutex> guard(mutex);
|
||||
auto it = map.find(params);
|
||||
if (it == map.end()) {
|
||||
return false;
|
||||
}
|
||||
*result = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Insert(const ConvParams& params, const T_Perf& algo_perf) {
|
||||
std::lock_guard<OrtMutex> guard(mutex);
|
||||
map[params] = algo_perf;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is till per node.
|
||||
// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc.
|
||||
AlgoPerfCache<T_BwdDataPerf> bwd_data_algos;
|
||||
AlgoPerfCache<T_BwdFilterPerf> bwd_filter_algos;
|
||||
|
||||
template <typename T_Perf>
|
||||
struct AlgoSearch {};
|
||||
|
||||
template <>
|
||||
struct AlgoSearch<T_BwdDataPerf> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
static AlgoPerfCache<T_BwdDataPerf>& Cache() { return bwd_data_algos; }
|
||||
static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::vector<T_BwdDataPerf>& perf_results) {
|
||||
static const T_BwdDataAlgo algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms.");
|
||||
int perf_count;
|
||||
std::unique_ptr<T_BwdDataPerf[]> candidates = std::make_unique<T_BwdDataPerf[]>(num_algos);
|
||||
if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor,
|
||||
args.conv_desc, args.x_tensor, num_algos,
|
||||
&perf_count, candidates.get()));
|
||||
} else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) {
|
||||
size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr<void>(allocator, max_workspace_size, true);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||
args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor,
|
||||
args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size));
|
||||
} else {
|
||||
ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode);
|
||||
}
|
||||
perf_results = GetValidAlgorithms<T_BwdDataPerf>(candidates.get(), perf_count);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AlgoSearch<T_BwdFilterPerf> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
static AlgoPerfCache<T_BwdFilterPerf>& Cache() { return bwd_filter_algos; }
|
||||
static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::vector<T_BwdFilterPerf>& perf_results) {
|
||||
static const T_BwdFilterAlgo algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
|
||||
};
|
||||
|
||||
// NOTE: - 1 because ALGO_WINOGRAD is not implemented.
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
|
||||
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms.");
|
||||
std::unique_ptr<T_BwdFilterPerf[]> candidates = std::make_unique<T_BwdFilterPerf[]>(num_algos);
|
||||
int perf_count;
|
||||
if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor,
|
||||
args.conv_desc, args.w_desc, num_algos,
|
||||
&perf_count, candidates.get()));
|
||||
} else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) {
|
||||
size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr<void>(allocator, max_workspace_size, true);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
||||
args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc,
|
||||
args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size));
|
||||
} else {
|
||||
ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode);
|
||||
}
|
||||
perf_results = GetValidAlgorithms<T_BwdFilterPerf>(candidates.get(), perf_count);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T_Perf>
|
||||
class AlgoIterator {
|
||||
public:
|
||||
AlgoIterator(const ConvArgs& args) : args_(args) {}
|
||||
|
||||
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
|
||||
perf_results.resize(1);
|
||||
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
|
||||
if (args.params.data_type == CUDNN_DATA_HALF) {
|
||||
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else {
|
||||
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function<Status(const T_Perf& perf)> f) {
|
||||
auto& cache = AlgoSearch<T_Perf>::Cache();
|
||||
|
||||
if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<T_Perf> perf_results;
|
||||
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
|
||||
? OnlyDefaultAlgorithm(args_, perf_results)
|
||||
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
|
||||
for (auto& algo_perf : perf_results) {
|
||||
if (f(algo_perf) == Status::OK()) {
|
||||
cache.Insert(args_.params, algo_perf);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const ConvArgs& args_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status ConvGrad<T>::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX,
|
||||
Tensor* dW, cudnnHandle_t cudnn_handle) const {
|
||||
|
|
|
|||
|
|
@ -3,47 +3,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cpu/nn/conv_attributes.h"
|
||||
#include "core/providers/cuda/nn/conv.h"
|
||||
#include "orttraining/training_ops/cuda/nn/conv_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
// cuDNN only takes 4D or 5D x tensor.
|
||||
static constexpr int MAX_DIM = 3;
|
||||
|
||||
struct ConvParams {
|
||||
int8_t device_id;
|
||||
cudnnDataType_t data_type;
|
||||
int input_size[2 + MAX_DIM];
|
||||
uint8_t input_dim;
|
||||
int weight_size[2 + MAX_DIM];
|
||||
int padding[MAX_DIM * 2];
|
||||
int stride[MAX_DIM];
|
||||
int dilation[MAX_DIM];
|
||||
int64_t groups;
|
||||
int algo_mode;
|
||||
};
|
||||
|
||||
struct ConvArgs {
|
||||
// Update needed if x or w's dims changed.
|
||||
TensorShapeVector last_x_dims;
|
||||
TensorShapeVector last_w_dims;
|
||||
|
||||
cudnnHandle_t handle;
|
||||
ConvParams params;
|
||||
CudnnTensor x_tensor, y_tensor, b_tensor;
|
||||
CudnnFilterDescriptor w_desc;
|
||||
CudnnConvolutionDescriptor conv_desc;
|
||||
const void* x_data;
|
||||
const void* w_data;
|
||||
const void* dy_data;
|
||||
void* dx_data;
|
||||
void* dw_data;
|
||||
void* db_data;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ConvGrad final : public CudaKernel {
|
||||
public:
|
||||
|
|
|
|||
275
orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
Normal file
275
orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/nn/conv_shared.h"
|
||||
|
||||
#include "core/platform/ort_mutex.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
||||
namespace onnxruntime::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) {
|
||||
return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc,
|
||||
args.x_tensor, algo, workspace_size);
|
||||
}
|
||||
|
||||
cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) {
|
||||
return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc,
|
||||
args.w_desc, algo, workspace_size);
|
||||
}
|
||||
|
||||
cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_FwdAlgo algo, size_t* workspace_size) {
|
||||
return cudnnGetConvolutionForwardWorkspaceSize(args.handle, args.x_tensor, args.w_desc, args.conv_desc,
|
||||
args.y_tensor, algo, workspace_size);
|
||||
}
|
||||
|
||||
template <typename T_Algo>
|
||||
size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) {
|
||||
// Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info.
|
||||
size_t free, total;
|
||||
CUDA_CALL_THROW(cudaMemGetInfo(&free, &total));
|
||||
// Assuming 10% of fragmentation.
|
||||
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
|
||||
size_t max_workspace_size = 0;
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
cudnnStatus_t status;
|
||||
size_t workspace_size;
|
||||
status = GetWorkspaceSize(args, algo[i], &workspace_size);
|
||||
if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size ||
|
||||
workspace_size > free)
|
||||
continue;
|
||||
max_workspace_size = workspace_size;
|
||||
}
|
||||
|
||||
return max_workspace_size;
|
||||
}
|
||||
|
||||
template <typename T_Perf>
|
||||
std::vector<T_Perf> GetValidAlgorithms(const T_Perf* perf_results, int n_algo) {
|
||||
std::vector<T_Perf> result;
|
||||
result.reserve(n_algo);
|
||||
for (int i = 0; i < n_algo; i++) {
|
||||
T_Perf perf = perf_results[i];
|
||||
if (perf.status == CUDNN_STATUS_SUCCESS) {
|
||||
result.emplace_back(perf);
|
||||
}
|
||||
}
|
||||
ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN");
|
||||
// TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups
|
||||
// when cuDNN version < 7.5. Need to add handling for such special case.
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T_Perf>
|
||||
struct AlgoPerfCache {
|
||||
mutable OrtMutex mutex;
|
||||
std::unordered_map<ConvParams, T_Perf, ConvParamsHash, ConvParamsEqual> map;
|
||||
|
||||
bool Find(const ConvParams& params, T_Perf* result) {
|
||||
std::lock_guard<OrtMutex> guard(mutex);
|
||||
auto it = map.find(params);
|
||||
if (it == map.end()) {
|
||||
return false;
|
||||
}
|
||||
*result = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
void Insert(const ConvParams& params, const T_Perf& algo_perf) {
|
||||
std::lock_guard<OrtMutex> guard(mutex);
|
||||
map[params] = algo_perf;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Currently we use global AlgoPerfCache for ConvGrad and ConvTransposeGrad only.
|
||||
// Conv's perf cache is still per node.
|
||||
// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc.
|
||||
AlgoPerfCache<T_BwdDataPerf> bwd_data_algos;
|
||||
AlgoPerfCache<T_BwdFilterPerf> bwd_filter_algos;
|
||||
AlgoPerfCache<T_FwdPerf> fwd_algos;
|
||||
|
||||
template <typename T_Perf>
|
||||
struct AlgoSearch {};
|
||||
|
||||
template <>
|
||||
struct AlgoSearch<T_BwdDataPerf> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
static AlgoPerfCache<T_BwdDataPerf>& Cache() { return bwd_data_algos; }
|
||||
static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::vector<T_BwdDataPerf>& perf_results) {
|
||||
static const T_BwdDataAlgo algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms.");
|
||||
int perf_count;
|
||||
std::unique_ptr<T_BwdDataPerf[]> candidates = std::make_unique<T_BwdDataPerf[]>(num_algos);
|
||||
if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor,
|
||||
args.conv_desc, args.x_tensor, num_algos,
|
||||
&perf_count, candidates.get()));
|
||||
} else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) {
|
||||
size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr<void>(allocator, max_workspace_size, true);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||
args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor,
|
||||
args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size));
|
||||
} else {
|
||||
ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode);
|
||||
}
|
||||
perf_results = GetValidAlgorithms<T_BwdDataPerf>(candidates.get(), perf_count);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AlgoSearch<T_BwdFilterPerf> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
static AlgoPerfCache<T_BwdFilterPerf>& Cache() { return bwd_filter_algos; }
|
||||
static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::vector<T_BwdFilterPerf>& perf_results) {
|
||||
static const T_BwdFilterAlgo algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
|
||||
};
|
||||
|
||||
// NOTE: - 1 because ALGO_WINOGRAD is not implemented.
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
|
||||
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms.");
|
||||
std::unique_ptr<T_BwdFilterPerf[]> candidates = std::make_unique<T_BwdFilterPerf[]>(num_algos);
|
||||
int perf_count;
|
||||
if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor,
|
||||
args.conv_desc, args.w_desc, num_algos,
|
||||
&perf_count, candidates.get()));
|
||||
} else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) {
|
||||
size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr<void>(allocator, max_workspace_size, true);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
||||
args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc,
|
||||
args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size));
|
||||
} else {
|
||||
ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode);
|
||||
}
|
||||
perf_results = GetValidAlgorithms<T_BwdFilterPerf>(candidates.get(), perf_count);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AlgoSearch<T_FwdPerf> {
|
||||
static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
static AlgoPerfCache<T_FwdPerf>& Cache() { return fwd_algos; }
|
||||
static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::vector<T_FwdPerf>& perf_results) {
|
||||
static const T_FwdAlgo algos[] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
};
|
||||
|
||||
static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms.");
|
||||
std::unique_ptr<T_FwdPerf[]> candidates = std::make_unique<T_FwdPerf[]>(num_algos);
|
||||
int perf_count;
|
||||
if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) {
|
||||
CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(args.handle, args.x_tensor, args.w_desc,
|
||||
args.conv_desc, args.y_tensor, num_algos,
|
||||
&perf_count, candidates.get()));
|
||||
} else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) {
|
||||
size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos)
|
||||
: AlgoSearchWorkspaceSize;
|
||||
// Use GetTransientScratchBuffer() so the workspace can be freed instead of cached.
|
||||
// Because the benchmarking uses a huge amount of memory, e.g. a few GBs.
|
||||
IAllocatorUniquePtr<void> workspace = max_workspace_size == 0
|
||||
? nullptr
|
||||
: IAllocator::MakeUniquePtr<void>(allocator, max_workspace_size, true);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
args.handle, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, args.y_tensor,
|
||||
args.y_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size));
|
||||
} else {
|
||||
ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode);
|
||||
}
|
||||
perf_results = GetValidAlgorithms<T_FwdPerf>(candidates.get(), perf_count);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
size_t ConvParamsHash::operator()(const ConvParams& conv_params) const {
|
||||
auto ptr = reinterpret_cast<const uint8_t*>(&conv_params);
|
||||
uint32_t value = 0x811C9DC5;
|
||||
for (int i = 0; i < static_cast<int>(sizeof(ConvParams)); ++i) {
|
||||
value ^= ptr[i];
|
||||
value *= 0x01000193;
|
||||
}
|
||||
return static_cast<size_t>(value);
|
||||
}
|
||||
|
||||
bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const {
|
||||
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
|
||||
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
|
||||
return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0;
|
||||
}
|
||||
|
||||
template <typename T_Perf>
|
||||
Status AlgoIterator<T_Perf>::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results) {
|
||||
perf_results.resize(1);
|
||||
perf_results[0].algo = AlgoSearch<T_Perf>::DEFAULT_ALGO;
|
||||
if (args.params.data_type == CUDNN_DATA_HALF) {
|
||||
perf_results[0].mathType = CUDNN_TENSOR_OP_MATH;
|
||||
} else {
|
||||
perf_results[0].mathType = CUDNN_DEFAULT_MATH;
|
||||
}
|
||||
CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T_Perf>
|
||||
Status AlgoIterator<T_Perf>::TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::function<Status(const T_Perf& perf)> f) {
|
||||
auto& cache = AlgoSearch<T_Perf>::Cache();
|
||||
|
||||
if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<T_Perf> perf_results;
|
||||
ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault
|
||||
? OnlyDefaultAlgorithm(args_, perf_results)
|
||||
: AlgoSearch<T_Perf>::FindAlgorithms(args_, provider, allocator, perf_results));
|
||||
for (auto& algo_perf : perf_results) {
|
||||
if (f(algo_perf) == Status::OK()) {
|
||||
cache.Insert(args_.params, algo_perf);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template class AlgoIterator<T_BwdDataPerf>;
|
||||
template class AlgoIterator<T_BwdFilterPerf>;
|
||||
template class AlgoIterator<T_FwdPerf>;
|
||||
|
||||
} // namespace onnxruntime::cuda
|
||||
84
orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
Normal file
84
orttraining/orttraining/training_ops/cuda/nn/conv_shared.h
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cudnn_common.h"
|
||||
#include "core/providers/cuda/nn/conv.h"
|
||||
|
||||
// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad/ConvTransposeGrad is adapted from PyTorch's implementation
|
||||
// in aten/src/ATen/native/cudnn/Conv_v7.cpp.
|
||||
|
||||
namespace onnxruntime::cuda {
|
||||
|
||||
using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t;
|
||||
using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t;
|
||||
using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t;
|
||||
using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t;
|
||||
using T_FwdAlgo = cudnnConvolutionFwdAlgo_t;
|
||||
using T_FwdPerf = cudnnConvolutionFwdAlgoPerf_t;
|
||||
|
||||
// cuDNN only takes 4D or 5D x tensor.
|
||||
static constexpr int MAX_DIM = 3;
|
||||
|
||||
struct ConvParams {
|
||||
int8_t device_id;
|
||||
cudnnDataType_t data_type;
|
||||
int input_size[2 + MAX_DIM];
|
||||
uint8_t input_dim;
|
||||
int weight_size[2 + MAX_DIM];
|
||||
int padding[MAX_DIM * 2];
|
||||
int stride[MAX_DIM];
|
||||
int dilation[MAX_DIM];
|
||||
int64_t groups;
|
||||
int algo_mode;
|
||||
};
|
||||
|
||||
struct ConvArgs {
|
||||
// Update needed if x or w's dims changed.
|
||||
TensorShapeVector last_x_dims; // Input to the convolution
|
||||
TensorShapeVector last_w_dims; // Weights of the convolution
|
||||
|
||||
cudnnHandle_t handle;
|
||||
ConvParams params;
|
||||
CudnnTensor x_tensor, y_tensor, b_tensor;
|
||||
CudnnFilterDescriptor w_desc;
|
||||
CudnnConvolutionDescriptor conv_desc;
|
||||
const void* x_data;
|
||||
const void* w_data;
|
||||
const void* dy_data;
|
||||
void* y_data;
|
||||
void* dx_data;
|
||||
void* dw_data;
|
||||
void* db_data;
|
||||
};
|
||||
|
||||
struct ConvParamsHash {
|
||||
// ConvParams must be a POD because we read out its memory constant as char* when hashing.
|
||||
static_assert(std::is_pod<ConvParams>::value, "ConvParams is not POD");
|
||||
|
||||
size_t operator()(const ConvParams& conv_params) const;
|
||||
};
|
||||
|
||||
struct ConvParamsEqual {
|
||||
// ConvParams must be a POD because we read out its memory constant as char* when hashing.
|
||||
static_assert(std::is_pod<ConvParams>::value, "ConvParams is not POD");
|
||||
|
||||
bool operator()(const ConvParams& a, const ConvParams& b) const;
|
||||
};
|
||||
|
||||
template <typename T_Perf>
|
||||
class AlgoIterator {
|
||||
public:
|
||||
AlgoIterator(const ConvArgs& args) : args_(args) {}
|
||||
|
||||
Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator,
|
||||
std::function<Status(const T_Perf& perf)> f);
|
||||
|
||||
static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector<T_Perf>& perf_results);
|
||||
|
||||
private:
|
||||
const ConvArgs& args_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime::cuda
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "orttraining/training_ops/cuda/nn/conv_transpose_grad.h"
|
||||
|
||||
namespace onnxruntime::cuda {
|
||||
|
||||
#define REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTransposeGrad, kMSDomain, 1, T, kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
ConvTransposeGrad<T>);
|
||||
|
||||
REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(float)
|
||||
REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(double)
|
||||
REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
template <typename T>
|
||||
Status ConvTransposeGrad<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* dY = context->Input<Tensor>(0);
|
||||
const Tensor* X = context->Input<Tensor>(1);
|
||||
const Tensor* W = context->Input<Tensor>(2);
|
||||
Tensor* dX = context->Output(0, X->Shape());
|
||||
Tensor* dW = context->Output(1, W->Shape());
|
||||
Tensor* dB = context->Output(2, {W->Shape()[1] * conv_attrs_.group});
|
||||
|
||||
if (dX) {
|
||||
ORT_RETURN_IF_ERROR(PrepareConvForwardArgs(*dY, *W, *dX, GetCudnnHandle(context), args_dx_));
|
||||
ORT_RETURN_IF_ERROR(ComputeInputGradient(context->GetComputeStream(), args_dx_));
|
||||
}
|
||||
|
||||
if (dW || dB) {
|
||||
ORT_RETURN_IF_ERROR(PrepareConvBackwardFilterArgs(*dY, *W, *X, dW, dB, GetCudnnHandle(context), args_dw_));
|
||||
if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient(context->GetComputeStream(), args_dw_));
|
||||
if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient(args_dw_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvTransposeGrad<T>::ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const {
|
||||
return AlgoIterator<T_FwdPerf>(args).TryAll(
|
||||
static_cast<const CUDAExecutionProvider*>(Info().GetExecutionProvider()),
|
||||
Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
|
||||
[&](const T_FwdPerf& algo_perf) -> Status {
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(algo_perf.memory, stream);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(
|
||||
args.handle, &one, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc,
|
||||
algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data));
|
||||
return Status::OK();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvTransposeGrad<T>::ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const {
|
||||
return AlgoIterator<T_BwdFilterPerf>(args).TryAll(
|
||||
static_cast<const CUDAExecutionProvider*>(Info().GetExecutionProvider()),
|
||||
Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
|
||||
[&](const T_BwdFilterPerf& algo_perf) -> Status {
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
IAllocatorUniquePtr<void> workspace = GetScratchBuffer<void>(algo_perf.memory, stream);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType));
|
||||
CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardFilter(
|
||||
args.handle, &one, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc,
|
||||
algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data));
|
||||
return Status::OK();
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvTransposeGrad<T>::ComputeBiasGradient(const ConvArgs& args) const {
|
||||
const auto one = Consts<CudaT>::One;
|
||||
const auto zero = Consts<CudaT>::Zero;
|
||||
CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardBias(args.handle, &one, args.x_tensor, args.x_data, &zero,
|
||||
args.b_tensor, args.db_data));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvTransposeGrad<T>::PrepareConvForwardArgs(const Tensor& X, const Tensor& W,
|
||||
Tensor& Y, cudnnHandle_t cudnn_handle,
|
||||
ConvArgs& args) const {
|
||||
const TensorShape& x_shape = X.Shape();
|
||||
auto x_dims = x_shape.AsShapeVector();
|
||||
args.x_data = reinterpret_cast<const CudaT*>(X.template Data<T>());
|
||||
|
||||
const TensorShape& w_shape = W.Shape();
|
||||
auto w_dims = w_shape.AsShapeVector();
|
||||
args.w_data = reinterpret_cast<const CudaT*>(W.template Data<T>());
|
||||
|
||||
const TensorShape& y_shape = Y.Shape();
|
||||
auto y_dims = y_shape.AsShapeVector();
|
||||
args.y_data = reinterpret_cast<CudaT*>(Y.template MutableData<T>());
|
||||
|
||||
args.dy_data = nullptr;
|
||||
args.db_data = nullptr;
|
||||
args.dx_data = nullptr;
|
||||
args.dw_data = nullptr;
|
||||
|
||||
bool x_dims_changed = (args.last_x_dims != x_dims);
|
||||
bool w_dims_changed = (args.last_w_dims != w_dims);
|
||||
if (x_dims_changed || w_dims_changed) {
|
||||
if (x_dims_changed) args.last_x_dims = x_dims;
|
||||
if (w_dims_changed) args.last_w_dims = w_dims;
|
||||
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W));
|
||||
|
||||
TensorShapeVector kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape));
|
||||
auto rank = kernel_shape.size();
|
||||
|
||||
ConvPadVector pads(conv_attrs_.pads);
|
||||
if (pads.empty()) {
|
||||
pads.resize(rank * 2, 0);
|
||||
}
|
||||
|
||||
TensorShapeVector dilations(conv_attrs_.dilations);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(rank, 1);
|
||||
}
|
||||
|
||||
TensorShapeVector strides(conv_attrs_.strides);
|
||||
if (strides.empty()) {
|
||||
strides.resize(rank, 1);
|
||||
}
|
||||
|
||||
const CUDAExecutionProvider* cuda_ep =
|
||||
static_cast<const CUDAExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
|
||||
if (rank < 2) {
|
||||
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
|
||||
x_dims.insert(x_dims.begin() + 2, 1);
|
||||
y_dims.insert(y_dims.begin() + 2, 1);
|
||||
w_dims.insert(w_dims.begin() + 2, 1);
|
||||
pads.insert(pads.begin() + rank, 0);
|
||||
pads.insert(pads.begin(), 0);
|
||||
kernel_shape.insert(kernel_shape.begin(), 1);
|
||||
strides.insert(strides.begin(), 1);
|
||||
dilations.insert(dilations.begin(), 1);
|
||||
} else {
|
||||
x_dims.push_back(1);
|
||||
y_dims.push_back(1);
|
||||
w_dims.push_back(1);
|
||||
pads.insert(pads.begin() + rank, 0);
|
||||
pads.insert(pads.end(), 0);
|
||||
kernel_shape.push_back(1);
|
||||
strides.push_back(1);
|
||||
dilations.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
memset(&args.params, 0, sizeof(ConvParams));
|
||||
args.params.device_id = static_cast<int8_t>(cuda_ep->GetDeviceId());
|
||||
args.params.data_type = CudnnTensor::GetDataType<CudaT>();
|
||||
args.params.input_dim = static_cast<uint8_t>(x_dims.size());
|
||||
for (size_t i = 0; i < x_dims.size(); i++) {
|
||||
args.params.input_size[i] = static_cast<int>(x_dims[i]);
|
||||
args.params.weight_size[i] = static_cast<int>(w_dims[i]);
|
||||
}
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
args.params.padding[i] = static_cast<int>(pads[i]);
|
||||
args.params.padding[i + rank] = static_cast<int>(pads[i + rank]);
|
||||
args.params.stride[i] = static_cast<int>(strides[i]);
|
||||
args.params.dilation[i] = static_cast<int>(dilations[i]);
|
||||
}
|
||||
args.params.groups = conv_attrs_.group;
|
||||
int algo_mode = cuda_ep->GetCudnnConvAlgo();
|
||||
ORT_ENFORCE(algo_mode > -1 && algo_mode < 3,
|
||||
"Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode);
|
||||
args.params.algo_mode = algo_mode;
|
||||
|
||||
args.handle = cudnn_handle;
|
||||
ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
|
||||
args.params.data_type));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ConvTransposeGrad<T>::PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY,
|
||||
Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle,
|
||||
ConvArgs& args) const {
|
||||
const TensorShape& x_shape = X.Shape();
|
||||
auto x_dims = x_shape.AsShapeVector();
|
||||
args.x_data = reinterpret_cast<const CudaT*>(X.template Data<T>());
|
||||
|
||||
const TensorShape& y_shape = dY.Shape();
|
||||
auto y_dims = y_shape.AsShapeVector();
|
||||
args.dy_data = reinterpret_cast<const CudaT*>(dY.template Data<T>());
|
||||
|
||||
const TensorShape& w_shape = W.Shape();
|
||||
auto w_dims = w_shape.AsShapeVector();
|
||||
|
||||
args.y_data = nullptr;
|
||||
args.dw_data = dW ? reinterpret_cast<CudaT*>(dW->template MutableData<T>()) : nullptr;
|
||||
args.db_data = dB ? reinterpret_cast<CudaT*>(dB->template MutableData<T>()) : nullptr;
|
||||
args.dx_data = nullptr;
|
||||
args.w_data = nullptr;
|
||||
|
||||
bool x_dims_changed = (args.last_x_dims != x_dims);
|
||||
bool w_dims_changed = (args.last_w_dims != w_dims);
|
||||
if (x_dims_changed || w_dims_changed) {
|
||||
if (x_dims_changed) args.last_x_dims = x_dims;
|
||||
if (w_dims_changed) args.last_w_dims = w_dims;
|
||||
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W));
|
||||
|
||||
TensorShapeVector kernel_shape;
|
||||
ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape));
|
||||
auto rank = kernel_shape.size();
|
||||
|
||||
ConvPadVector pads(conv_attrs_.pads);
|
||||
if (pads.empty()) {
|
||||
pads.resize(rank * 2, 0);
|
||||
}
|
||||
|
||||
TensorShapeVector dilations(conv_attrs_.dilations);
|
||||
if (dilations.empty()) {
|
||||
dilations.resize(rank, 1);
|
||||
}
|
||||
|
||||
TensorShapeVector strides(conv_attrs_.strides);
|
||||
if (strides.empty()) {
|
||||
strides.resize(rank, 1);
|
||||
}
|
||||
|
||||
const CUDAExecutionProvider* cuda_ep =
|
||||
static_cast<const CUDAExecutionProvider*>(this->Info().GetExecutionProvider());
|
||||
|
||||
if (rank < 2) {
|
||||
if (cuda_ep->GetCudnnConv1dPadToNc1d()) {
|
||||
x_dims.insert(x_dims.begin() + 2, 1);
|
||||
y_dims.insert(y_dims.begin() + 2, 1);
|
||||
w_dims.insert(w_dims.begin() + 2, 1);
|
||||
pads.insert(pads.begin() + rank, 0);
|
||||
pads.insert(pads.begin(), 0);
|
||||
kernel_shape.insert(kernel_shape.begin(), 1);
|
||||
strides.insert(strides.begin(), 1);
|
||||
dilations.insert(dilations.begin(), 1);
|
||||
} else {
|
||||
x_dims.push_back(1);
|
||||
y_dims.push_back(1);
|
||||
w_dims.push_back(1);
|
||||
pads.insert(pads.begin() + rank, 0);
|
||||
pads.insert(pads.end(), 0);
|
||||
kernel_shape.push_back(1);
|
||||
strides.push_back(1);
|
||||
dilations.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
memset(&args.params, 0, sizeof(ConvParams));
|
||||
args.params.device_id = static_cast<int8_t>(cuda_ep->GetDeviceId());
|
||||
args.params.data_type = CudnnTensor::GetDataType<CudaT>();
|
||||
args.params.input_dim = static_cast<uint8_t>(x_dims.size());
|
||||
for (size_t i = 0; i < x_dims.size(); i++) {
|
||||
args.params.input_size[i] = static_cast<int>(x_dims[i]);
|
||||
args.params.weight_size[i] = static_cast<int>(w_dims[i]);
|
||||
}
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
args.params.padding[i] = static_cast<int>(pads[i]);
|
||||
args.params.padding[i + rank] = static_cast<int>(pads[i + rank]);
|
||||
args.params.stride[i] = static_cast<int>(strides[i]);
|
||||
args.params.dilation[i] = static_cast<int>(dilations[i]);
|
||||
}
|
||||
args.params.groups = conv_attrs_.group;
|
||||
int algo_mode = cuda_ep->GetCudnnConvAlgo();
|
||||
ORT_ENFORCE(algo_mode > -1 && algo_mode < 3,
|
||||
"Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode);
|
||||
args.params.algo_mode = algo_mode;
|
||||
|
||||
args.handle = cudnn_handle;
|
||||
ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type));
|
||||
ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations,
|
||||
gsl::narrow_cast<int>(conv_attrs_.group), CUDNN_CROSS_CORRELATION,
|
||||
args.params.data_type));
|
||||
|
||||
if (dB) {
|
||||
const auto& b_shape = dB->Shape();
|
||||
ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D");
|
||||
TensorShapeVector b_dims(2 + kernel_shape.size());
|
||||
b_dims[0] = 1; // N
|
||||
b_dims[1] = b_shape[0]; // C
|
||||
for (size_t i = 0; i < kernel_shape.size(); i++)
|
||||
b_dims[2 + i] = 1;
|
||||
|
||||
ORT_RETURN_IF_ERROR(args.b_tensor.Set(b_dims, CudnnTensor::GetDataType<CudaT>()));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime::cuda
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
||||
#include "core/providers/cpu/nn/conv_attributes.h"
|
||||
#include "orttraining/training_ops/cuda/nn/conv_shared.h"
|
||||
|
||||
namespace onnxruntime::cuda {
|
||||
|
||||
template <typename T>
|
||||
class ConvTransposeGrad final : public CudaKernel {
|
||||
public:
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
|
||||
ConvTransposeGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
Status ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const;
|
||||
Status ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const;
|
||||
Status ComputeBiasGradient(const ConvArgs& args) const;
|
||||
|
||||
Status PrepareConvForwardArgs(const Tensor& X, const Tensor& W,
|
||||
Tensor& Y, cudnnHandle_t cudnn_handle,
|
||||
ConvArgs& args) const;
|
||||
|
||||
Status PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY,
|
||||
Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle,
|
||||
ConvArgs& args) const;
|
||||
|
||||
ConvAttributes conv_attrs_;
|
||||
mutable ConvArgs args_dx_;
|
||||
mutable ConvArgs args_dw_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime::cuda
|
||||
Loading…
Reference in a new issue