#include <algorithm>
#include <iostream>

#include "strides.hpp"
#include "util.hpp"

using namespace std;

//================================================================================================
//
//================================================================================================

ngraph::tensor_size::tensor_size()
    : m_tree{}
    , m_element_type{element_type_float}
{
}

ngraph::tensor_size::tensor_size(size_t s, ElementType et)
    : m_tree{s}
    , m_element_type{et}
{
}

ngraph::tensor_size::tensor_size(const std::initializer_list<scalar_tree>& list, ElementType et)
    : m_tree{list}
    , m_element_type{et}
{
}

ngraph::tensor_size::tensor_size(const std::vector<size_t>& list, const ElementType& et)
    : m_tree{list}
    , m_element_type{et}
{
}

ngraph::tensor_stride ngraph::tensor_size::full_strides() const
{
    tensor_stride   result{*this};
    vector<size_t*> value_pointer_list;
    vector<size_t>  size_list;

    scalar_tree::traverse_tree(result.m_tree, [&](size_t* value) {
        value_pointer_list.push_back(value);
        size_list.push_back(*value);
    });
    int index                  = value_pointer_list.size() - 1;
    *value_pointer_list[index] = result.m_element_type.size();
    for (index--; index >= 0; index--)
    {
        *value_pointer_list[index] = *value_pointer_list[index + 1] * size_list[index + 1];
    }

    return result;
}

ngraph::tensor_stride ngraph::tensor_size::strides() const
{
    return full_strides().strides();
}

ngraph::tensor_size ngraph::tensor_size::sizes() const
{
    vector<size_t> tmp;
    if (m_tree.is_list())
    {
        for (auto s : m_tree.get_list())
        {
            tmp.push_back(s.reduce([](size_t a, size_t b) { return a * b; }));
        }
    }
    else
    {
        tmp.push_back(m_tree.get_value());
    }
    return tensor_size(tmp, m_element_type);
}

std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::tensor_size& s)
{
    out << s.m_tree;
    return out;
}

//================================================================================================
//
//================================================================================================

ngraph::tensor_stride::tensor_stride()
    : m_tree{}
    , m_element_type{element_type_float}
{
}

ngraph::tensor_stride::tensor_stride(const tensor_size& s)
    : m_tree{}
    , m_element_type{s.m_element_type}
{
    m_tree = s.m_tree;
}

ngraph::tensor_stride::tensor_stride(const std::vector<size_t>& list, const ElementType& et)
    : m_tree{}
    , m_element_type{et}
{
    m_tree = list;
}

ngraph::tensor_stride ngraph::tensor_stride::reduce_strides() const
{
    vector<size_t> tmp;
    if (m_tree.is_list())
    {
        for (auto s : m_tree.get_list())
        {
            tmp.push_back(s.reduce([](size_t a, size_t b) { return min(a, b); }));
        }
    }
    else
    {
        tmp.push_back(m_tree.get_value());
    }
    return tensor_stride(tmp, m_element_type);
}

ngraph::tensor_stride ngraph::tensor_stride::full_strides() const
{
    return *this;
}

ngraph::tensor_stride ngraph::tensor_stride::strides() const
{
    return reduce_strides();
}

std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::tensor_stride& s)
{
    out << s.m_tree;
    return out;
}