mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
support print in ort eager mode (#9825)
* fix reshape implementation in eager mode * test code * update opgen script to support fallback to cpu * enhance the eager backend to support torch cpu fallback * add more testes * disable the printensor test for now, as we need to erge a PR to pytorch first
This commit is contained in:
parent
1e9e57df3e
commit
37bf46eb19
10 changed files with 1142 additions and 494 deletions
|
|
@ -21,6 +21,24 @@ for schema in defs.get_all_schemas_with_history():
|
|||
onnx_ops[key].since_version < schema.since_version:
|
||||
onnx_ops[key] = schema
|
||||
|
||||
def convert_to_aten_type(onnx_type_strs):
|
||||
type_map = {'tensor(float16)' : 'at::kHalf',
|
||||
'tensor(float)' : 'at::kFloat',
|
||||
'tensor(double)' : 'at::kDouble',
|
||||
'tensor(bfloat16)' : 'at::kBFloat16',
|
||||
'tensor(int32)' : 'at::kInt',
|
||||
'tensor(int16)' : 'at::kShort',
|
||||
'tensor(int8)' : 'at::kByte',
|
||||
'tensor(int64)' : 'at::kLong',
|
||||
'tensor(bool)' : 'at::kBool',
|
||||
}
|
||||
result = set({})
|
||||
for onnx_type in onnx_type_strs:
|
||||
# ONNX has more types, like tensor(string), ignore those types at this momemnt
|
||||
if onnx_type in type_map:
|
||||
result.add(type_map[onnx_type])
|
||||
return result
|
||||
|
||||
with open(out_file, 'wt') as fp:
|
||||
def write(s): fp.write(s)
|
||||
def writeline(s = ''): fp.write(s + '\n')
|
||||
|
|
@ -54,9 +72,17 @@ with open(out_file, 'wt') as fp:
|
|||
|
||||
writeline('):')
|
||||
write(f' super().__init__(\'{schema.name}\', {len(schema.outputs)}')
|
||||
|
||||
writeline(',')
|
||||
write(' ')
|
||||
input_types = []
|
||||
for input in schema.inputs:
|
||||
write(f', {input.name}')
|
||||
input_types.append(convert_to_aten_type(input.types))
|
||||
write(str(input_types))
|
||||
if len(schema.inputs) > 0:
|
||||
writeline(',')
|
||||
input_names = ','.join([input.name for input in schema.inputs])
|
||||
write(f' {input_names}')
|
||||
|
||||
|
||||
if len(schema.attributes) > 0:
|
||||
writeline(',')
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from opgen.generator import \
|
|||
ORTGen as ORTGen, \
|
||||
ONNXOp as ONNXOp, \
|
||||
SignatureOnly as SignatureOnly, \
|
||||
MakeFallthrough as MakeFallthrough
|
||||
MakeTorchFallback as MakeTorchFallback
|
||||
|
||||
from opgen.onnxops import *
|
||||
|
||||
|
|
@ -12,17 +12,17 @@ kMSDomain = 'onnxruntime::kMSDomain'
|
|||
|
||||
class ReluGrad(ONNXOp):
|
||||
def __init__(self, dY, X):
|
||||
super().__init__('ReluGrad', 1, dY, X)
|
||||
super().__init__('ReluGrad', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], dY, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
class Gelu(ONNXOp):
|
||||
def __init__(self, X):
|
||||
super().__init__('Gelu', 1, X)
|
||||
super().__init__('Gelu', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
class GeluGrad(ONNXOp):
|
||||
def __init__(self, dY, X):
|
||||
super().__init__('GeluGrad', 1, dY, X)
|
||||
super().__init__('GeluGrad', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], dY, X)
|
||||
self.domain = kMSDomain
|
||||
|
||||
ops = {
|
||||
|
|
@ -33,6 +33,7 @@ ops = {
|
|||
'aten::copy_': SignatureOnly(),
|
||||
'aten::_reshape_alias': SignatureOnly(),
|
||||
'aten::view': SignatureOnly(),
|
||||
'aten::_copy_from_and_resize' : SignatureOnly(),
|
||||
|
||||
'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'),
|
||||
'aten::t': Transpose('self'),
|
||||
|
|
@ -48,7 +49,20 @@ ops = {
|
|||
'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd'
|
||||
'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'),
|
||||
'aten::gelu' : Gelu('self'),
|
||||
'aten::gelu_backward' : GeluGrad('grad', 'self')
|
||||
'aten::gelu_backward' : GeluGrad('grad', 'self'),
|
||||
'aten::max' : ReduceMax('self', keepdims=1),
|
||||
'aten::min' : ReduceMin('self', keepdims=1),
|
||||
|
||||
'aten::ne.Scalar':MakeTorchFallback(),
|
||||
'aten::ne.Scalar_out': MakeTorchFallback(),
|
||||
'aten::ne.Tensor_out': MakeTorchFallback(),
|
||||
'aten::eq.Tensor': MakeTorchFallback(),
|
||||
'aten::eq.Tensor_out':MakeTorchFallback(),
|
||||
'aten::bitwise_and.Tensor_out' : MakeTorchFallback(),
|
||||
'aten::masked_select' : MakeTorchFallback(),
|
||||
'aten::as_strided' : MakeTorchFallback(),
|
||||
'aten::_local_scalar_dense' : MakeTorchFallback(),
|
||||
'aten::gt.Scalar_out' : MakeTorchFallback(),
|
||||
}
|
||||
|
||||
for binary_op, onnx_op in {
|
||||
|
|
@ -64,7 +78,7 @@ for unary_op in [
|
|||
'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos',
|
||||
'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round',
|
||||
'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero',
|
||||
'sign', 'min', 'max', 'hardsigmoid', 'isinf', 'det']:
|
||||
'sign', 'hardsigmoid', 'isinf', 'det']:
|
||||
aten_name = f'aten::{unary_op}'
|
||||
onnx_op = onnx_ops[unary_op]('self')
|
||||
ops[aten_name] = onnx_op
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from opgen.generator import \
|
|||
ORTGen as ORTGen, \
|
||||
ONNXOp as ONNXOp, \
|
||||
SignatureOnly as SignatureOnly, \
|
||||
MakeFallthrough as MakeFallthrough
|
||||
MakeTorchFallback as MakeTorchFallback
|
||||
|
||||
from opgen.onnxops import *
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class ONNXOp:
|
|||
def __init__(self,
|
||||
name: str,
|
||||
outputs: int,
|
||||
input_types: List,
|
||||
*inputs: Union[str, Outputs],
|
||||
**attributes: Optional[Union[str, Outputs]]):
|
||||
self.name = name
|
||||
|
|
@ -55,6 +56,7 @@ class ONNXOp:
|
|||
self.inputs = inputs
|
||||
self.attributes = attributes
|
||||
self.domain = None
|
||||
self.input_types = input_types
|
||||
|
||||
def eval(self, ctx: ONNXOpEvalContext):
|
||||
evaluated_inputs = []
|
||||
|
|
@ -71,10 +73,10 @@ class ONNXOp:
|
|||
return self.outputs
|
||||
|
||||
class SignatureOnly(ONNXOp):
|
||||
def __init__(self): super().__init__(None, 0)
|
||||
def __init__(self): super().__init__(None, 0, [])
|
||||
|
||||
class MakeFallthrough(ONNXOp):
|
||||
def __init__(self): super().__init__(None, 0)
|
||||
class MakeTorchFallback(ONNXOp):
|
||||
def __init__(self): super().__init__(None, 0, [])
|
||||
|
||||
class FunctionGenerationError(NotImplementedError):
|
||||
def __init__(self, cpp_func: ast.FunctionDecl, message: str):
|
||||
|
|
@ -88,13 +90,13 @@ class MappedOpFunction:
|
|||
onnx_op: ONNXOp,
|
||||
cpp_func: ast.FunctionDecl,
|
||||
signature_only: bool,
|
||||
make_fallthrough: bool):
|
||||
make_torch_fallback: bool):
|
||||
self.op_namespace = op_namespace
|
||||
self.mapped_op_name = mapped_op_name
|
||||
self.onnx_op = onnx_op
|
||||
self.cpp_func = cpp_func
|
||||
self.signature_only = signature_only
|
||||
self.make_fallthrough = make_fallthrough
|
||||
self.make_torch_fallback = make_torch_fallback
|
||||
|
||||
class ORTGen:
|
||||
_mapped_ops: Dict[str, ONNXOp]
|
||||
|
|
@ -126,9 +128,6 @@ class ORTGen:
|
|||
del self._mapped_ops[mapped_func.mapped_op_name]
|
||||
generated_funcs.append(mapped_func)
|
||||
|
||||
if mapped_func.make_fallthrough:
|
||||
continue
|
||||
|
||||
ns = mapped_func.op_namespace
|
||||
if current_ns and current_ns != ns:
|
||||
current_ns = None
|
||||
|
|
@ -173,6 +172,7 @@ class ORTGen:
|
|||
writer.writeline('#include "python/onnxruntime_pybind_state_common.h"')
|
||||
writer.writeline()
|
||||
writer.writeline('#include <torch/extension.h>')
|
||||
writer.writeline('#include <ATen/native/CPUFallback.h>')
|
||||
writer.writeline()
|
||||
writer.writeline('#include <core/providers/dml/OperatorAuthorHelper/Attributes.h>')
|
||||
writer.writeline()
|
||||
|
|
@ -206,6 +206,27 @@ class ORTGen:
|
|||
writer.pop_indent()
|
||||
writer.write(')')
|
||||
|
||||
def _write_cpu_fall_back(self,
|
||||
writer: writer.SourceWriter,
|
||||
mapped_func: MappedOpFunction):
|
||||
onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func
|
||||
#return at::native::call_fallback_fn<
|
||||
# &at::native::cpu_fallback,
|
||||
# ATEN_OP(eq_Tensor)>::call(self, other);
|
||||
writer.writeline('return native::call_fallback_fn<')
|
||||
writer.push_indent()
|
||||
writer.writeline('&native::cpu_fallback,')
|
||||
writer.write('ATEN_OP(')
|
||||
writer.write(cpp_func.identifier.value)
|
||||
writer.write(')>::call(')
|
||||
|
||||
params = ', '.join([p.member.identifier.value for p \
|
||||
in cpp_func.parameters if p.member.identifier])
|
||||
writer.write(params)
|
||||
writer.writeline(');')
|
||||
writer.pop_indent()
|
||||
|
||||
|
||||
def _write_function_body(
|
||||
self,
|
||||
writer: writer.SourceWriter,
|
||||
|
|
@ -214,6 +235,15 @@ class ORTGen:
|
|||
|
||||
assert(len(cpp_func.parameters) > 0)
|
||||
|
||||
# Debug Logging
|
||||
log_params = ', '.join([p.member.identifier.value for p \
|
||||
in cpp_func.parameters if p.member.identifier])
|
||||
writer.writeline(f'ORT_LOG_FN({log_params});')
|
||||
writer.writeline()
|
||||
|
||||
if mapped_func.make_torch_fallback:
|
||||
return self._write_cpu_fall_back(writer, mapped_func)
|
||||
|
||||
return_alias_info = self._get_alias_info(cpp_func.torch_func.return_type) if cpp_func.torch_func else None
|
||||
if return_alias_info and not return_alias_info.is_writable:
|
||||
return_alias_info = None
|
||||
|
|
@ -224,11 +254,32 @@ class ORTGen:
|
|||
onnx_op.eval(ctx)
|
||||
ctx.prepare_outputs()
|
||||
|
||||
# Debug Logging
|
||||
log_params = ', '.join([p.member.identifier.value for p \
|
||||
in cpp_func.parameters if p.member.identifier])
|
||||
writer.writeline(f'ORT_LOG_FN({log_params});')
|
||||
writer.writeline()
|
||||
# generate the type check
|
||||
need_type_check = False
|
||||
if not self._custom_ops:
|
||||
for onnx_op_index, onnx_op in enumerate(ctx.ops):
|
||||
for op_input in onnx_op.inputs:
|
||||
if not isinstance(op_input, Outputs):
|
||||
need_type_check = True
|
||||
break
|
||||
if need_type_check:
|
||||
writer.write('if (')
|
||||
i = 0
|
||||
for onnx_op_index, onnx_op in enumerate(ctx.ops):
|
||||
for idx, op_input in enumerate(onnx_op.inputs):
|
||||
if isinstance(op_input, Outputs):
|
||||
continue
|
||||
writer.writeline(' || ' if i > 0 else '')
|
||||
if i == 0:
|
||||
writer.push_indent()
|
||||
cpp_param = cpp_func.get_parameter(op_input)
|
||||
supported_types = ','.join([type for type in onnx_op.input_types[idx]])
|
||||
writer.write('!IsSupportedType(%s, {%s})' % (cpp_param.identifier.value, supported_types))
|
||||
i += 1
|
||||
writer.writeline(') {')
|
||||
self._write_cpu_fall_back(writer, mapped_func)
|
||||
writer.pop_indent()
|
||||
writer.writeline('}')
|
||||
|
||||
# Fetch the ORT invoker from an at::Tensor.device()
|
||||
# FIXME: find the first at::Tensor param anywhere in the signature
|
||||
|
|
@ -258,10 +309,10 @@ class ORTGen:
|
|||
continue
|
||||
# See if this input is aliased as an in-place tensor
|
||||
cpp_param = cpp_func.get_parameter(op_input)
|
||||
if return_alias_info and cpp_param and \
|
||||
len(cpp_param.torch_param) == 1 and \
|
||||
self._get_alias_info(cpp_param.torch_param[0]) == return_alias_info:
|
||||
in_place_param = cpp_param
|
||||
if return_alias_info and cpp_param:
|
||||
for torch_p in cpp_param.torch_param:
|
||||
if self._get_alias_info(torch_p) == return_alias_info:
|
||||
in_place_param = cpp_param
|
||||
|
||||
writer.write(f'auto ort_input_{op_input} = ')
|
||||
writer.writeline(f'create_ort_value(invoker, {op_input});')
|
||||
|
|
@ -367,18 +418,15 @@ class ORTGen:
|
|||
for mapped_func in generated_funcs:
|
||||
cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func
|
||||
|
||||
if mapped_func.make_fallthrough:
|
||||
reg_function_arg = 'torch::CppFunction::makeFallthrough()'
|
||||
|
||||
if mapped_func.op_namespace:
|
||||
reg_function_arg = f'{mapped_func.op_namespace}::'
|
||||
else:
|
||||
if mapped_func.op_namespace:
|
||||
reg_function_arg = f'{mapped_func.op_namespace}::'
|
||||
else:
|
||||
reg_function_arg = ''
|
||||
reg_function_arg += cpp_func.identifier.value
|
||||
reg_function_arg = ''
|
||||
reg_function_arg += cpp_func.identifier.value
|
||||
|
||||
writer.write('m.impl(')
|
||||
if not mapped_func.make_fallthrough:
|
||||
reg_function_arg = f'TORCH_FN({reg_function_arg})'
|
||||
reg_function_arg = f'TORCH_FN({reg_function_arg})'
|
||||
|
||||
writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});')
|
||||
|
||||
|
|
@ -427,7 +475,7 @@ class ORTGen:
|
|||
op_namespace = None
|
||||
op_namewithoutnamespace = op_name
|
||||
|
||||
cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '__')
|
||||
cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '_')
|
||||
|
||||
onnx_op = self._mapped_ops.get(op_name)
|
||||
if not onnx_op:
|
||||
|
|
@ -439,7 +487,7 @@ class ORTGen:
|
|||
onnx_op,
|
||||
cpp_func,
|
||||
isinstance(onnx_op, SignatureOnly),
|
||||
isinstance(onnx_op, MakeFallthrough))
|
||||
isinstance(onnx_op, MakeTorchFallback))
|
||||
|
||||
def _parse_function_decls(self, cpp_parser: parser.CPPParser):
|
||||
# Parse the C++ declarations
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#include "ort_aten.h"
|
||||
#include "ort_tensor.h"
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <ATen/native/CPUFallback.h>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
|
@ -158,13 +160,26 @@ onnx::AttributeProto create_ort_attribute(
|
|||
return attr;
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
|
||||
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types){
|
||||
return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end();
|
||||
}
|
||||
|
||||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types){
|
||||
return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() ||
|
||||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
|
||||
}
|
||||
|
||||
//#pragma endregion
|
||||
|
||||
//#pragma region Hand-Implemented ATen Ops
|
||||
|
||||
namespace aten {
|
||||
|
||||
at::Tensor empty__memory_format(
|
||||
at::Tensor empty_memory_format(
|
||||
at::IntArrayRef size,
|
||||
// *,
|
||||
c10::optional<at::ScalarType> dtype_opt,
|
||||
|
|
@ -186,7 +201,7 @@ at::Tensor empty__memory_format(
|
|||
ort_scalar_type_from_aten(*dtype_opt),
|
||||
size.vec(),
|
||||
&ot);
|
||||
|
||||
|
||||
return aten_tensor_from_ort(
|
||||
std::move(ot),
|
||||
at::TensorOptions()
|
||||
|
|
@ -255,6 +270,29 @@ at::Tensor view(const at::Tensor& self, at::IntArrayRef size) {
|
|||
self.options());
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
|
||||
switch (dtype){
|
||||
case at::kFloat:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
|
||||
case at::kDouble:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
case at::kHalf:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
|
||||
case at::kBFloat16:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
|
||||
case at::kInt:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
case at::kShort:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT16;
|
||||
case at::kLong:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_INT64;
|
||||
case at::kBool:
|
||||
return ONNX_NAMESPACE::TensorProto_DataType_BOOL;
|
||||
default:
|
||||
ORT_THROW("Unsupport aten scalar type: ", dtype);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor& copy_(
|
||||
at::Tensor& self,
|
||||
const at::Tensor& src,
|
||||
|
|
@ -269,8 +307,45 @@ at::Tensor& copy_(
|
|||
: src.device());
|
||||
const auto ort_src = create_ort_value(invoker, src);
|
||||
auto ort_self = create_ort_value(invoker, self);
|
||||
if (self.scalar_type() != src.scalar_type()){
|
||||
// invoke cast first
|
||||
std::vector<OrtValue> ort_cast_output(1);
|
||||
onnxruntime::NodeAttributes attrs(1);
|
||||
attrs["to"] = create_ort_attribute(
|
||||
"to", (int64_t)GetONNXTensorProtoDataType(self.scalar_type()), at::kLong);
|
||||
|
||||
copy(invoker, ort_src, ort_self);
|
||||
auto status = invoker.Invoke("Cast", {
|
||||
std::move(ort_src),
|
||||
}, ort_cast_output, &attrs);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
copy(invoker, ort_cast_output[0], ort_self);
|
||||
}
|
||||
else{
|
||||
copy(invoker, ort_src, ort_self);
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
at::Tensor _copy_from_and_resize(
|
||||
const at::Tensor& self,
|
||||
const at::Tensor& dst){
|
||||
ORT_LOG_FN(self, dst);
|
||||
|
||||
assert_tensor_supported(self);
|
||||
assert_tensor_supported(dst);
|
||||
|
||||
auto& invoker = GetORTInvoker(self.device().type() == at::kORT
|
||||
? self.device()
|
||||
: dst.device());
|
||||
const auto ort_self = create_ort_value(invoker, self);
|
||||
auto ort_dst = create_ort_value(invoker, dst);
|
||||
|
||||
copy(invoker, ort_self, ort_dst);
|
||||
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,5 +73,11 @@ onnx::AttributeProto create_ort_attribute(
|
|||
const char* name,
|
||||
const char* value);
|
||||
|
||||
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace torch_ort
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <core/framework/ort_value.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace torch_ort {
|
||||
namespace eager {
|
||||
|
|
|
|||
|
|
@ -120,6 +120,13 @@ class OrtEPTests(unittest.TestCase):
|
|||
ort_device = torch_ort.device(1)
|
||||
assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
|
||||
|
||||
#disable the print test for now as we need to merge a PR to pytorch first.
|
||||
#def test_print(self):
|
||||
# x = torch.ones(1, 2)
|
||||
# ort_x = x.to('ort')
|
||||
# with OutputGrabber() as out:
|
||||
# print(ort_x)
|
||||
# assert "tensor([[1., 1.]], device='ort:0')" in out.capturedtext
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -25,6 +25,14 @@ class OrtOpTests(unittest.TestCase):
|
|||
assert torch.allclose(
|
||||
torch.add(cpu_ones, cpu_ones, alpha=2.5),
|
||||
torch.add(ort_ones, ort_ones, alpha=2.5).cpu())
|
||||
|
||||
def test_mul_bool(self):
|
||||
device = self.get_device()
|
||||
cpu_ones = torch.ones(3, 3, dtype=bool)
|
||||
ort_ones = cpu_ones.to(device)
|
||||
assert torch.allclose(
|
||||
torch.mul(cpu_ones, cpu_ones),
|
||||
torch.mul(ort_ones, ort_ones).cpu())
|
||||
|
||||
def test_add_(self):
|
||||
device = self.get_device()
|
||||
|
|
@ -68,6 +76,20 @@ class OrtOpTests(unittest.TestCase):
|
|||
cpu_ans = cpu_ones * 4
|
||||
ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0)
|
||||
assert torch.allclose(cpu_ans, ort_ans.cpu())
|
||||
|
||||
def test_max(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
ort_tensor = cpu_tensor.to('ort')
|
||||
y = ort_tensor.max()
|
||||
x = cpu_tensor.max()
|
||||
assert torch.allclose(x, y.cpu())
|
||||
|
||||
def test_min(self):
|
||||
cpu_tensor = torch.rand(10, 10)
|
||||
ort_tensor = cpu_tensor.to('ort')
|
||||
y = ort_tensor.min()
|
||||
x = cpu_tensor.min()
|
||||
assert torch.allclose(x, y.cpu())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue