[functorch] updated with some more decompositions

This commit is contained in:
Horace He 2021-12-01 08:16:23 +00:00 committed by Jon Janzen
parent c51aceb877
commit ab64065e5e
3 changed files with 2254 additions and 1 deletions

View file

@ -1,5 +1,6 @@
import torch
from torch import Tensor
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
from enum import Enum
aten = torch.ops.aten
@ -72,6 +73,24 @@ def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target:
x = self - target
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
@register_decomposition(aten.slice_backward)
def slice_backward_decomposition(grad_output: Tensor, input_sizes: List[int], dim: int, start: int, end: int, step:int):
grad_input = aten.new_zeros(grad_output, input_sizes)
return aten.slice_scatter(grad_input, grad_output, dim, start, end, step)
@register_decomposition(aten.select_backward)
def select_backward_decomposition(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
grad_input = aten.new_zeros(grad_output, input_sizes)
return aten.select_scatter(grad_input, grad_output, dim, index)
# Currently not numerically identical for bfloat16
# @register_decomposition(aten._softmax_backward_data)
# def _softmax_backward_data(grad_output: Tensor, output: Tensor, dim: int, input_dtype: int):
# grad_input = output * (grad_output - aten.sum(grad_output * output, dim=dim, keepdim=True))
# import pdb; pdb.set_trace()
# print(grad_input - aten._softmax_backward_data(grad_output, output.elem, dim, input_dtype))
# return grad_input
# @register_decomposition(aten._fused_dropout)
# def _fused_dropout_decomposition(input, p, generator=None):
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)

File diff suppressed because it is too large Load diff

View file

@ -81,7 +81,7 @@ def gen_data(special_op_lists, analysis_name):
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
noncomposite_ops = all_ops - composite_ops
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
from collections import defaultdict