Commit 1bb52e61 authored by Jai Menon's avatar Jai Menon Committed by GitHub

Merge branch 'master' into jmenon/codegen

parents 9a19500f 6f38615d
...@@ -6,7 +6,7 @@ RUN apt-get update && apt-get install -y \ ...@@ -6,7 +6,7 @@ RUN apt-get update && apt-get install -y \
build-essential cmake \ build-essential cmake \
clang-3.9 clang-format-3.9 \ clang-3.9 clang-format-3.9 \
git \ git \
wget patch diffutils wget patch diffutils zlib1g-dev libtinfo-dev
RUN apt-get clean autoclean && \ RUN apt-get clean autoclean && \
apt-get autoremove -y apt-get autoremove -y
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
set (SRC set (SRC
descriptor/input.cpp descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
descriptor/output.cpp descriptor/output.cpp
descriptor/primary_tensor_view.cpp descriptor/primary_tensor_view.cpp
descriptor/tensor.cpp descriptor/tensor.cpp
descriptor/tensor_view.cpp
descriptor/tuple.cpp descriptor/tuple.cpp
function.cpp function.cpp
log.cpp log.cpp
...@@ -50,8 +52,13 @@ set (SRC ...@@ -50,8 +52,13 @@ set (SRC
pass/propagate_types.cpp pass/propagate_types.cpp
pass/topological_sort.cpp pass/topological_sort.cpp
pass/visualize_tree.cpp pass/visualize_tree.cpp
runtime/call_frame.cpp runtime/backend.cpp
runtime/external_function.cpp runtime/manager.cpp
runtime/ngvm/call_frame.cpp
runtime/ngvm/external_function.cpp
runtime/ngvm/ngvm_backend.cpp
runtime/ngvm/ngvm_manager.cpp
runtime/tensor_view.cpp
runtime/tuple.cpp runtime/tuple.cpp
runtime/utils.cpp runtime/utils.cpp
shape.cpp shape.cpp
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#pragma once #pragma once
#include <memory>
namespace ngraph namespace ngraph
{ {
namespace descriptor namespace descriptor
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace descriptor;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::descriptor::layout; using namespace ngraph::descriptor::layout;
using ngraph::Shape; using ngraph::Shape;
......
...@@ -14,16 +14,17 @@ ...@@ -14,16 +14,17 @@
#pragma once #pragma once
#include <cstddef>
#include <vector> #include <vector>
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace descriptor namespace descriptor
{ {
class TensorView;
namespace layout namespace layout
{ {
/// @brief The standard strided layout, used for row-major and column-major, their permutations and slices. /// @brief The standard strided layout, used for row-major and column-major, their permutations and slices.
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::descriptor::layout;
TensorViewLayout::TensorViewLayout(const ngraph::descriptor::TensorView& tensor_view)
: m_tensor_view_type(tensor_view.get_tensor_view_type())
{
}
const ngraph::element::Type& TensorViewLayout::get_element_type() const
{
return m_tensor_view_type->get_element_type();
}
const ngraph::Shape& TensorViewLayout::get_shape() const
{
return m_tensor_view_type->get_shape();
}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -22,6 +23,11 @@ ...@@ -22,6 +23,11 @@
namespace ngraph namespace ngraph
{ {
namespace element
{
class Type;
}
namespace descriptor namespace descriptor
{ {
class TensorView; class TensorView;
...@@ -34,10 +40,7 @@ namespace ngraph ...@@ -34,10 +40,7 @@ namespace ngraph
class TensorViewLayout class TensorViewLayout
{ {
protected: protected:
TensorViewLayout(const ngraph::descriptor::TensorView& tensor_view) TensorViewLayout(const ngraph::descriptor::TensorView& tensor_view);
: m_tensor_view_type(tensor_view.get_tensor_view_type())
{
}
public: public:
virtual ~TensorViewLayout() {} virtual ~TensorViewLayout() {}
...@@ -51,11 +54,8 @@ namespace ngraph ...@@ -51,11 +54,8 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t. /// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0; virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const element::Type& get_element_type() const const element::Type& get_element_type() const;
{ const Shape& get_shape() const;
return m_tensor_view_type->get_element_type();
}
const Shape& get_shape() const { return m_tensor_view_type->get_shape(); }
/// Where this view is located in the buffer. /// Where this view is located in the buffer.
const BufferPos& get_buffer_pos() const { return m_buffer_pos; } const BufferPos& get_buffer_pos() const { return m_buffer_pos; }
BufferPos& get_buffer_pos() { return m_buffer_pos; } BufferPos& get_buffer_pos() { return m_buffer_pos; }
......
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/input.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace descriptor; using namespace ngraph::descriptor;
Output::Output(const std::shared_ptr<Node>& node, Output::Output(const std::shared_ptr<Node>& node,
size_t index, size_t index,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <set> #include <set>
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/node.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -12,14 +12,12 @@ ...@@ -12,14 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#pragma once
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor;
namespace ngraph std::shared_ptr<const ngraph::ValueType> TensorView::get_value_type() const
{ {
namespace runtime return m_tensor_view_type;
{ }
using TensorViewIndex = unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t>;
}
}
\ No newline at end of file
...@@ -16,15 +16,13 @@ ...@@ -16,15 +16,13 @@
#include <memory> #include <memory>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/value.hpp" #include "ngraph/descriptor/value.hpp"
#include "ngraph/log.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
{ {
class Node; class Node;
class TensorViewType;
namespace descriptor namespace descriptor
{ {
...@@ -34,6 +32,9 @@ namespace ngraph ...@@ -34,6 +32,9 @@ namespace ngraph
class TensorViewLayout; class TensorViewLayout;
} }
class Tensor;
class TensorView;
/// @brief Compile-time descriptor of a first-class value that is a view of a tensor. /// @brief Compile-time descriptor of a first-class value that is a view of a tensor.
class TensorView : public Value class TensorView : public Value
{ {
...@@ -51,10 +52,7 @@ namespace ngraph ...@@ -51,10 +52,7 @@ namespace ngraph
virtual const Tensor& get_tensor() const = 0; virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0; virtual Tensor& get_tensor() = 0;
virtual std::shared_ptr<const ValueType> get_value_type() const override virtual std::shared_ptr<const ValueType> get_value_type() const override;
{
return m_tensor_view_type;
}
const std::string& get_name() const { return m_name; } const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const std::shared_ptr<const TensorViewType> get_tensor_view_type() const
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
{ {
class ValueType;
namespace descriptor namespace descriptor
{ {
class TensorView; class TensorView;
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp" #include "ngraph/ops/parameter.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
......
...@@ -80,9 +80,12 @@ ...@@ -80,9 +80,12 @@
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/manager.hpp"
#include "ngraph/runtime/ngvm/ngvm_backend.hpp"
#include "ngraph/runtime/ngvm/ngvm_manager.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp" #include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp" #include "ngraph/runtime/tuple.hpp"
......
...@@ -22,17 +22,13 @@ ...@@ -22,17 +22,13 @@
#include <iostream> #include <iostream>
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
{ {
namespace descriptor
{
class Input;
class Output;
class Tensor;
}
/// Nodes are the backbone of the graph of Value dataflow. Every node has /// Nodes are the backbone of the graph of Value dataflow. Every node has
/// zero or more nodes as arguments and one value, which is either a tensor /// zero or more nodes as arguments and one value, which is either a tensor
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/ops/op.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <memory> #include <memory>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ops/op.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/ops/op.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/ops/broadcast.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/concatenate.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/ops/constant.hpp"
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include "ngraph/node.hpp"
#include "ngraph/runtime/utils.hpp" #include "ngraph/runtime/utils.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/convert.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/dot.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/ops/function_call.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/function.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "ngraph/pass/manager.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
namespace ngraph namespace ngraph
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <algorithm> #include <algorithm>
#include <sstream> #include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/op.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <sstream> #include <sstream>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/parameter.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "ngraph/ngraph.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/function.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <memory> #include <memory>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ops/select.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/tuple.hpp"
using namespace std; using namespace std;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
......
...@@ -12,11 +12,8 @@ ...@@ -12,11 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <memory> #include "ngraph/ops/op.hpp"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph; using namespace ngraph;
using namespace ngraph::op; using namespace ngraph::op;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <memory> #include <memory>
#include "ngraph/ngraph.hpp" #include "ngraph/ops/op.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
......
...@@ -22,7 +22,6 @@ namespace ngraph ...@@ -22,7 +22,6 @@ namespace ngraph
{ {
class AssignTensors; class AssignTensors;
} }
class Node;
} }
class ngraph::pass::AssignTensors : public CallGraphPass class ngraph::pass::AssignTensors : public CallGraphPass
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <fstream> #include <fstream>
#include "ngraph/ngraph.hpp" #include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/pass/dump_sorted.hpp" #include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -24,7 +24,6 @@ namespace ngraph ...@@ -24,7 +24,6 @@ namespace ngraph
{ {
class DumpSorted; class DumpSorted;
} }
class Node;
} }
class ngraph::pass::DumpSorted : public ModulePass class ngraph::pass::DumpSorted : public ModulePass
......
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/node.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
......
...@@ -23,7 +23,6 @@ namespace ngraph ...@@ -23,7 +23,6 @@ namespace ngraph
{ {
class Liveness; class Liveness;
} }
class Node;
} }
class ngraph::pass::Liveness : public CallGraphPass class ngraph::pass::Liveness : public CallGraphPass
......
...@@ -28,9 +28,6 @@ namespace ngraph ...@@ -28,9 +28,6 @@ namespace ngraph
class Manager; class Manager;
class ManagerState; class ManagerState;
} }
class Node;
class Function;
} }
class ngraph::pass::Manager class ngraph::pass::Manager
......
...@@ -17,15 +17,14 @@ ...@@ -17,15 +17,14 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/function.hpp"
namespace ngraph namespace ngraph
{ {
namespace pass namespace pass
{ {
class ManagerState; class ManagerState;
} }
class Node;
class Function;
} }
class ngraph::pass::ManagerState class ngraph::pass::ManagerState
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
......
...@@ -28,7 +28,6 @@ namespace ngraph ...@@ -28,7 +28,6 @@ namespace ngraph
class MemoryNode; class MemoryNode;
class MemoryManager; class MemoryManager;
} }
class Node;
} }
class ngraph::pass::MemoryLayout : public CallGraphPass class ngraph::pass::MemoryLayout : public CallGraphPass
......
...@@ -26,7 +26,6 @@ namespace ngraph ...@@ -26,7 +26,6 @@ namespace ngraph
{ {
class MemoryVisualize; class MemoryVisualize;
} }
class Node;
} }
class ngraph::pass::MemoryVisualize : public ModulePass class ngraph::pass::MemoryVisualize : public ModulePass
......
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ngraph/function.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/manager_state.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -30,9 +32,7 @@ namespace ngraph ...@@ -30,9 +32,7 @@ namespace ngraph
class NodePass; class NodePass;
class CallGraphPass; class CallGraphPass;
class Manager; class Manager;
class ManagerState;
} }
class Function;
} }
class ngraph::pass::PassBase class ngraph::pass::PassBase
......
...@@ -22,7 +22,6 @@ namespace ngraph ...@@ -22,7 +22,6 @@ namespace ngraph
{ {
class PropagateTypes; class PropagateTypes;
} }
class Node;
} }
class ngraph::pass::PropagateTypes : public CallGraphPass class ngraph::pass::PropagateTypes : public CallGraphPass
......
...@@ -25,7 +25,6 @@ namespace ngraph ...@@ -25,7 +25,6 @@ namespace ngraph
{ {
class TopologicalSort; class TopologicalSort;
} }
class Node;
} }
class ngraph::pass::TopologicalSort : public FunctionPass class ngraph::pass::TopologicalSort : public FunctionPass
......
...@@ -26,7 +26,6 @@ namespace ngraph ...@@ -26,7 +26,6 @@ namespace ngraph
{ {
class VisualizeTree; class VisualizeTree;
} }
class Node;
} }
class ngraph::pass::VisualizeTree : public ModulePass class ngraph::pass::VisualizeTree : public ModulePass
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime;
std::shared_ptr<TensorView>
Backend::make_primary_tensor_view(const ngraph::element::Type& element_type, const Shape& shape)
{
return element_type.make_primary_tensor_view(shape);
}
std::shared_ptr<ngraph::runtime::Tuple>
Backend::make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements)
{
return std::make_shared<ngraph::runtime::Tuple>(elements);
}
...@@ -14,46 +14,57 @@ ...@@ -14,46 +14,57 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include <memory>
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/common.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace element
{
class Type;
}
namespace runtime namespace runtime
{ {
namespace eigen class ExternalFunction;
class CallFrame;
class TensorView;
class Tuple;
class Value;
template <typename ET>
class ParameterizedTensorView;
/// @brief Interface to a generic backend.
///
/// Backends are responsible for function execution and value allocation.
class Backend
{ {
public:
virtual ~Backend() {}
/// @brief Make a call frame that can support one concurrent call of an external function.
///
/// If more than one concurrent execution is needed, each execution will require its own call frame.
virtual std::shared_ptr<ngraph::runtime::CallFrame>
make_call_frame(const std::shared_ptr<ExternalFunction>& external_function) = 0;
/// @brief Return a handle for a tensor on the backend device.
virtual std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape);
template <typename ET> template <typename ET>
class ScalarTensorProductInstruction : public Instruction std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_parameterized_tensor_view(const Shape& shape)
{ {
public: return std::dynamic_pointer_cast<ngraph::runtime::ParameterizedTensorView<ET>>(
ScalarTensorProductInstruction(const TensorViewInfo& arg0, make_primary_tensor_view(ET::element_type(), shape));
const TensorViewInfo& arg1, }
const TensorViewInfo& out)
: m_arg0(arg0) /// @brief Construct a tuple handle from a sequence of values.
, m_arg1(arg1) virtual std::shared_ptr<ngraph::runtime::Tuple>
, m_out(out) make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
{ };
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenVector<ET>(call_frame, m_out) =
call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] *
EigenVector<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
} }
} }
...@@ -25,51 +25,23 @@ namespace ngraph ...@@ -25,51 +25,23 @@ namespace ngraph
namespace runtime namespace runtime
{ {
class PrimaryTensorView; class PrimaryTensorView;
class Instruction; class Value;
// A VM for executing lightly-compiled graph functions. // A VM for executing lightly-compiled graph functions.
class CallFrame class CallFrame
{ {
public: public:
CallFrame( virtual ~CallFrame() {}
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
/// Tuples will be expanded into their tensor views to build the call frame. /// Tuples will be expanded into their tensor views to build the call frame.
void operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs, virtual void
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs); operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs) = 0;
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views. /// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs); virtual void tensor_call(const TensorViewPtrs& inputs,
const TensorViewPtrs& outputs) = 0;
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{
return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
}
template <typename ET>
typename ET::type* get_tensor_view_data(size_t i)
{
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
}
protected:
size_t m_n_inputs;
size_t m_n_outputs;
TensorViewPtrs m_tensor_views;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
size_t m_next_pc;
bool m_return;
}; };
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class DivideInstruction : public Instruction
{
public:
DivideInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0) / EigenArray1d<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class DotInstruction : public Instruction
{
public:
DotInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out)
<< EigenVector<ET>(call_frame, m_arg0)
.dot(EigenVector<ET>(call_frame, m_arg1));
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class MaximumInstruction : public Instruction
{
public:
MaximumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0)
.max(EigenArray1d<ET>(call_frame, m_arg1));
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
class CallFrame;
namespace eigen
{
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET>
using DynamicArray = Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
template <typename ET>
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicMatrix =
Eigen::Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicVector = Eigen::Matrix<typename ET::type, Eigen::Dynamic, 1>;
template <typename ET>
using EigenVectorBase = Eigen::Map<DynamicVector<ET>, 0, VectorStrides>;
namespace fmt
{
/// @brief vector format for Eigen wrappers.
class V
{
public:
V(const TensorViewInfo& tensor_view_info)
: l0(tensor_view_info
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size())
{
}
public:
size_t l0;
size_t l1{1};
size_t s0{1};
size_t s1{1};
};
class M
{
M(const Shape& shape, const Strides& strides)
: l0(shape.at(0))
, l1(shape.at(1))
, s0(strides.at(0))
, s1(strides.at(1))
{
}
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const TensorViewInfo& tensor_view_info)
: M(tensor_view_info
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>())
{
}
public:
size_t l0;
size_t l1;
size_t s0;
size_t s1;
};
}
// ET element type
// FMT array format (fmt::V for vector, etc.)
// BASE select array/matrix
template <typename ET, typename FMT, typename BASE, typename STRIDES = DynamicStrides>
class EigenWrapper : public BASE
{
using base = BASE;
public:
EigenWrapper(typename ET::type* t, const FMT& fmt)
: base(t, fmt.l0, fmt.l1, STRIDES(fmt.s0, fmt.s1))
{
}
EigenWrapper(
typename ET::type* t,
const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: base(t, layout->get_size(), 1, DynamicStrides(1, 1))
{
}
EigenWrapper(CallFrame& call_frame, const TensorViewInfo& tensor_view_info)
: EigenWrapper(
call_frame.get_tensor_view_data<ET>(tensor_view_info.get_index()),
FMT(tensor_view_info))
{
}
template <typename U>
EigenWrapper& operator=(const U& other)
{
this->base::operator=(other);
return *this;
}
};
template <typename ET, typename FMT = fmt::V>
using EigenArray1d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenArray2d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenMatrix = EigenWrapper<ET, FMT, EigenMatrixBase<ET>>;
template <typename ET, typename FMT = fmt::V>
using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>;
}
}
}
...@@ -26,45 +26,29 @@ namespace ngraph ...@@ -26,45 +26,29 @@ namespace ngraph
{ {
namespace runtime namespace runtime
{ {
class CallFrame;
class ExternalFunction class ExternalFunction
{ {
using FunctionMap = protected:
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true); bool release_function = true)
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(); : m_function(function)
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(FunctionMap& function_map); , m_release_function(release_function)
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>> , m_is_compiled(false)
get_instructions()
{ {
return m_instructions;
} }
// Release original function's resources // Release original function's resources
void release_function() { m_function = nullptr; } void release_function() { m_function = nullptr; }
protected: public:
void compile(); virtual ~ExternalFunction() {}
void compile(FunctionMap& function_map); virtual std::shared_ptr<CallFrame> make_call_frame() = 0;
protected:
std::shared_ptr<ngraph::Function> m_function; std::shared_ptr<ngraph::Function> m_function;
bool m_release_function; bool m_release_function;
bool m_is_compiled; bool m_is_compiled;
size_t m_n_inputs;
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions;
ngraph::descriptor::TensorViewPtrs m_temp_views;
static OpMap& get_op_map();
}; };
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/manager.hpp"
using namespace ngraph::runtime;
Manager::FactoryMap& Manager::get_factory_map()
{
static FactoryMap factory_map;
return factory_map;
}
std::shared_ptr<Manager> Manager::get(const std::string& name)
{
return get_factory_map().at(name)(name);
}
Manager::Factory Manager::register_factory(std::string name, Factory factory)
{
get_factory_map()[name] = factory;
return factory;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <string>
namespace ngraph
{
class Function;
namespace runtime
{
class Backend;
class ExternalFunction;
/// @brief Interface to a generic manager.
///
/// A manager provides access to compilation for a backend, and a means to obtain
/// a backed for execution and allocation.
class Manager
{
public:
virtual ~Manager() {}
/// @brief Allocate a backend for this transformer.
///
/// Specific transformers may provide addtional methods for allocating customized backends.
virtual std::shared_ptr<Backend> allocate_backend() = 0;
/// @brief Convert a function to a form that can be run on a backend.
virtual std::shared_ptr<ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) = 0;
using Factory = std::function<std::shared_ptr<Manager>(const std::string&)>;
using FactoryMap = std::map<std::string, Factory>;
static FactoryMap& get_factory_map();
static std::shared_ptr<Manager> get(const std::string& name);
static Factory register_factory(std::string name, Factory factory);
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
using namespace std;
using namespace ngraph::runtime::ngvm;
CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps,
size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_initial_pc(initial_pc)
, m_instructions(instructions)
{
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
}
void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{
copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
m_next_pc = m_initial_pc;
m_return = false;
while (!m_return)
{
m_pc = m_next_pc;
m_next_pc = m_pc + 1;
m_instructions->at(m_pc)->execute(*this);
}
// Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
}
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results)
{
// TODO: Check types of args and result
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs;
for (auto argument : arguments)
{
argument->collect_tensor_views(inputs, argument);
}
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> outputs;
for (auto result : results)
{
result->collect_tensor_views(outputs, result);
}
tensor_call(inputs, outputs);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
class PrimaryTensorView;
namespace ngvm
{
class Instruction;
// A VM for executing lightly-compiled graph functions.
class CallFrame : public ngraph::runtime::CallFrame
{
public:
CallFrame(
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
/// @brief Invoke the function with values matching the signature of the function.
///
/// Tuples will be expanded into their tensor views to build the call frame.
void
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{
return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
}
template <typename ET>
typename ET::type* get_tensor_view_data(size_t i)
{
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
}
protected:
size_t m_n_inputs;
size_t m_n_outputs;
TensorViewPtrs m_tensor_views;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
size_t m_next_pc;
bool m_return;
};
}
}
}
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
...@@ -24,28 +24,31 @@ namespace ngraph ...@@ -24,28 +24,31 @@ namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ETI, typename ETO> namespace eigen
class ConvertInstruction : public Instruction
{ {
public: template <typename ET>
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out) class AbsInstruction : public Instruction
: m_arg(arg)
, m_out(out)
{ {
} public:
AbsInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ETO>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg).template cast<typename ETO::type>(); Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
} }
protected: protected:
TensorViewInfo m_arg; TensorViewInfo m_arg;
TensorViewInfo m_out; TensorViewInfo m_out;
}; };
}
} }
} }
} }
...@@ -14,41 +14,44 @@ ...@@ -14,41 +14,44 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class AddInstruction : public Instruction
{ {
public: template <typename ET>
AddInstruction(const TensorViewInfo& arg0, class AddInstruction : public Instruction
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{ {
} public:
AddInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) +
EigenArray1d<ET>(call_frame, m_arg0) + EigenArray1d<ET>(call_frame, m_arg1); EigenArray1d<ET>(call_frame, m_arg1);
} }
protected: protected:
TensorViewInfo m_arg0; TensorViewInfo m_arg0;
TensorViewInfo m_arg1; TensorViewInfo m_arg1;
TensorViewInfo m_out; TensorViewInfo m_out;
}; };
}
} }
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class BroadcastScalarInstruction : public Instruction
{
public:
BroadcastScalarInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg)(0, 0);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
...@@ -14,41 +14,42 @@ ...@@ -14,41 +14,42 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class EqualInstruction : public Instruction
{ {
public: template <typename ET>
EqualInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out) class BroadcastVectorColwiseInstruction : public Instruction
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{ {
} public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<element::Bool>(call_frame, m_out) = EigenMatrix<ET>(call_frame, m_out).colwise() =
(EigenArray1d<ET>(call_frame, m_arg0) == EigenVector<ET>(call_frame, m_arg);
EigenArray1d<ET>(call_frame, m_arg1)) }
.template cast<char>();
}
protected: protected:
TensorViewInfo m_arg0; TensorViewInfo m_arg;
TensorViewInfo m_arg1; TensorViewInfo m_out;
TensorViewInfo m_out; };
}; }
} }
} }
} }
...@@ -14,38 +14,41 @@ ...@@ -14,38 +14,41 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class BroadcastVectorRowwiseInstruction : public Instruction
{ {
public: template <typename ET>
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg, class BroadcastVectorRowwiseInstruction : public Instruction
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{ {
} public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenMatrix<ET>(call_frame, m_out).rowwise() = EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose(); EigenVector<ET>(call_frame, m_arg).transpose();
} }
protected: protected:
TensorViewInfo m_arg; TensorViewInfo m_arg;
TensorViewInfo m_out; TensorViewInfo m_out;
}; };
}
} }
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
class CallInstruction : public Instruction
{
public:
CallInstruction(std::shared_ptr<ExternalFunction> ef,
std::vector<TensorViewInfo> in,
std::vector<TensorViewInfo> out)
: m_external_function(ef)
, m_in(in)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
m_external_function->make_call_frame());
std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs;
std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs;
for (auto in : m_in)
{
inputs.push_back(call_frame.get_tensor_view(in.get_index()));
}
for (auto out : m_out)
{
outputs.push_back(call_frame.get_tensor_view(out.get_index()));
}
(*cf)(inputs, outputs);
}
protected:
std::shared_ptr<ExternalFunction> m_external_function;
std::vector<TensorViewInfo> m_in;
std::vector<TensorViewInfo> m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class ConcatMatrixInstruction : public Instruction
{
public:
ConcatMatrixInstruction(const std::vector<TensorViewInfo>& args,
size_t axis,
const TensorViewInfo& out)
: m_args(args)
, m_axis(axis)
, m_out(out)
{
size_t concat_pos[2]{0, 0};
for (auto arg : args)
{
auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
m_blocks.push_back(
{concat_pos[0], concat_pos[1], arg_shape.at(0), arg_shape.at(1)});
concat_pos[axis] += arg_shape.at(axis);
}
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET> out(call_frame, m_out);
for (size_t i = 0; i < m_args.size(); i++)
{
auto& b = m_blocks[i];
out.block(b[0], b[1], b[2], b[3])
<< EigenMatrix<ET>(call_frame, m_args.at(i));
}
}
protected:
std::vector<TensorViewInfo> m_args;
size_t m_axis;
TensorViewInfo m_out;
std::vector<std::vector<size_t>> m_blocks;
};
}
}
}
}
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include <vector>
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
...@@ -24,46 +26,45 @@ namespace ngraph ...@@ -24,46 +26,45 @@ namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class ConcatMatrixInstruction : public Instruction
{ {
public: // Would be better to just generate a sequence of copy into slice of output instructions
ConcatMatrixInstruction(const std::vector<TensorViewInfo>& args, template <typename ET>
size_t axis, class ConcatVectorInstruction : public Instruction
const TensorViewInfo& out)
: m_args(args)
, m_axis(axis)
, m_out(out)
{ {
size_t concat_pos[2]{0, 0}; public:
for (auto arg : args) ConcatVectorInstruction(const std::vector<TensorViewInfo>& args,
const TensorViewInfo& out)
: m_args(args)
, m_out(out)
{ {
auto& arg_shape = arg.get_tensor_view_layout()->get_shape(); for (auto arg : args)
m_blocks.push_back( {
{concat_pos[0], concat_pos[1], arg_shape.at(0), arg_shape.at(1)}); auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
concat_pos[axis] += arg_shape.at(axis); m_sizes.push_back(arg_shape.at(0));
}
} }
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET> out(call_frame, m_out);
for (size_t i = 0; i < m_args.size(); i++)
{ {
auto& b = m_blocks[i]; EigenVector<ET> out(call_frame, m_out);
out.block(b[0], b[1], b[2], b[3]) size_t concat_pos = 0;
<< EigenMatrix<ET>(call_frame, m_args.at(i)); for (size_t i = 0; i < m_args.size(); i++)
{
out.segment(concat_pos, m_sizes[i])
<< EigenVector<ET>(call_frame, m_args.at(i));
concat_pos += m_sizes[i];
}
} }
}
protected: protected:
std::vector<TensorViewInfo> m_args; std::vector<TensorViewInfo> m_args;
size_t m_axis; TensorViewInfo m_out;
TensorViewInfo m_out; std::vector<size_t> m_sizes;
std::vector<std::vector<size_t>> m_blocks; };
}; }
} }
} }
} }
...@@ -14,42 +14,42 @@ ...@@ -14,42 +14,42 @@
#pragma once #pragma once
#include <cassert> #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
/// @brief Copies a tensor from in to out. namespace eigen
template <typename ET>
class CopyInstruction : public Instruction
{ {
public: template <typename ET>
/// @param in Index of input tensor in call frame. class ConstantInstruction : public Instruction
/// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out)
: m_in(in)
, m_out(out)
{ {
} public:
ConstantInstruction(const std::vector<typename ET::type> value,
const TensorViewInfo& out)
: m_value(value)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.get_parameterized_tensor_view<ET>(m_out)->get_vector() = call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())
call_frame.get_parameterized_tensor_view<ET>(m_in)->get_vector(); ->get_vector() = m_value;
} }
protected: protected:
size_t m_in; const std::vector<typename ET::type> m_value;
size_t m_out; TensorViewInfo m_out;
}; };
}
} }
} }
} }
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
...@@ -24,29 +24,32 @@ namespace ngraph ...@@ -24,29 +24,32 @@ namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class ConstantInstruction : public Instruction
{ {
public: template <typename ETI, typename ETO>
ConstantInstruction(const std::vector<typename ET::type> value, class ConvertInstruction : public Instruction
const TensorViewInfo& out)
: m_value(value)
, m_out(out)
{ {
} public:
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())->get_vector() = EigenArray1d<ETO>(call_frame, m_out) =
m_value; EigenArray1d<ETI>(call_frame, m_arg)
} .template cast<typename ETO::type>();
}
protected: protected:
const std::vector<typename ET::type> m_value; TensorViewInfo m_arg;
TensorViewInfo m_out; TensorViewInfo m_out;
}; };
}
} }
} }
} }
...@@ -14,41 +14,45 @@ ...@@ -14,41 +14,45 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include <cassert>
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class LessEqInstruction : public Instruction
{ {
public: /// @brief Copies a tensor from in to out.
LessEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out) template <typename ET>
: m_arg0(arg0) class CopyInstruction : public Instruction
, m_arg1(arg1)
, m_out(out)
{ {
} public:
/// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out)
: m_in(in)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenArray1d<element::Bool>(call_frame, m_out) = call_frame.get_parameterized_tensor_view<ET>(m_out)->get_vector() =
(EigenArray1d<ET>(call_frame, m_arg0) <= call_frame.get_parameterized_tensor_view<ET>(m_in)->get_vector();
EigenArray1d<ET>(call_frame, m_arg1)) }
.template cast<char>();
}
protected: protected:
TensorViewInfo m_arg0; size_t m_in;
TensorViewInfo m_arg1; size_t m_out;
TensorViewInfo m_out; };
}; }
} }
} }
} }
...@@ -14,41 +14,44 @@ ...@@ -14,41 +14,44 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
template <typename ET> namespace eigen
class MatrixMultInstruction : public Instruction
{ {
public: template <typename ET>
MatrixMultInstruction(const TensorViewInfo& arg0, class DivideInstruction : public Instruction
{
public:
DivideInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1, const TensorViewInfo& arg1,
const TensorViewInfo& out) const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
{ {
} }
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
EigenMatrix<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) /
EigenMatrix<ET>(call_frame, m_arg0) * EigenMatrix<ET>(call_frame, m_arg1); EigenArray1d<ET>(call_frame, m_arg1);
} }
protected: protected:
TensorViewInfo m_arg0; TensorViewInfo m_arg0;
TensorViewInfo m_arg1; TensorViewInfo m_arg1;
TensorViewInfo m_out; TensorViewInfo m_out;
}; };
}
} }
} }
} }
...@@ -14,51 +14,45 @@ ...@@ -14,51 +14,45 @@
#pragma once #pragma once
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp" #include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp" #include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp" #include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
{ {
namespace eigen namespace ngvm
{ {
// Would be better to just generate a sequence of copy into slice of output instructions namespace eigen
template <typename ET>
class ConcatVectorInstruction : public Instruction
{ {
public: template <typename ET>
ConcatVectorInstruction(const std::vector<TensorViewInfo>& args, class DotInstruction : public Instruction
const TensorViewInfo& out)
: m_args(args)
, m_out(out)
{ {
for (auto arg : args) public:
DotInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{ {
auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
m_sizes.push_back(arg_shape.at(0));
} }
}
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{
EigenVector<ET> out(call_frame, m_out);
size_t concat_pos = 0;
for (size_t i = 0; i < m_args.size(); i++)
{ {
out.segment(concat_pos, m_sizes[i]) EigenArray1d<ET>(call_frame, m_out)
<< EigenVector<ET>(call_frame, m_args.at(i)); << EigenVector<ET>(call_frame, m_arg0)
concat_pos += m_sizes[i]; .dot(EigenVector<ET>(call_frame, m_arg1));
} }
}
protected: protected:
std::vector<TensorViewInfo> m_args; TensorViewInfo m_arg0;
TensorViewInfo m_out; TensorViewInfo m_arg1;
std::vector<size_t> m_sizes; TensorViewInfo m_out;
}; };
}
} }
} }
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment