#pragma once

#include <cstdio>
#include <initializer_list>
#include <vector>

#include "element_type.hpp"
#include "tree.hpp"

namespace ngraph
{
    class tensor_size;
    class tensor_stride;
}

//================================================================================================
//
//================================================================================================
class ngraph::tensor_size
{
    friend class tensor_stride;

public:
    tensor_size();
    tensor_size(size_t s, ElementType et = element_type_float);
    tensor_size(const std::initializer_list<scalar_tree>& list,
                ElementType                               et = element_type_float);

    const ElementType& get_type() const { return m_element_type; }
    tensor_stride      full_strides() const;
    tensor_stride      strides() const;
    tensor_size        sizes() const;

    tensor_size operator[](size_t index) const;

    friend std::ostream& operator<<(std::ostream& out, const tensor_size& s);

private:
    tensor_size(const std::vector<size_t>&, const ElementType&);

    scalar_tree m_tree;
    ElementType m_element_type;
};

//================================================================================================
//
//================================================================================================
class ngraph::tensor_stride
{
    friend class tensor_size;

public:
    tensor_stride();
    const ElementType& get_type() const { return m_element_type; }
    tensor_stride      full_strides() const;
    tensor_stride      strides() const;

    tensor_stride reduce_strides() const;

    tensor_stride operator[](size_t index) const;

    friend std::ostream& operator<<(std::ostream& out, const tensor_stride& s);

private:
    tensor_stride(const tensor_size&);
    tensor_stride(const std::vector<size_t>&, const ElementType&);

    scalar_tree m_tree;
    ElementType m_element_type;
};