ndarray.hpp 8.05 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
Scott Cyphers's avatar
Scott Cyphers committed
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30

// Based on the Matrix class in
// The C++ Programming Language
// Fourth edition
// Bjarne Stroustrup
// Addison-Wesley, Boston, 2013.

#pragma once

#include <algorithm>
#include <cstring>
#include <memory>
#include <type_traits>
#include <vector>

31
#include "ngraph/log.hpp"
Scott Cyphers's avatar
Scott Cyphers committed
32 33
#include "ngraph/shape.hpp"

Scott Cyphers's avatar
Scott Cyphers committed
34 35
namespace ngraph
{
36
    namespace test
Scott Cyphers's avatar
Scott Cyphers committed
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
    {
        namespace init
        {
            // Recursively define types for N-deep initializer lists
            template <typename T, size_t N>
            struct NestedInitializerListWrapper
            {
                using type =
                    std::initializer_list<typename NestedInitializerListWrapper<T, N - 1>::type>;
            };

            // 1-deep is a plain initializer_list
            template <typename T>
            struct NestedInitializerListWrapper<T, 1>
            {
                using type = std::initializer_list<T>;
            };

            // Scalar case is just the element type
            template <typename T>
            struct NestedInitializerListWrapper<T, 0>
            {
                using type = T;
            };

            // Convenience type name for N-deep initializer lists of Ts
            template <typename T, size_t N>
            using NestedInitializerList = typename NestedInitializerListWrapper<T, N>::type;

            // Fill in a shape from a nested initializer list
            // For a scalar, nothing to do.
            template <typename T, size_t N>
            typename std::enable_if<(N == 0), void>::type
                fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
            {
            }

            // Check that the inits match the shape
            template <typename T, size_t N>
            typename std::enable_if<(N == 0), void>::type
                check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
79 80 81 82
                if (shape.size() != 0)
                {
                    throw std::invalid_argument("Initializers do not match shape");
                }
Scott Cyphers's avatar
Scott Cyphers committed
83 84 85 86 87 88 89 90 91 92 93 94 95 96
            }

            // For a plain initializer list, the shape is the length of the list.
            template <typename T, size_t N>
            typename std::enable_if<(N == 1)>::type
                fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                shape.push_back(inits.size());
            }

            template <typename T, size_t N>
            typename std::enable_if<(N == 1)>::type
                check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
97 98 99 100
                if (shape.at(shape.size() - N) != inits.size())
                {
                    throw std::invalid_argument("Initializers do not match shape");
                }
Scott Cyphers's avatar
Scott Cyphers committed
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
            }

            // In the general case, we append our level's length and recurse.
            template <typename T, size_t N>
            typename std::enable_if<(N > 1), void>::type
                fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
            {
                shape.push_back(inits.size());
                fill_shape<T, N - 1>(shape, *inits.begin());
            }

            template <typename T, size_t N>
            typename std::enable_if<(N > 1), void>::type
                check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
116 117 118 119
                if (shape.at(shape.size() - N) != inits.size())
                {
                    throw std::invalid_argument("Initializers do not match shape");
                }
Scott Cyphers's avatar
Scott Cyphers committed
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
                for (auto it : inits)
                {
                    check_shape<T, N - 1>(shape, it);
                }
            }

            // Get the shape of inits.
            template <typename T, size_t N>
            Shape get_shape(const NestedInitializerList<T, N>& inits)
            {
                Shape shape;
                fill_shape<T, N>(shape, inits);
                check_shape<T, N>(shape, inits);
                return shape;
            }

            template <typename IT, typename T, size_t N>
            typename std::enable_if<(N == 1), IT>::type
                flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
140 141 142 143
                if (inits.size() != shape.at(shape.size() - N))
                {
                    throw std::invalid_argument("Initializers do not match shape");
                }
Scott Cyphers's avatar
Scott Cyphers committed
144 145 146 147 148 149 150 151 152 153 154
                for (auto it1 : inits)
                {
                    *(it++) = it1;
                }
                return it;
            }

            template <typename IT, typename T, size_t N>
            typename std::enable_if<(N > 1), IT>::type
                flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
            {
155 156 157 158
                if (inits.size() != shape.at(shape.size() - N))
                {
                    throw std::invalid_argument("Initializers do not match shape");
                }
Scott Cyphers's avatar
Scott Cyphers committed
159 160 161 162 163 164 165 166 167 168 169
                for (auto it1 : inits)
                {
                    it = flatten<IT, T, N - 1>(it, shape, it1);
                }
                return it;
            }

            template <typename IT, typename T, size_t N>
            typename std::enable_if<(N == 0), IT>::type
                flatten(IT it, const Shape& shape, const NestedInitializerList<T, 0>& init)
            {
170 171 172 173
                if (shape.size() != 0)
                {
                    throw std::invalid_argument("Initializers do not match shape");
                }
Scott Cyphers's avatar
Scott Cyphers committed
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
                *(it++) = init;
                return it;
            }
        }

        template <typename T>
        class NDArrayBase
        {
            using vtype = std::vector<T>;

        public:
            using type = T;
            using iterator = typename vtype::iterator;
            using const_iterator = typename vtype::const_iterator;

            NDArrayBase(const Shape& shape)
                : m_shape(shape)
                , m_elements(shape_size(m_shape))
            {
            }

            const Shape& get_shape() const { return m_shape; }
            const_iterator begin() const { return m_elements.begin(); }
            const_iterator end() const { return m_elements.end(); }
            vtype get_vector() { return m_elements; }
            const vtype get_vector() const { return m_elements; }
200 201
            operator const vtype() const { return m_elements; }
            operator vtype() { return m_elements; }
Robert Kimball's avatar
Robert Kimball committed
202 203
            void* data() { return m_elements.data(); }
            const void* data() const { return m_elements.data(); }
Scott Cyphers's avatar
Scott Cyphers committed
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
            bool operator==(const NDArrayBase<T>& other) const
            {
                return m_shape == other.m_shape && m_elements == other.m_elements;
            }

        protected:
            Shape m_shape;
            vtype m_elements;
        };

        /// An N dimensional array of elements of type T
        template <typename T, size_t N>
        class NDArray : public NDArrayBase<T>
        {
        public:
            NDArray(const init::NestedInitializerList<T, N>& initial_value)
                : NDArrayBase<T>(init::get_shape<T, N>(initial_value))
            {
                init::flatten<typename std::vector<T>::iterator, T, N>(
                    NDArrayBase<T>::m_elements.begin(), NDArrayBase<T>::m_shape, initial_value);
            }
        };
    }
}