mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Fix Triton Compile Error for Codegened Dropout Code (#17899)
This commit is contained in:
parent
9d07ca3621
commit
fa0a79a921
1 changed files with 2 additions and 2 deletions
|
|
@ -280,7 +280,7 @@ class TritonCodegen(NodeVisitor):
|
|||
"Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n",
|
||||
"Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n",
|
||||
"Log": "{indent}{o0} = tl.log({i0})\n",
|
||||
"DropoutGrad": "{indent}p = 1 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n",
|
||||
"DropoutGrad": "{indent}p = 1.0 - {i2}\n{indent}{o0} = tl.where({i1}, {i0} / p, 0.0)\n",
|
||||
"Identity": "{indent}{o0} = {i0}\n",
|
||||
}
|
||||
|
||||
|
|
@ -420,7 +420,7 @@ class TritonCodegen(NodeVisitor):
|
|||
offset_str = f"{node.global_offset} + " if node.global_offset != sympy.Integer(0) else ""
|
||||
offset_str += self._get_offset_mask(node.offset_calc, node.inputs[0].name)[0]
|
||||
code_buffer += (
|
||||
f"{space_indent}p = 1 - {p_var_name}\n"
|
||||
f"{space_indent}p = 1.0 - {p_var_name}\n"
|
||||
f"{space_indent}random = tl.rand(t_seed_cuda, {offset_str})\n"
|
||||
f"{space_indent}{mask_var_name} = random < p\n"
|
||||
f"{space_indent}{output_var_name} = tl.where({mask_var_name}, {input_var_name} / p, 0.0)\n"
|
||||
|
|
|
|||
Loading…
Reference in a new issue