mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[functorch] updated with some more decompositions
This commit is contained in:
parent
c51aceb877
commit
ab64065e5e
3 changed files with 2254 additions and 1 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
|
||||
from enum import Enum
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
@ -72,6 +73,24 @@ def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target:
|
|||
x = self - target
|
||||
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
|
||||
|
||||
@register_decomposition(aten.slice_backward)
|
||||
def slice_backward_decomposition(grad_output: Tensor, input_sizes: List[int], dim: int, start: int, end: int, step:int):
|
||||
grad_input = aten.new_zeros(grad_output, input_sizes)
|
||||
return aten.slice_scatter(grad_input, grad_output, dim, start, end, step)
|
||||
|
||||
@register_decomposition(aten.select_backward)
|
||||
def select_backward_decomposition(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
|
||||
grad_input = aten.new_zeros(grad_output, input_sizes)
|
||||
return aten.select_scatter(grad_input, grad_output, dim, index)
|
||||
|
||||
# Currently not numerically identical for bfloat16
|
||||
# @register_decomposition(aten._softmax_backward_data)
|
||||
# def _softmax_backward_data(grad_output: Tensor, output: Tensor, dim: int, input_dtype: int):
|
||||
# grad_input = output * (grad_output - aten.sum(grad_output * output, dim=dim, keepdim=True))
|
||||
# import pdb; pdb.set_trace()
|
||||
# print(grad_input - aten._softmax_backward_data(grad_output, output.elem, dim, input_dtype))
|
||||
# return grad_input
|
||||
|
||||
# @register_decomposition(aten._fused_dropout)
|
||||
# def _fused_dropout_decomposition(input, p, generator=None):
|
||||
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
|
||||
|
|
|
|||
2234
functorch/op_analysis/decompositions
Normal file
2234
functorch/op_analysis/decompositions
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -81,7 +81,7 @@ def gen_data(special_op_lists, analysis_name):
|
|||
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
|
||||
noncomposite_ops = all_ops - composite_ops
|
||||
|
||||
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
||||
ops = yaml.load(open('../../pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
|
||||
|
||||
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
|
||||
from collections import defaultdict
|
||||
|
|
|
|||
Loading…
Reference in a new issue