Changes to support TNLRV3 fine-tuning (#4639)

* added reducesumlogexp gradient
added test
fixed type mismatch when calling cudnnreduce kernel
fixed python frontend to remove redundant states to match pytorch state dict
This commit is contained in:
Tixxx 2020-07-29 19:17:59 -07:00 committed by GitHub
parent d8f3e46d45
commit f90a2d46ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 128 additions and 136 deletions

View file

@ -395,10 +395,12 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
}
CudnnReduceDescriptor reduce_desc;
if (std::is_same<T, MLFloat16>::value)
if (std::is_same<T, MLFloat16>::value) {
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, CudnnTensor::GetDataType<float>(), ReduceTensorIndices));
else
} else {
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, ReduceTensorIndices));
}
const auto one = Consts<CudaT>::One;
const auto zero = Consts<CudaT>::Zero;
CudnnTensor input_tensor;
@ -437,7 +439,11 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
} else {
// Reduce max -- Max/Min will output indices data
CudnnReduceDescriptor reduce_max_desc;
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES));
cudnnDataType_t cudnn_reduce_max_type = cudnn_type_X;
if((std::is_same<T, MLFloat16>::value)) {
cudnn_reduce_max_type = CUDNN_DATA_FLOAT;
}
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_reduce_max_type, CUDNN_REDUCE_TENSOR_NO_INDICES));
size_t indices_bytes_max = 0;
CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cuda_ep.PerThreadCudnnHandle(), reduce_max_desc,
input_tensor, output_tensor, &indices_bytes_max));

View file

@ -905,6 +905,40 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
return result;
}
// Reference computation is pytorch's logsumexp_backward
// dx_i = exp(xi) / reduceSum(exp(xi))
// O(0) = log(reduceSum(exp(xi)))
// Self_Sub_Result = I(0) - O(0) = xi - log(sum(exp(xi))) = log( xi / reduceSum(exp(xi)))
// Gradient computation is re-using output and input from forward op, can be a recomputation candidate.
IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;
if (attributes.find("keepdims") != attributes.end() &&
attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}
ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
}
else {
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
}
result.push_back(NodeDef("Exp", {IA("Self_Sub_Result")}, {IA("Self_Sub_Result_Exp")}));
result.push_back(NodeDef("Mul", {IA("Self_Sub_Result_Exp"), grad}, {GI(0)}));
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();

View file

@ -25,6 +25,7 @@ DECLARE_GRADIENT_BUILDER(GetMulGradient)
DECLARE_GRADIENT_BUILDER(GetDivGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
DECLARE_GRADIENT_BUILDER(GetPowGradient)
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)

View file

@ -51,6 +51,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Pow", GetPowGradient);
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);

View file

@ -629,6 +629,8 @@ class ORTTrainer():
self.world_size = world_size
self.use_mixed_precision = use_mixed_precision
self.original_model_state_keys = list(model.state_dict().keys()) if hasattr(model, 'state_dict') else []
self.session = None
self.device_ = device
self.gradient_accumulation_steps = gradient_accumulation_steps
@ -773,7 +775,11 @@ class ORTTrainer():
if n.name not in torch_state:
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))
return torch_state
# Need to remove redundant initializers and name suffices to map back to original torch state names
torch_state_to_return = {key: torch_state[key] for key in self.original_model_state_keys if key in torch_state} \
if self.original_model_state_keys \
else torch_state
return torch_state_to_return
def load_state_dict(self, state_dict, strict=False):
# Note: It may happen ONNX model has not yet been initialized

View file

@ -7,6 +7,9 @@
namespace onnxruntime {
namespace test {
using TestDataVector = std::tuple<std::vector<std::vector<TensorInfo>>, // Input data
std::vector<std::vector<TensorInfo>>, // output data
std::vector<std::vector<onnx::AttributeProto>>>; //attribute
class GradientOpTester : public OpTester {
public:
@ -39,3 +42,4 @@ class GradientOpTester : public OpTester {
};
} // namespace test
} // namespace onnxruntime

View file

@ -38,6 +38,70 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
#define EXPECT_IS_TINY(max_error) \
EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f)
static void RunReductionTests(const OpDef& op_def) {
TestDataVector test_data(
// Input X
{
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
},
// Input Y
{
{{1, 1, 1}},
{{}},
{{1, 3, 1}},
{{2}},
{{4, 1, 2}},
{{4, 3}},
{{4, 1, 2}},
{{4}}
},
// Attributes
{
// default
{},
// axes = [0, 1, 2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [0, 2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
// axes = [0, 1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [1], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [-2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [-2, -1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
MakeAttribute("keepdims", int64_t(0))}
});
GradientChecker<float, float, float> gradient_checker;
float max_error;
for (size_t i = 0; i < std::get<0>(test_data).size(); i++) {
max_error = 0;
gradient_checker.ComputeGradientError(op_def, std::get<0>(test_data)[i],
std::get<1>(test_data)[i], &max_error,
std::get<2>(test_data)[i]);
EXPECT_IS_TINY(max_error);
}
}
template <typename T>
void GenerateRandomDataWithOneHot(
std::vector<std::vector<float>>& x_datas,
@ -426,149 +490,24 @@ TEST(GradientCheckerTest, GemmGrad) {
}
TEST(GradientCheckerTest, ReduceMeanGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceMean", kOnnxDomain, 11};
// default
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
EXPECT_IS_TINY(max_error);
}
// TODO: Fix forward kernel behavior for default axes
// default axes, keepdims = 0
/*
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
*/
// axes = [0, 1, 2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
// axes = [0, 2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
EXPECT_IS_TINY(max_error);
}
// axes = [0, 1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
// axes = [1], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}
// axes = [2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
// axes = [-2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}
// axes = [-2, -1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
RunReductionTests(op_def);
}
TEST(GradientCheckerTest, ReduceSumGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceSum", kOnnxDomain, 11};
// default
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
EXPECT_IS_TINY(max_error);
}
RunReductionTests(op_def);
}
// axes = [0, 1, 2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
// axes = [0, 2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
EXPECT_IS_TINY(max_error);
}
// axes = [0, 1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
// axes = [1], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}
// axes = [2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
// axes = [-2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}
// axes = [-1, -3], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-1, -3}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
RunReductionTests(op_def);
}
#ifndef USE_CUDA
@ -1960,3 +1899,4 @@ TEST(GradientCheckerTest, ExpandGrad) {
} // namespace onnxruntime
#endif // NDEBUG