mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Reenable ConstantOfShape TypeTests (#15910)
ConstantOfShape TypeTests were previously broken due to a bug where the case for the uint64 test was being passed an int64_data_size. Changing the data type to uint64_data_size fixes the bug. TensorProto Int8 and Int16 tests are reenabled since they are now passing.
This commit is contained in:
parent
e5189330d5
commit
b473d3eee5
2 changed files with 3 additions and 14 deletions
|
|
@ -3027,7 +3027,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
CASE_PROTO(UINT8, uint8_t, int32_data_size);
|
||||
CASE_PROTO(UINT16, uint16_t, int32_data_size);
|
||||
CASE_PROTO(UINT32, uint32_t, uint64_data_size);
|
||||
CASE_PROTO(UINT64, uint64_t, int64_data_size);
|
||||
CASE_PROTO(UINT64, uint64_t, uint64_data_size);
|
||||
CASE_PROTO(FLOAT16, onnxruntime::MLFloat16, int32_data_size);
|
||||
default: ORT_THROW_HR(E_INVALIDARG);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -134,25 +134,14 @@ void RunTypedTest(TensorProto::DataType dt, T value) {
|
|||
}
|
||||
|
||||
TEST(ConstantOfShape, TypeTests) {
|
||||
// TODO: Unskip when fixed #41968513
|
||||
if (DefaultDmlExecutionProvider().get() != nullptr) {
|
||||
GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1876): Unspecified error";
|
||||
}
|
||||
|
||||
// bool can not be tested due to a shortcoming of
|
||||
// our test infrastructure which makes use of
|
||||
// std::vector<T> which has a specialization for bool
|
||||
// and does not have a continuous buffer implementation
|
||||
// RunTypedTest(TensorProto::BOOL, true);
|
||||
|
||||
// The following two types even though supported by the
|
||||
// operator cause a failure at
|
||||
// onnx\onnx\checker.cc tensor_checker() where these
|
||||
// two types are not listed among those that a tensor may
|
||||
// contain
|
||||
// RunTypedTest(TensorProto::INT8, int8_t(8));
|
||||
// RunTypedTest(TensorProto::INT16, int16_t(16));
|
||||
|
||||
RunTypedTest(TensorProto::INT8, int8_t(8));
|
||||
RunTypedTest(TensorProto::INT16, int16_t(16));
|
||||
RunTypedTest(TensorProto::FLOAT, 1.f);
|
||||
RunTypedTest(TensorProto::FLOAT16, MLFloat16(static_cast<uint16_t>(5)));
|
||||
RunTypedTest(TensorProto::DOUBLE, 1.0);
|
||||
|
|
|
|||
Loading…
Reference in a new issue