ndarray.hpp 7.39 KB
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

// Based on the Matrix class in
// The C++ Programming Language
// Fourth edition
// Bjarne Stroustrup
// Addison-Wesley, Boston, 2013.

#pragma once

#include <algorithm>
#include <cassert>
#include <cstring>
#include <memory>
#include <type_traits>
#include <vector>

#include "ngraph/log.hpp"
#include "ngraph/shape.hpp"

namespace ngraph
{
    namespace test
    {
        namespace init
        {
            // Recursively define types for N-deep initializer lists
            template <typename T, size_t N>
            struct NestedInitializerListWrapper
            {
                using type =
                    std::initializer_list<typename NestedInitializerListWrapper<T, N - 1>::type>;
            };

            // 1-deep is a plain initializer_list
            template <typename T>
            struct NestedInitializerListWrapper<T, 1>
            {
                using type = std::initializer_list<T>;
            };

            // Scalar case is just the element type
            template <typename T>
            struct NestedInitializerListWrapper<T, 0>
            {
                using type = T;
            };

            // Convenience type name for N-deep initializer lists of Ts
            template <typename T, size_t N>
            using NestedInitializerList = typename NestedInitializerListWrapper<T, N>::type;

            // Fill in a shape from a nested initializer list
            // For a scalar, nothing to do.
            template <typename T, size_t N>
            typename std::enable_if<(N == 0), void>::type
                fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
            {
            }

            // Check that the inits match the shape
            template <typename T, size_t N>
            typename std::enable_if<(N == 0), void>::type
                check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                assert(shape.size() == 0);
            }

            // For a plain initializer list, the shape is the length of the list.
            template <typename T, size_t N>
            typename std::enable_if<(N == 1)>::type
                fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                shape.push_back(inits.size());
            }

            template <typename T, size_t N>
            typename std::enable_if<(N == 1)>::type
                check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                assert(shape.at(shape.size() - N) == inits.size());
            }

            // In the general case, we append our level's length and recurse.
            template <typename T, size_t N>
            typename std::enable_if<(N > 1), void>::type
                fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                shape.push_back(inits.size());
                fill_shape<T, N - 1>(shape, *inits.begin());
            }

            template <typename T, size_t N>
            typename std::enable_if<(N > 1), void>::type
                check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                assert(shape.at(shape.size() - N) == inits.size());
                for (auto it : inits)
                {
                    check_shape<T, N - 1>(shape, it);
                }
            }

            // Get the shape of inits.
            template <typename T, size_t N>
            Shape get_shape(const NestedInitializerList<T, N>& inits)
            {
                Shape shape;
                fill_shape<T, N>(shape, inits);
                check_shape<T, N>(shape, inits);
                return shape;
            }

            template <typename IT, typename T, size_t N>
            typename std::enable_if<(N == 1), IT>::type
                flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                assert(inits.size() == shape.at(shape.size() - N));
                for (auto it1 : inits)
                {
                    *(it++) = it1;
                }
                return it;
            }

            template <typename IT, typename T, size_t N>
            typename std::enable_if<(N > 1), IT>::type
                flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                assert(inits.size() == shape.at(shape.size() - N));
                for (auto it1 : inits)
                {
                    it = flatten<IT, T, N - 1>(it, shape, it1);
                }
                return it;
            }

            template <typename IT, typename T, size_t N>
            typename std::enable_if<(N == 0), IT>::type
                flatten(IT it, const Shape& shape, const NestedInitializerList<T, 0>& init)
            {
                assert(shape.size() == 0);
                *(it++) = init;
                return it;
            }
        }

        template <typename T>
        class NDArrayBase
        {
            using vtype = std::vector<T>;

        public:
            using type = T;
            using iterator = typename vtype::iterator;
            using const_iterator = typename vtype::const_iterator;

            NDArrayBase(const Shape& shape)
                : m_shape(shape)
                , m_elements(shape_size(m_shape))
            {
            }

            const Shape& get_shape() const { return m_shape; }
            const_iterator begin() const { return m_elements.begin(); }
            const_iterator end() const { return m_elements.end(); }
            vtype get_vector() { return m_elements; }
            const vtype get_vector() const { return m_elements; }
            operator const vtype() const { return m_elements; }
            operator vtype() { return m_elements; }
            void* data() { return m_elements.data(); }
            const void* data() const { return m_elements.data(); }
            bool operator==(const NDArrayBase<T>& other) const
            {
                return m_shape == other.m_shape && m_elements == other.m_elements;
            }

        protected:
            Shape m_shape;
            vtype m_elements;
        };

        /// An N dimensional array of elements of type T
        template <typename T, size_t N>
        class NDArray : public NDArrayBase<T>
        {
        public:
            NDArray(const init::NestedInitializerList<T, N>& initial_value)
                : NDArrayBase<T>(init::get_shape<T, N>(initial_value))
            {
                init::flatten<typename std::vector<T>::iterator, T, N>(
                    NDArrayBase<T>::m_elements.begin(), NDArrayBase<T>::m_shape, initial_value);
            }
        };
    }
}