#pragma once #include <algorithm> #include <functional> #include <initializer_list> #include <iostream> #include <vector> #include "util.hpp" namespace ngraph { template <typename T> class tree; using scalar_tree = ngraph::tree<size_t>; } //================================================================================================ // //================================================================================================ template <typename T> class ngraph::tree { public: tree(T s) : m_list{} , m_value{s} , m_is_list{false} { } tree(const std::initializer_list<tree<T>>& list) : m_list{} , m_value{0} , m_is_list{true} { m_list = list; } tree(const std::vector<T>& list) : m_list{} , m_value{0} , m_is_list{true} { for (auto s : list) { m_list.push_back(tree(s)); } } bool is_list() const { return m_is_list; } T get_value() const { return m_value; } const std::vector<tree>& get_list() const { return m_list; } static void traverse_tree(tree& s, std::function<void(T*)> func) { if (s.is_list()) { for (tree& s1 : s.m_list) { traverse_tree(s1, func); } } else { func(&(s.m_value)); } } friend std::ostream& operator<<(std::ostream& out, const tree& s) { if (s.is_list()) { out << "(" << join(s.get_list(), ", ") << ")"; } else { out << s.get_value(); } return out; } T reduce(const std::function<T(T, T)>& func) const { size_t rc; if (is_list()) { switch (m_list.size()) { case 0: rc = 0; break; case 1: rc = m_list[0].reduce(func); break; default: rc = m_list[0].reduce(func); for (int i = 1; i < m_list.size(); i++) { rc = func(rc, m_list[i].reduce(func)); } break; } } else { rc = m_value; } return rc; } private: std::vector<tree> m_list; T m_value; bool m_is_list; };