// ---------------------------------------------------------------------------- // Copyright 2017 Nervana Systems Inc. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // ---------------------------------------------------------------------------- #pragma once #include <Eigen/Dense> #include <vector> #include "shape.hpp" #include "runtime/tensor_view.hpp" #include "descriptor/tensor_view.hpp" namespace ngraph { namespace runtime { namespace eigen { std::shared_ptr<ngraph::runtime::PrimaryTensorView> make_tensor_view(std::shared_ptr<ngraph::descriptor::TensorView>); template <typename ET> class PrimaryTensorView : public ngraph::runtime::PrimaryTensorView { public: // Standard definitions from vector using value_type = typename ET::type; using storage_type = std::vector<value_type>; using size_type = typename storage_type::size_type; using difference_type = typename storage_type::difference_type; using reference = typename storage_type::reference; using const_reference = typename storage_type::const_reference; using pointer = typename storage_type::pointer; using const_pointer = typename storage_type::const_pointer; using iterator = typename storage_type::iterator; using const_iterator = typename storage_type::const_iterator; using reverse_iterator = typename storage_type::reverse_iterator; using const_reverse_iterator = typename storage_type::const_reverse_iterator; // Mapping vector to eigen using eigen_type = Eigen::Array<value_type, Eigen::Dynamic, 1>; using eigen_map = Eigen::Map<eigen_type>; PrimaryTensorView(const ngraph::Shape& shape) : m_shape(shape) , m_size(ngraph::shape_size(shape)) , m_strides(ngraph::row_major_strides(m_shape)) , m_vector(m_size, 0) , m_map(&m_vector[0], m_size, 1) { } template <typename T> PrimaryTensorView& operator=(const T& value) { m_vector = value; return *this; } // For getting the data out const storage_type& get_vector() { return m_vector; } eigen_map& get_map() { return m_map; } const eigen_map& get_map() const { return m_map; } const Shape& get_shape() const { return m_shape; } protected: ngraph::Shape m_shape; size_t m_size; ngraph::Strides m_strides; storage_type m_vector; eigen_map m_map; }; template <typename ET> void add(const PrimaryTensorView<ET>& arg0, const PrimaryTensorView<ET>& arg1, PrimaryTensorView<ET>& out) { out.get_map() = arg0.get_map() + arg1.get_map(); } template <typename ET> void multiply(const PrimaryTensorView<ET>& arg0, const PrimaryTensorView<ET>& arg1, PrimaryTensorView<ET>& out) { out.get_map() = arg0.get_map() * arg1.get_map(); } } } }