Commit 1b3f0e07 authored by Scott Cyphers's avatar Scott Cyphers

Tensor initializers

parent 77b216aa
......@@ -13,6 +13,7 @@
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=return-type")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=inconsistent-missing-override")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pedantic-errors")
# whitelist errors here
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Weverything")
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
// Based on the Matrix class in
// The C++ Programming Language
// Fourth edition
// Bjarne Stroustrup
// Addison-Wesley, Boston, 2013.
#pragma once
#include <algorithm>
#include <cassert>
#include <cstring>
#include <memory>
#include <type_traits>
#include <vector>
namespace ngraph
{
namespace runtime
{
namespace init
{
// Recursively define types for N-deep initializer lists
template <typename T, size_t N>
struct NestedInitializerListWrapper
{
using type =
std::initializer_list<typename NestedInitializerListWrapper<T, N - 1>::type>;
};
// 1-deep is a plain initializer_list
template <typename T>
struct NestedInitializerListWrapper<T, 1>
{
using type = std::initializer_list<T>;
};
// Scalar case is just the element type
template <typename T>
struct NestedInitializerListWrapper<T, 0>
{
using type = T;
};
// Convenience type name for N-deep initializer lists of Ts
template <typename T, size_t N>
using NestedInitializerList = typename NestedInitializerListWrapper<T, N>::type;
// Fill in a shape from a nested initializer list
// For a scalar, nothing to do.
template <typename T, size_t N>
typename std::enable_if<(N == 0), void>::type
fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
{
}
// Check that the inits match the shape
template <typename T, size_t N>
typename std::enable_if<(N == 0), void>::type
check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
{
assert(shape.size() == 0);
}
// For a plain initializer list, the shape is the length of the list.
template <typename T, size_t N>
typename std::enable_if<(N == 1)>::type
fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
{
shape.push_back(inits.size());
}
template <typename T, size_t N>
typename std::enable_if<(N == 1)>::type
check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
{
assert(shape.at(shape.size() - N) == inits.size());
}
// In the general case, we append our level's length and recurse.
template <typename T, size_t N>
typename std::enable_if<(N > 1), void>::type
fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
{
shape.push_back(inits.size());
fill_shape<T, N - 1>(shape, *inits.begin());
}
template <typename T, size_t N>
typename std::enable_if<(N > 1), void>::type
check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
{
assert(shape.at(shape.size() - N) == inits.size());
for (auto it : inits)
{
check_shape<T, N - 1>(shape, it);
}
}
// Get the shape of inits.
template <typename T, size_t N>
Shape get_shape(const NestedInitializerList<T, N>& inits)
{
Shape shape;
fill_shape<T, N>(shape, inits);
check_shape<T, N>(shape, inits);
return shape;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N == 1), IT>::type
flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
{
assert(inits.size() == shape.at(shape.size() - N));
for (auto it1 : inits)
{
*(it++) = it1;
}
return it;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N > 1), IT>::type
flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
{
assert(inits.size() == shape.at(shape.size() - N));
for (auto it1 : inits)
{
it = flatten<IT, T, N - 1>(it, shape, it1);
}
return it;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N == 0), IT>::type
flatten(IT it, const Shape& shape, const NestedInitializerList<T, 0>& init)
{
assert(shape.size() == 0);
*(it++) = init;
return it;
}
}
template <typename T>
class NDArrayBase
{
using vtype = std::vector<T>;
public:
using type = T;
using iterator = typename vtype::iterator;
using const_iterator = typename vtype::const_iterator;
NDArrayBase(const Shape& shape)
: m_shape(shape)
, m_elements(shape_size(m_shape))
{
}
const Shape& get_shape() const { return m_shape; }
const_iterator begin() const { return m_elements.begin(); }
const_iterator end() const { return m_elements.end(); }
vtype get_vector() { return m_elements; }
const vtype get_vector() const { return m_elements; }
bool operator==(const NDArrayBase<T>& other) const
{
return m_shape == other.m_shape && m_elements == other.m_elements;
}
protected:
Shape m_shape;
vtype m_elements;
};
/// An N dimensional array of elements of type T
template <typename T, size_t N>
class NDArray : public NDArrayBase<T>
{
public:
NDArray(const init::NestedInitializerList<T, N>& initial_value)
: NDArrayBase<T>(init::get_shape<T, N>(initial_value))
{
init::flatten<typename std::vector<T>::iterator, T, N>(
NDArrayBase<T>::m_elements.begin(), NDArrayBase<T>::m_shape, initial_value);
}
};
}
}
......@@ -14,12 +14,16 @@
#pragma once
#include <algorithm>
#include <cassert>
#include <cstring>
#include <memory>
#include <type_traits>
#include <vector>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/runtime/ndarray.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
......@@ -60,8 +64,17 @@ namespace ngraph
return *this;
}
template <typename T, size_t N>
ParameterizedTensorView<ET>& operator=(const NDArray<T, N>& ndarray)
{
assert(ndarray.get_shape() == get_shape());
std::copy(ndarray.begin(), ndarray.end(), m_vector.begin());
return *this;
}
// For getting the data out
storage_type& get_vector() { return m_vector; }
const storage_type& get_vector() const { return m_vector; }
virtual void write(const void* p, size_t tensor_offset, size_t n) override
{
size_t elt_offset = tensor_offset / sizeof(typename ET::type);
......@@ -104,6 +117,11 @@ namespace ngraph
std::memcpy(p, &m_vector[elt_offset], n);
}
bool operator==(const NDArrayBase<typename ET::type>& ndarray) const
{
return get_shape() == ndarray.get_shape() && get_vector() == ndarray.get_vector();
}
protected:
storage_type m_vector;
};
......
......@@ -37,21 +37,21 @@ TEST(execute, abc)
// Create some tensors for input/output
auto a = backend->make_parameterized_tensor_view<element::Float32>(shape);
*a = vector<float>{1, 2, 3, 4};
*a = runtime::NDArray<float, 2>({{1, 2}, {3, 4}});
auto b = backend->make_parameterized_tensor_view<element::Float32>(shape);
*b = vector<float>{5, 6, 7, 8};
*b = runtime::NDArray<float, 2>({{5, 6}, {7, 8}});
auto c = backend->make_parameterized_tensor_view<element::Float32>(shape);
*c = vector<float>{9, 10, 11, 12};
*c = runtime::NDArray<float, 2>({{9, 10}, {11, 12}});
auto result = backend->make_parameterized_tensor_view<element::Float32>(shape);
(*cf)({a, b, c}, {result});
ASSERT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector());
ASSERT_EQ(*result, (runtime::NDArray<float, 2>({{54, 80}, {110, 144}})));
(*cf)({b, a, c}, {result});
ASSERT_EQ((vector<float>{54, 80, 110, 144}), result->get_vector());
ASSERT_EQ(*result, (runtime::NDArray<float, 2>({{54, 80}, {110, 144}})));
(*cf)({a, c, b}, {result});
ASSERT_EQ((vector<float>{50, 72, 98, 128}), result->get_vector());
ASSERT_EQ(*result, (runtime::NDArray<float, 2>({{50, 72}, {98, 128}})));
}
TEST(execute, abc_int64)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment