[ORTModule] Bugfix of torch.chunk's Custom Symbolic when chunks==1 (#12249)

handle custom chunk with chunks==1
This commit is contained in:
Vincent Wang 2022-07-20 17:00:41 +08:00 committed by GitHub
parent a0074ba9bc
commit 3cdc6d7775
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 28 deletions

View file

@ -228,6 +228,8 @@ def squeeze(g, self, dim=None):
# exporting to Split with SplitGrad as gradient graph.
@register_symbolic("ConstantChunk", "prim")
def prim_ConstantChunk(g, self, chunks, dim):
if chunks == 1:
return self
input_shape_dim = g.op(
"Gather", g.op("Shape", self), g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)), axis_i=0
)

View file

@ -1326,51 +1326,56 @@ def test_gradient_correctness_reducesum(dim, keepdim):
# Before PyTorch 1.11.0, the exporter will fail to register symbolic with non-empty domain.
@pytest.mark.skipif(Version(torch.__version__) < Version("1.11.0"), reason="PyTorch 1.10 incompatible")
@pytest.mark.parametrize("dim", [0, 1, -1])
def test_gradient_correctness_chunk(dim):
@pytest.mark.parametrize("chunks", [1, 3])
def test_gradient_correctness_chunk(dim, chunks):
class NeuralNetChunk(torch.nn.Module):
def __init__(self, dim):
super(NeuralNetChunk, self).__init__()
self.dim = dim
def forward(self, input):
return input.chunk(3, dim=self.dim)
return input.chunk(chunks, dim=self.dim)
device = "cuda"
pt_model = NeuralNetChunk(dim).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="chunk_model"))
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=(chunks > 1), onnx_prefix="chunk_model"))
def run_step(model, input):
y1, y2, y3 = model(input)
loss = y1.sum() + y2.sum() + y3.sum()
results = model(input)
loss = results[0].sum()
for i in range(1, len(results)):
loss = loss + results[i].sum()
loss.backward()
return y1, y2, y3
return results
N, D, H = 16, 17, 18
for _ in range(10):
input = torch.rand((N, D, H), device=device, requires_grad=True)
pt_y1, pt_y2, pt_y3 = run_step(pt_model, input)
ort_y1, ort_y2, ort_y3 = run_step(ort_model, input)
pt_input = torch.rand((N, D, H), device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_results = run_step(pt_model, pt_input)
ort_results = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_y1, pt_y1)
_test_helpers.assert_values_are_close(ort_y2, pt_y2)
_test_helpers.assert_values_are_close(ort_y3, pt_y3)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
assert len(ort_results) == len(pt_results)
for i in range(len(ort_results)):
_test_helpers.assert_values_are_close(ort_results[i], pt_results[i])
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
has_split = False
for node in model.graph.node:
if node.op_type == "Split":
has_split = True
break
assert has_split
os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
if chunks > 1:
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
has_split = False
for node in model.graph.node:
if node.op_type == "Split":
has_split = True
break
assert has_split
os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that