mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[ORTModule] Bugfix of torch.chunk's Custom Symbolic when chunks==1 (#12249)
handle custom chunk with chunks==1
This commit is contained in:
parent
a0074ba9bc
commit
3cdc6d7775
2 changed files with 35 additions and 28 deletions
|
|
@ -228,6 +228,8 @@ def squeeze(g, self, dim=None):
|
|||
# exporting to Split with SplitGrad as gradient graph.
|
||||
@register_symbolic("ConstantChunk", "prim")
|
||||
def prim_ConstantChunk(g, self, chunks, dim):
|
||||
if chunks == 1:
|
||||
return self
|
||||
input_shape_dim = g.op(
|
||||
"Gather", g.op("Shape", self), g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)), axis_i=0
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1326,51 +1326,56 @@ def test_gradient_correctness_reducesum(dim, keepdim):
|
|||
# Before PyTorch 1.11.0, the exporter will fail to register symbolic with non-empty domain.
|
||||
@pytest.mark.skipif(Version(torch.__version__) < Version("1.11.0"), reason="PyTorch 1.10 incompatible")
|
||||
@pytest.mark.parametrize("dim", [0, 1, -1])
|
||||
def test_gradient_correctness_chunk(dim):
|
||||
@pytest.mark.parametrize("chunks", [1, 3])
|
||||
def test_gradient_correctness_chunk(dim, chunks):
|
||||
class NeuralNetChunk(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(NeuralNetChunk, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input):
|
||||
return input.chunk(3, dim=self.dim)
|
||||
return input.chunk(chunks, dim=self.dim)
|
||||
|
||||
device = "cuda"
|
||||
pt_model = NeuralNetChunk(dim).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="chunk_model"))
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=(chunks > 1), onnx_prefix="chunk_model"))
|
||||
|
||||
def run_step(model, input):
|
||||
y1, y2, y3 = model(input)
|
||||
loss = y1.sum() + y2.sum() + y3.sum()
|
||||
results = model(input)
|
||||
loss = results[0].sum()
|
||||
for i in range(1, len(results)):
|
||||
loss = loss + results[i].sum()
|
||||
loss.backward()
|
||||
return y1, y2, y3
|
||||
return results
|
||||
|
||||
N, D, H = 16, 17, 18
|
||||
for _ in range(10):
|
||||
input = torch.rand((N, D, H), device=device, requires_grad=True)
|
||||
pt_y1, pt_y2, pt_y3 = run_step(pt_model, input)
|
||||
ort_y1, ort_y2, ort_y3 = run_step(ort_model, input)
|
||||
pt_input = torch.rand((N, D, H), device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_results = run_step(pt_model, pt_input)
|
||||
ort_results = run_step(ort_model, ort_input)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_y1, pt_y1)
|
||||
_test_helpers.assert_values_are_close(ort_y2, pt_y2)
|
||||
_test_helpers.assert_values_are_close(ort_y3, pt_y3)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
assert len(ort_results) == len(pt_results)
|
||||
for i in range(len(ort_results)):
|
||||
_test_helpers.assert_values_are_close(ort_results[i], pt_results[i])
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
|
||||
model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
has_split = False
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "Split":
|
||||
has_split = True
|
||||
break
|
||||
assert has_split
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
|
||||
if chunks > 1:
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
|
||||
assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
|
||||
model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
has_split = False
|
||||
for node in model.graph.node:
|
||||
if node.op_type == "Split":
|
||||
has_split = True
|
||||
break
|
||||
assert has_split
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx"))
|
||||
os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx"))
|
||||
|
||||
|
||||
# In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that
|
||||
|
|
|
|||
Loading…
Reference in a new issue