// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include "TensorBuffer.h" #include "MLValueHelpers.h" namespace Windows::AI::MachineLearning { template class Tensor { private: using TensorBuffer = TensorBuffer; using TensorBufferPtr = typename TensorBuffer::TensorBufferPtr; TensorBufferPtr m_buffer; std::vector m_shape; public: Tensor() = delete; Tensor( std::vector const& shape, winrt::Windows::Storage::Streams::IBuffer buffer) : m_shape(shape), m_buffer( TensorBuffer::Create( static_cast( std::accumulate( std::begin(shape), std::end(shape), static_cast(1), std::multiplies())), buffer)) {} Tensor(std::vector const& shape) : m_shape(shape), m_buffer( TensorBuffer::Create( static_cast( std::accumulate( std::begin(shape), std::end(shape), static_cast(1), std::multiplies())))) {} Tensor(std::vector const&& shape) : m_shape(std::move(shape)), m_buffer( TensorBuffer::Create( static_cast( std::accumulate( std::begin(shape), std::end(shape), static_cast(1), std::multiplies())))) {} auto size() const { return m_buffer->Size(); } auto buffer() { return m_buffer->Buffer(); } OrtValue MLValue() { // Get the shape onnxruntime::TensorShape shape(m_shape); // Get the data type auto type = onnxruntime::DataTypeImpl::GetType(); return MLValueHelpers::CreateMLValue(shape, type, buffer().second); } void set(uint32_t size, const T* pData) { m_buffer->Set(size, pData); } void set(std::vector&& other) { m_buffer->Set(other); } const std::vector& shape() const { return m_shape; } }; } // namespace Windows::AI::MachineLearning