mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
TensorDesc::Placement test failure - cherry pick Vibranium fix. (#2328)
This commit is contained in:
parent
67ec626d88
commit
db454beacf
1 changed files with 31 additions and 8 deletions
|
|
@ -48,8 +48,8 @@ TensorDesc::TensorDesc(MLOperatorTensorDataType dataType)
|
|||
|
||||
TensorDesc::TensorDesc(
|
||||
MLOperatorTensorDataType dataType,
|
||||
gsl::span<const uint32_t> dimensions,
|
||||
gsl::span<const uint32_t> nonBroadcastDimensions,
|
||||
gsl::span<const uint32_t> dimensions, // Desired dimensions
|
||||
gsl::span<const uint32_t> nonBroadcastDimensions, // Actual physical dimensions
|
||||
uint32_t coerceAxis,
|
||||
int32_t placement, // Adjustment offset of the passed dimensions within the minDimensionCount.
|
||||
int32_t leftAlignedDimensionCount, // Number of dimensions that remain left aligned when expanded to minimum count (INT32_MAX means all, 0 means all right aligned).
|
||||
|
|
@ -63,11 +63,6 @@ TensorDesc::TensorDesc(
|
|||
m_bufferTensorDesc.DataType = GetDmlDataTypeFromMlDataType(dataType);
|
||||
ML_CHECK_VALID_ARGUMENT(ApiTraits::IsValidEnumValue(m_bufferTensorDesc.DataType));
|
||||
|
||||
// Flattening coercion isn't always possible when striding is used to broadcast tensors.
|
||||
// Also, placement and split alignment is not implemented in combination with broadcasting.
|
||||
ML_CHECK_VALID_ARGUMENT((nonBroadcastDimensions == dimensions)
|
||||
|| ((coerceAxis == TensorAxis::DoNotCoerce) && (placement == W) && (leftAlignedDimensionCount == RightAligned)));
|
||||
|
||||
gsl::span<const uint32_t> sizes;
|
||||
|
||||
// If needed, flatten the tensor dimensions to a 2D tensor of size [a_0 * ... * a_{coerceAxis-1}, a_{coerceAxis} * ... * a_{n-1}]
|
||||
|
|
@ -98,6 +93,9 @@ TensorDesc::TensorDesc(
|
|||
sizes = dimensions;
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
// Align dimensions
|
||||
|
||||
// Determine the number of dimensions that should be aligned to the left edge when promoting to the minimum dimension count.
|
||||
// Negative values mean align from the right.
|
||||
const int32_t rank = gsl::narrow_cast<int32_t>(sizes.size());
|
||||
|
|
@ -134,17 +132,39 @@ TensorDesc::TensorDesc(
|
|||
while (j < MaximumDimensionCount) { m_sizes[j++] = 1; }
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
// Coerce the physical shape to the desired shape.
|
||||
|
||||
// By default, assume strides are not necessary.
|
||||
bool useStrides = false;
|
||||
|
||||
if (dimensions != nonBroadcastDimensions)
|
||||
{
|
||||
// This broadcasting and subset logic is only applicable to the simple case where all
|
||||
// dimensions are contiguously right aligned, which means no flattening coercion,
|
||||
// placement offset, or split alignment. In such cases, the right side of m_sizes
|
||||
// should match the original dimensions.
|
||||
ML_CHECK_VALID_ARGUMENT(std::equal(
|
||||
dimensions.begin(),
|
||||
dimensions.end(),
|
||||
&m_sizes[m_bufferTensorDesc.DimensionCount - rank],
|
||||
&m_sizes[m_bufferTensorDesc.DimensionCount]
|
||||
));
|
||||
|
||||
// Stretch any dimensions with a single element.
|
||||
//
|
||||
// e.g. physical [2,1,4]
|
||||
// desired [2,3,4]
|
||||
// output [2,3,4]
|
||||
// strides [4,0,1]
|
||||
//
|
||||
// If broadcasting is used, then strides are used.
|
||||
useStrides = true;
|
||||
|
||||
// Walk backwards through both input shapes and broadcast or default each dimension
|
||||
// Walk backwards through both input shapes and broadcast or default each dimension.
|
||||
auto nonBroadcastDimsIter = nonBroadcastDimensions.rbegin();
|
||||
uint32_t elementCount = 1;
|
||||
|
||||
for (int descDimIndex = m_bufferTensorDesc.DimensionCount - 1; descDimIndex >= 0; --descDimIndex)
|
||||
{
|
||||
if (nonBroadcastDimsIter == nonBroadcastDimensions.rend() || (*nonBroadcastDimsIter == 1))
|
||||
|
|
@ -164,6 +184,9 @@ TensorDesc::TensorDesc(
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////
|
||||
// Handle 64-bit tensors.
|
||||
|
||||
uint64_t endPaddingInBytes = 0;
|
||||
|
||||
if (dataType == MLOperatorTensorDataType::UInt64 || dataType == MLOperatorTensorDataType::Int64)
|
||||
|
|
|
|||
Loading…
Reference in a new issue