mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ONNX] remove outdated ImplicitCastType QA in onnx.rst (#81268)
Extend work from: https://github.com/pytorch/pytorch/pull/80596 This PR removes outdated QA of ImplicitCastType , as the coverage is greatly increased with the introduction of onnx shape inference and scalar type analysis. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81268 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
This commit is contained in:
parent
d68fed56ef
commit
69608fc598
1 changed files with 3 additions and 21 deletions
|
|
@ -570,27 +570,9 @@ Q: How to export models with primitive type inputs (e.g. int, float)?
|
|||
Q: Does ONNX support implicit scalar datatype casting?
|
||||
|
||||
No, but the exporter will try to handle that part. Scalars are exported as constant tensors.
|
||||
The exporter will try to figure out the right datatype for scalars. However when it is unable
|
||||
to do so, you will need to manually specify the datatype. This often happens with
|
||||
scripted models, where the datatypes are not recorded. For example::
|
||||
|
||||
class ImplicitCastType(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
# Exporter knows x is float32, will export "2" as float32 as well.
|
||||
y = x + 2
|
||||
# Currently the exporter doesn't know the datatype of y, so
|
||||
# "3" is exported as int64, which is wrong!
|
||||
return y + 3
|
||||
# To fix, replace the line above with:
|
||||
# return y + torch.tensor([3], dtype=torch.float32)
|
||||
|
||||
x = torch.tensor([1.0], dtype=torch.float32)
|
||||
torch.onnx.export(ImplicitCastType(), x, "implicit_cast.onnx",
|
||||
example_outputs=ImplicitCastType()(x))
|
||||
|
||||
We are trying to improve the datatype propagation in the exporter such that implicit casting
|
||||
is supported in more cases.
|
||||
The exporter will figure out the right data type for scalars. In rare cases when it is unable
|
||||
to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`.
|
||||
If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).
|
||||
|
||||
Q: Are lists of Tensors exportable to ONNX?
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue