// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "NumericData.h" #include "StringData.h" // // the Tensor class is the actual object for CPU memory buffers. // TensorBase contains one of these to represent the raw memory // GetCpuResource() returns it // namespace _winml { inline size_t compute_size_of_shape(const std::vector& shape) { auto size_of_shape = static_cast( std::accumulate(std::begin(shape), std::end(shape), static_cast(1), std::multiplies()) ); return size_of_shape; } template inline auto create_data(const std::vector& shape, const wfc::IIterable& buffers) { return _winml::numeric_data::create(compute_size_of_shape(shape), sizeof(T), buffers); } template <> inline auto create_data(const std::vector& shape, const wfc::IIterable& /*buffers*/) { return _winml::string_data::create(compute_size_of_shape(shape)); } template class Tensor { private: std::shared_ptr<_winml::idata> data_; std::vector shape_; private: Tensor() = delete; public: Tensor(const std::vector& shape) : data_(create_data(shape, nullptr)), shape_(shape) {} Tensor(const std::vector& shape, const wfc::IIterable& buffers) : data_(create_data(shape, buffers)), shape_(shape) {} auto size_in_bytes() const { return data_->size_in_bytes(); } auto num_buffers() { return data_->num_buffers(); } auto& buffers() { return data_->buffers(); } gsl::span buffer(bool should_sync_buffer = true) { auto span = data_->buffer(should_sync_buffer); return gsl::span(reinterpret_cast(span.data()), data_->num_elements()); } auto flush() { return data_->flush(); } void set(size_t size, const T* data) { auto size_in_bytes = size * sizeof(T); data_->set(size_in_bytes, reinterpret_cast(data)); } const std::vector& shape() const { return shape_; } auto get_data() { return data_; } }; } // namespace _winml