diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index e07de1fc34..9f020477eb 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -674,7 +674,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Gather); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); @@ -767,7 +767,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, uint8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, CumSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_float_int32_t, OneHot); @@ -775,7 +775,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int32_t_MLFloat16_int32_t, OneHot); // OpSet 12 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool); @@ -989,6 +989,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Loop); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, LRN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Identity); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1252,9 +1265,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1277,26 +1290,26 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1307,7 +1320,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1328,7 +1341,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1338,12 +1351,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1361,9 +1374,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1372,7 +1385,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1411,7 +1424,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1442,17 +1455,17 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1460,7 +1473,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // OpSet 12 - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -1674,6 +1687,19 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/orttraining/orttraining/training_ops/cuda/math/div_grad.cc b/orttraining/orttraining/training_ops/cuda/math/div_grad.cc index 75477dcc93..21b3dd9c8d 100644 --- a/orttraining/orttraining/training_ops/cuda/math/div_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/math/div_grad.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "div_grad.h" -#include "div_grad_impl.h" +#include "orttraining/training_ops/cuda/math/div_grad.h" +#include "orttraining/training_ops/cuda/math/div_grad_impl.h" #include "core/providers/cuda/math/binary_elementwise_ops.h" using namespace onnxruntime::common; diff --git a/orttraining/orttraining/training_ops/rocm/math/div_grad.cc b/orttraining/orttraining/training_ops/rocm/math/div_grad.cc new file mode 100644 index 0000000000..dae76670b6 --- /dev/null +++ b/orttraining/orttraining/training_ops/rocm/math/div_grad.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/rocm/math/div_grad.h" +#include "orttraining/training_ops/rocm/math/div_grad_impl.h" +#include "core/providers/rocm/math/binary_elementwise_ops.h" + +using namespace onnxruntime::common; +namespace onnxruntime { +namespace rocm { + +#define DIVGRAD_REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DivGrad, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DivGrad); + +DIVGRAD_REGISTER_KERNEL_TYPED(MLFloat16) +DIVGRAD_REGISTER_KERNEL_TYPED(float) +// DIVGRAD_REGISTER_KERNEL_TYPED(double) + +std::vector prepended_dimension_1(const TensorShape& shape, size_t total_rank) { + size_t input_rank = shape.NumDimensions(); + if (input_rank == total_rank) + return shape.GetDims(); + + std::vector dims(total_rank, 1); + + // https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md + // for property 3 of Multidirectional Broadcasting, we need to prepended with a dimension of length 1. + if (input_rank > 0) + std::copy(shape.GetDims().begin(), shape.GetDims().end(), &dims[total_rank - input_rank]); + return dims; +} + +template +Status DivGrad::ComputeInternal(OpKernelContext* context) const { + typedef typename ToHipType::MappedType HipT; + + const Tensor* dy_tensor = context->Input(0); + const Tensor* a_tensor = context->Input(1); + const Tensor* b_tensor = context->Input(2); + const TensorShape& a_shape = a_tensor->Shape(); + const TensorShape& b_shape = b_tensor->Shape(); + const TensorShape& dy_shape = dy_tensor->Shape(); + + // output shapes shall match its corresponding inputs + Tensor* da_output_tensor = context->Output(0, a_shape); + Tensor* db_output_tensor = context->Output(1, b_shape); + if (!da_output_tensor && !db_output_tensor) + return Status::OK(); + + BinaryElementwisePreparation prepare; + ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(a_tensor, b_tensor, + // TODO: BinaryElementwiseBroadcastPrepare shall take dy_tensor as const Tensor*. + const_cast(dy_tensor), &prepare)); + const HipT* prepare_a_data = reinterpret_cast(prepare.lhs_tensor->template Data()); + const HipT* prepare_b_data = reinterpret_cast(prepare.rhs_tensor->template Data()); + const HipT* prepare_dy_data = reinterpret_cast(prepare.output_tensor->template Data()); + T* da_data = da_output_tensor ? da_output_tensor->template MutableData() : nullptr; + T* db_data = db_output_tensor ? db_output_tensor->template MutableData() : nullptr; + + switch (prepare.output_rank_or_simple_broadcast) { + case static_cast(SimpleBroadcast::NoBroadcast): + ImplDivGradSimple( + Stream(), + SimpleBroadcast::NoBroadcast, + prepare_a_data, + prepare_b_data, + prepare_dy_data, + dy_shape.Size(), + reinterpret_cast(da_data), + reinterpret_cast(db_data)); + break; + case static_cast(SimpleBroadcast::LeftScalar): { + T* temp_da_data = nullptr; + IAllocatorUniquePtr temp_da_allocator; + if (da_output_tensor) { + temp_da_allocator = GetScratchBuffer(dy_shape.Size()); + temp_da_data = temp_da_allocator.get(); + } + + ImplDivGradSimple( + Stream(), + SimpleBroadcast::LeftScalar, + prepare_a_data, + prepare_b_data, + prepare_dy_data, + dy_shape.Size(), + reinterpret_cast(temp_da_data), + reinterpret_cast(db_data)); + + if (da_output_tensor) { + std::vector a_output_dims = prepended_dimension_1(a_shape, dy_shape.NumDimensions()); + ReduceKernelShared( + temp_da_data, + dy_shape, + da_data, + TensorShape({}), + MIOPEN_REDUCE_TENSOR_ADD, + a_output_dims); + } + break; + } + case static_cast(SimpleBroadcast::RightScalar): { + T* temp_db_data = nullptr; + IAllocatorUniquePtr temp_db_allocator; + if (db_output_tensor) { + temp_db_allocator = GetScratchBuffer(dy_shape.Size()); + temp_db_data = temp_db_allocator.get(); + } + ImplDivGradSimple( + Stream(), + SimpleBroadcast::RightScalar, + prepare_a_data, + prepare_b_data, + prepare_dy_data, + dy_shape.Size(), + reinterpret_cast(da_data), + reinterpret_cast(temp_db_data)); + + if (db_output_tensor) { + std::vector b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions()); + ReduceKernelShared( + temp_db_data, + dy_shape, + db_data, + TensorShape({}), + MIOPEN_REDUCE_TENSOR_ADD, + b_output_dims); + } + break; + } + case static_cast(SimpleBroadcast::RightPerChannelBatch1): + case static_cast(SimpleBroadcast::RightPerChannelBatchN): { + T* temp_db_data = nullptr; + IAllocatorUniquePtr temp_db_allocator; + if (db_output_tensor) { + temp_db_allocator = GetScratchBuffer(dy_shape.Size()); + temp_db_data = temp_db_allocator.get(); + } + if (prepare.output_rank_or_simple_broadcast == static_cast(SimpleBroadcast::RightPerChannelBatch1)) { + // lhs(1,C,H) and rhs (C,1) + ImplDivGradRhsPerChannelBatch1( + Stream(), + prepare_a_data, + prepare_b_data, + prepare_dy_data, + dy_shape.Size(), + prepare.fdm_H, + reinterpret_cast(da_data), + reinterpret_cast(temp_db_data)); + } else { + // lhs(N,C,H) and rhs (C,1) + ImplDivGradRhsPerChannelBatchN( + Stream(), + prepare_a_data, + prepare_b_data, + prepare_dy_data, + dy_shape.Size(), + prepare.fdm_H, + prepare.fdm_C, + reinterpret_cast(da_data), + reinterpret_cast(temp_db_data)); + } + + if (db_output_tensor) { + std::vector b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions()); + ReduceKernelShared( + temp_db_data, + dy_shape, + db_data, + b_shape, + MIOPEN_REDUCE_TENSOR_ADD, + b_output_dims); + } + break; + } + default: { + bool need_reduce_da = da_output_tensor && a_shape.Size() != dy_shape.Size(); + bool need_reduce_db = db_output_tensor && b_shape.Size() != dy_shape.Size(); + IAllocatorUniquePtr temp_da_allocator, temp_db_allocator; + T* da_data_ref = nullptr; + if (da_output_tensor) + if (need_reduce_da) { + temp_da_allocator = GetScratchBuffer(dy_shape.Size()); + da_data_ref = temp_da_allocator.get(); + } else { + da_data_ref = da_data; + } + T* db_data_ref = nullptr; + if (db_output_tensor) + if (need_reduce_db) { + temp_db_allocator = GetScratchBuffer(dy_shape.Size()); + db_data_ref = temp_db_allocator.get(); + } else { + db_data_ref = db_data; + } + + ImplDivGrad( + Stream(), + prepare.output_rank_or_simple_broadcast, + &prepare.lhs_padded_strides, + prepare_a_data, + &prepare.rhs_padded_strides, + prepare_b_data, + prepare_dy_data, + dy_shape.Size(), + &prepare.fdm_output_strides, + reinterpret_cast(da_data_ref), + reinterpret_cast(db_data_ref)); + + if (need_reduce_da) { + std::vector a_output_dims = prepended_dimension_1(a_shape, dy_shape.NumDimensions()); + ReduceKernelShared( + da_data_ref, + dy_shape, + da_data, + a_shape, + MIOPEN_REDUCE_TENSOR_ADD, + a_output_dims); + } + + if (need_reduce_db) { + std::vector b_output_dims = prepended_dimension_1(b_shape, dy_shape.NumDimensions()); + ReduceKernelShared( + db_data_ref, + dy_shape, + db_data, + b_shape, + MIOPEN_REDUCE_TENSOR_ADD, + b_output_dims); + } + } + } + return Status::OK(); +} + +} // namespace rocm +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 5497666ada..74788a50ad 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -197,9 +197,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -234,7 +234,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 4816783f25..897f4b6425 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -79,20 +79,10 @@ provider_excluded_files = [ 'controlflow/scan.cc', 'controlflow/scan.h', 'cu_inc/common.cuh', - 'generator/constant_of_shape.cc', - 'generator/constant_of_shape.h', - 'generator/range.cc', - 'generator/range.h', - 'generator/range_impl.cu', - 'generator/range_impl.h', 'math/einsum_utils/einsum_auxiliary_ops.cc', 'math/einsum_utils/einsum_auxiliary_ops.h', 'math/einsum_utils/einsum_auxiliary_ops_diagonal.cu', 'math/einsum_utils/einsum_auxiliary_ops_diagonal.h', - 'math/cumsum.cc', - 'math/cumsum.h', - 'math/cumsum_impl.cu', - 'math/cumsum_impl.h', 'math/einsum.cc', 'math/einsum.h', 'math/gemm.cc', @@ -123,10 +113,6 @@ provider_excluded_files = [ 'nn/max_pool_with_index.h', 'nn/pool.cc', 'nn/pool.h', - 'nn/shrink.cc', - 'nn/shrink.h', - 'nn/shrink_impl.cu', - 'nn/shrink_impl.h', 'object_detection/non_max_suppression.cc', 'object_detection/non_max_suppression.h', 'object_detection/non_max_suppression_impl.cu', @@ -151,25 +137,7 @@ provider_excluded_files = [ 'shared_inc/fast_divmod.h', 'shared_inc/fpgeneric.h', 'shared_inc/integer_gemm.h', - 'tensor/compress.cc', - 'tensor/compress.h', - 'tensor/compress_impl.cu', - 'tensor/compress_impl.h', - 'tensor/eye_like.cc', - 'tensor/eye_like.h', - 'tensor/eye_like_impl.cu', - 'tensor/eye_like_impl.h', - 'tensor/flatten.cc', - 'tensor/flatten.h', - 'tensor/gather_elements.cc', - 'tensor/gather_elements.h', - 'tensor/gather_elements_impl.cu', - 'tensor/gather_elements_impl.h', 'tensor/gather_nd_impl.cu', - 'tensor/pad.cc', - 'tensor/pad.h', - 'tensor/pad_impl.cu', - 'tensor/pad_impl.h', 'tensor/quantize_linear.cc', 'tensor/quantize_linear.cu', 'tensor/quantize_linear.cuh', @@ -178,10 +146,6 @@ provider_excluded_files = [ 'tensor/resize.h', 'tensor/resize_impl.cu', 'tensor/resize_impl.h', - 'tensor/reverse_sequence.cc', - 'tensor/reverse_sequence.h', - 'tensor/reverse_sequence_impl.cu', - 'tensor/reverse_sequence_impl.h', 'tensor/transpose.cc', 'tensor/transpose.h', 'tensor/upsample.cc', @@ -234,9 +198,6 @@ training_ops_excluded_files = [ 'controlflow/wait.cc', 'controlflow/wait.h', 'math/div_grad.cc', - 'math/div_grad.h', - 'math/div_grad_impl.cu', - 'math/div_grad_impl.h', 'math/softmax_grad_impl.cu', 'math/softmax_grad.cc', 'nn/batch_norm_grad.cc', @@ -246,8 +207,6 @@ training_ops_excluded_files = [ 'optimizer/lamb.cc', 'reduction/reduction_all.cc', 'reduction/reduction_ops.cc', - 'tensor/gather_elements_grad.cc', - 'tensor/gather_elements_grad.h', 'tensor/gather_grad.cc', 'tensor/gather_grad_impl.cu', 'tensor/gather_grad_impl.h', diff --git a/tools/ci_build/github/pai/pai-excluded-tests.txt b/tools/ci_build/github/pai/pai-excluded-tests.txt index 8759def960..4100474341 100644 --- a/tools/ci_build/github/pai/pai-excluded-tests.txt +++ b/tools/ci_build/github/pai/pai-excluded-tests.txt @@ -12,13 +12,9 @@ OptimizerTest.LambOptimizerTestBaselineMixPrecision32_16 OptimizerTest.LambOptimizerTestScalarMixPrecision32_16 OptimizerTest.LambOptimizerTestScalarMixPrecision32_16_NoDefaultMaxNormClipping OptimizerTest.LambOptimizerTestLarge -CudaKernelTest.SoftmaxCrossEntropy_TinySizeTensor -CudaKernelTest.SoftmaxCrossEntropy_SmallSizeTensor -CudaKernelTest.SoftmaxCrossEntropy_MediumSizeTensor -CudaKernelTest.SoftmaxCrossEntropy_LargeSizeTensor CudaKernelTest.SparseSoftmaxCrossEntropy_LargeSizeTensor -CudaKernelTest.NegativeLogLikelihoodLoss_TinySizeTensor -CudaKernelTest.NegativeLogLikelihoodLoss_SmallSizeTensor +CudaKernelTest.NegativeLogLikelihoodLoss_TinySizeTensor +CudaKernelTest.NegativeLogLikelihoodLoss_SmallSizeTensor CudaKernelTest.NegativeLogLikelihoodLoss_MediumSizeTensor ReductionOpTest.ReductionVariationTest ReductionOpTest.ReduceL1_default_axes_keepdims @@ -76,39 +72,12 @@ ReductionOpTest.ReduceInfLogSumExp_double GatherOpTest.Gather_invalid_index_cpu Scatter.InvalidIndex LogSoftmaxOperator.LargeNumber -GatherElementsGrad.WithoutAxis -GatherElementsGrad.WithAxis -GatherElementsGrad.ThreeDimsWithAxis_0 -GatherElementsGrad.ThreeDimsWithAxis_2 -GatherElementsGrad.NegativeAxis -GatherElementsGrad.IndicesUpdatesDontMatch -GatherElementsGrad.ValidAxis -GatherElementsGrad.ValidNegativeIndex -GatherElementsGrad.SameUpdateWithoutAxis -GatherElementsGrad.SameUpdateWithAxis -GatherElementsGrad.SameUpdateWithNegativeAxis -GatherElementsGrad.SameUpdateWithoutAxisMLFloat16 -MathOpTest.Max_8_2inputbroadcast -MathOpTest.Less_broadcastBA -MathOpTest.Less_multidiretional_broadcastAB -MathOpTest.Less_multidiretional_broadcastBA -MathOpTest.Greater_broadcastBA -MathOpTest.Greater_multidiretional_broadcastAB -MathOpTest.Greater_multidiretional_broadcastBA MathOpTest.Pow_int64_float MathOpTest.Pow_int32_float MathOpTest.Pow_int64_double GradientCheckerTest.AddGrad GradientCheckerTest.SubGrad GradientCheckerTest.MulGrad -GradientCheckerTest.MatMulGrad GradientCheckerTest.ReduceMeanGrad GradientCheckerTest.ReduceL2Grad -GradientCheckerTest.SoftmaxCrossEntropyGrad -GradientCheckerTest.ExpandGrad GradientCheckerTest.DivGrad -GradientCheckerTest.GemmGrad -GradientCheckerTest.SplitGrad -GradientCheckerTest.SqueezeGrad -GradientCheckerTest.UnsqueezeGrad -GradientCheckerTest.ClipGrad