Commit aa3d8338 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Transformer/Backend (#166)

Add a minimal Backend API and make the interpreter use it.
read/write tensors (for frameworks)
parent 072fe644
......@@ -14,9 +14,11 @@
set (SRC
descriptor/input.cpp
descriptor/layout/dense_tensor_view_layout.cpp
descriptor/layout/tensor_view_layout.cpp
descriptor/output.cpp
descriptor/primary_tensor_view.cpp
descriptor/tensor.cpp
descriptor/tensor_view.cpp
descriptor/tuple.cpp
function.cpp
log.cpp
......@@ -50,8 +52,13 @@ set (SRC
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/visualize_tree.cpp
runtime/call_frame.cpp
runtime/external_function.cpp
runtime/backend.cpp
runtime/manager.cpp
runtime/ngvm/call_frame.cpp
runtime/ngvm/external_function.cpp
runtime/ngvm/ngvm_backend.cpp
runtime/ngvm/ngvm_manager.cpp
runtime/tensor_view.cpp
runtime/tuple.cpp
runtime/utils.cpp
shape.cpp
......
......@@ -15,6 +15,7 @@
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::descriptor::layout;
using ngraph::Shape;
......
......@@ -17,14 +17,14 @@
#include <cstddef>
#include <vector>
#include "ngraph/descriptor/buffer.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
namespace ngraph
{
namespace descriptor
{
class TensorView;
namespace layout
{
/// @brief The standard strided layout, used for row-major and column-major, their permutations and slices.
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::descriptor::layout;
TensorViewLayout::TensorViewLayout(const ngraph::descriptor::TensorView& tensor_view)
: m_tensor_view_type(tensor_view.get_tensor_view_type())
{
}
const ngraph::element::Type& TensorViewLayout::get_element_type() const
{
return m_tensor_view_type->get_element_type();
}
const ngraph::Shape& TensorViewLayout::get_shape() const
{
return m_tensor_view_type->get_shape();
}
......@@ -23,6 +23,11 @@
namespace ngraph
{
namespace element
{
class Type;
}
namespace descriptor
{
class TensorView;
......@@ -35,10 +40,7 @@ namespace ngraph
class TensorViewLayout
{
protected:
TensorViewLayout(const ngraph::descriptor::TensorView& tensor_view)
: m_tensor_view_type(tensor_view.get_tensor_view_type())
{
}
TensorViewLayout(const ngraph::descriptor::TensorView& tensor_view);
public:
virtual ~TensorViewLayout() {}
......@@ -52,11 +54,8 @@ namespace ngraph
/// With non-linear buffers, this will need to be something other than size_t.
virtual size_t get_index_offset(const std::vector<size_t>& indices) = 0;
const element::Type& get_element_type() const
{
return m_tensor_view_type->get_element_type();
}
const Shape& get_shape() const { return m_tensor_view_type->get_shape(); }
const element::Type& get_element_type() const;
const Shape& get_shape() const;
/// Where this view is located in the buffer.
const BufferPos& get_buffer_pos() const { return m_buffer_pos; }
BufferPos& get_buffer_pos() { return m_buffer_pos; }
......
......@@ -12,14 +12,12 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor;
namespace ngraph
std::shared_ptr<const ngraph::ValueType> TensorView::get_value_type() const
{
namespace runtime
{
using TensorViewIndex = unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t>;
}
}
\ No newline at end of file
return m_tensor_view_type;
}
......@@ -16,15 +16,13 @@
#include <memory>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/descriptor/value.hpp"
#include "ngraph/log.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph
{
class Node;
class TensorViewType;
namespace descriptor
{
......@@ -34,6 +32,9 @@ namespace ngraph
class TensorViewLayout;
}
class Tensor;
class TensorView;
/// @brief Compile-time descriptor of a first-class value that is a view of a tensor.
class TensorView : public Value
{
......@@ -51,10 +52,7 @@ namespace ngraph
virtual const Tensor& get_tensor() const = 0;
virtual Tensor& get_tensor() = 0;
virtual std::shared_ptr<const ValueType> get_value_type() const override
{
return m_tensor_view_type;
}
virtual std::shared_ptr<const ValueType> get_value_type() const override;
const std::string& get_name() const { return m_name; }
std::shared_ptr<const TensorViewType> get_tensor_view_type() const
......
......@@ -15,11 +15,12 @@
#pragma once
#include <memory>
#include "ngraph/types/type.hpp"
#include <vector>
namespace ngraph
{
class ValueType;
namespace descriptor
{
class TensorView;
......
......@@ -25,7 +25,6 @@
#include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph
......
......@@ -80,9 +80,12 @@
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/manager.hpp"
#include "ngraph/runtime/ngvm/ngvm_backend.hpp"
#include "ngraph/runtime/ngvm/ngvm_manager.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tuple.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime;
std::shared_ptr<TensorView>
Backend::make_primary_tensor_view(const ngraph::element::Type& element_type, const Shape& shape)
{
return element_type.make_primary_tensor_view(shape);
}
std::shared_ptr<ngraph::runtime::Tuple>
Backend::make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements)
{
return std::make_shared<ngraph::runtime::Tuple>(elements);
}
......@@ -14,46 +14,57 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include <memory>
#include "ngraph/common.hpp"
namespace ngraph
{
namespace element
{
class Type;
}
namespace runtime
{
namespace eigen
class ExternalFunction;
class CallFrame;
class TensorView;
class Tuple;
class Value;
template <typename ET>
class ParameterizedTensorView;
/// @brief Interface to a generic backend.
///
/// Backends are responsible for function execution and value allocation.
class Backend
{
public:
virtual ~Backend() {}
/// @brief Make a call frame that can support one concurrent call of an external function.
///
/// If more than one concurrent execution is needed, each execution will require its own call frame.
virtual std::shared_ptr<ngraph::runtime::CallFrame>
make_call_frame(const std::shared_ptr<ExternalFunction>& external_function) = 0;
/// @brief Return a handle for a tensor on the backend device.
virtual std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const ngraph::element::Type& element_type,
const Shape& shape);
template <typename ET>
class ScalarTensorProductInstruction : public Instruction
std::shared_ptr<ngraph::runtime::ParameterizedTensorView<ET>>
make_parameterized_tensor_view(const Shape& shape)
{
public:
ScalarTensorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenVector<ET>(call_frame, m_out) =
call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] *
EigenVector<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
return std::dynamic_pointer_cast<ngraph::runtime::ParameterizedTensorView<ET>>(
make_primary_tensor_view(ET::element_type(), shape));
}
/// @brief Construct a tuple handle from a sequence of values.
virtual std::shared_ptr<ngraph::runtime::Tuple>
make_tuple(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& elements);
};
}
}
......@@ -25,51 +25,23 @@ namespace ngraph
namespace runtime
{
class PrimaryTensorView;
class Instruction;
class Value;
// A VM for executing lightly-compiled graph functions.
class CallFrame
{
public:
CallFrame(
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
virtual ~CallFrame() {}
/// @brief Invoke the function with values matching the signature of the function.
///
/// Tuples will be expanded into their tensor views to build the call frame.
void operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
virtual void
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs) = 0;
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{
return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
}
template <typename ET>
typename ET::type* get_tensor_view_data(size_t i)
{
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
}
protected:
size_t m_n_inputs;
size_t m_n_outputs;
TensorViewPtrs m_tensor_views;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
size_t m_next_pc;
bool m_return;
virtual void tensor_call(const TensorViewPtrs& inputs,
const TensorViewPtrs& outputs) = 0;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class DivideInstruction : public Instruction
{
public:
DivideInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0) / EigenArray1d<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class DotInstruction : public Instruction
{
public:
DotInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out)
<< EigenVector<ET>(call_frame, m_arg0)
.dot(EigenVector<ET>(call_frame, m_arg1));
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
{
template <typename ET>
class MaximumInstruction : public Instruction
{
public:
MaximumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0)
.max(EigenArray1d<ET>(call_frame, m_arg1));
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
class CallFrame;
namespace eigen
{
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET>
using DynamicArray = Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
template <typename ET>
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicMatrix =
Eigen::Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicVector = Eigen::Matrix<typename ET::type, Eigen::Dynamic, 1>;
template <typename ET>
using EigenVectorBase = Eigen::Map<DynamicVector<ET>, 0, VectorStrides>;
namespace fmt
{
/// @brief vector format for Eigen wrappers.
class V
{
public:
V(const TensorViewInfo& tensor_view_info)
: l0(tensor_view_info
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size())
{
}
public:
size_t l0;
size_t l1{1};
size_t s0{1};
size_t s1{1};
};
class M
{
M(const Shape& shape, const Strides& strides)
: l0(shape.at(0))
, l1(shape.at(1))
, s0(strides.at(0))
, s1(strides.at(1))
{
}
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const TensorViewInfo& tensor_view_info)
: M(tensor_view_info
.get_layout<ngraph::descriptor::layout::DenseTensorViewLayout>())
{
}
public:
size_t l0;
size_t l1;
size_t s0;
size_t s1;
};
}
// ET element type
// FMT array format (fmt::V for vector, etc.)
// BASE select array/matrix
template <typename ET, typename FMT, typename BASE, typename STRIDES = DynamicStrides>
class EigenWrapper : public BASE
{
using base = BASE;
public:
EigenWrapper(typename ET::type* t, const FMT& fmt)
: base(t, fmt.l0, fmt.l1, STRIDES(fmt.s0, fmt.s1))
{
}
EigenWrapper(
typename ET::type* t,
const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: base(t, layout->get_size(), 1, DynamicStrides(1, 1))
{
}
EigenWrapper(CallFrame& call_frame, const TensorViewInfo& tensor_view_info)
: EigenWrapper(
call_frame.get_tensor_view_data<ET>(tensor_view_info.get_index()),
FMT(tensor_view_info))
{
}
template <typename U>
EigenWrapper& operator=(const U& other)
{
this->base::operator=(other);
return *this;
}
};
template <typename ET, typename FMT = fmt::V>
using EigenArray1d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenArray2d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenMatrix = EigenWrapper<ET, FMT, EigenMatrixBase<ET>>;
template <typename ET, typename FMT = fmt::V>
using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>;
}
}
}
......@@ -26,45 +26,29 @@ namespace ngraph
{
namespace runtime
{
class CallFrame;
class ExternalFunction
{
using FunctionMap =
std::unordered_map<std::shared_ptr<Function>, std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
public:
protected:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(FunctionMap& function_map);
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
get_instructions()
bool release_function = true)
: m_function(function)
, m_release_function(release_function)
, m_is_compiled(false)
{
return m_instructions;
}
// Release original function's resources
void release_function() { m_function = nullptr; }
protected:
void compile();
void compile(FunctionMap& function_map);
public:
virtual ~ExternalFunction() {}
virtual std::shared_ptr<CallFrame> make_call_frame() = 0;
protected:
std::shared_ptr<ngraph::Function> m_function;
bool m_release_function;
bool m_is_compiled;
size_t m_n_inputs;
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions;
ngraph::descriptor::TensorViewPtrs m_temp_views;
static OpMap& get_op_map();
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/manager.hpp"
using namespace ngraph::runtime;
Manager::FactoryMap& Manager::get_factory_map()
{
static FactoryMap factory_map;
return factory_map;
}
std::shared_ptr<Manager> Manager::get(const std::string& name)
{
return get_factory_map().at(name)(name);
}
Manager::Factory Manager::register_factory(std::string name, Factory factory)
{
get_factory_map()[name] = factory;
return factory;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <string>
namespace ngraph
{
class Function;
namespace runtime
{
class Backend;
class ExternalFunction;
/// @brief Interface to a generic manager.
///
/// A manager provides access to compilation for a backend, and a means to obtain
/// a backed for execution and allocation.
class Manager
{
public:
virtual ~Manager() {}
/// @brief Allocate a backend for this transformer.
///
/// Specific transformers may provide addtional methods for allocating customized backends.
virtual std::shared_ptr<Backend> allocate_backend() = 0;
/// @brief Convert a function to a form that can be run on a backend.
virtual std::shared_ptr<ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) = 0;
using Factory = std::function<std::shared_ptr<Manager>(const std::string&)>;
using FactoryMap = std::map<std::string, Factory>;
static FactoryMap& get_factory_map();
static std::shared_ptr<Manager> get(const std::string& name);
static Factory register_factory(std::string name, Factory factory);
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
using namespace std;
using namespace ngraph::runtime::ngvm;
CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps,
size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_initial_pc(initial_pc)
, m_instructions(instructions)
{
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
}
void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{
copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
m_next_pc = m_initial_pc;
m_return = false;
while (!m_return)
{
m_pc = m_next_pc;
m_next_pc = m_pc + 1;
m_instructions->at(m_pc)->execute(*this);
}
// Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
}
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results)
{
// TODO: Check types of args and result
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs;
for (auto argument : arguments)
{
argument->collect_tensor_views(inputs, argument);
}
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> outputs;
for (auto result : results)
{
result->collect_tensor_views(outputs, result);
}
tensor_call(inputs, outputs);
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
class PrimaryTensorView;
namespace ngvm
{
class Instruction;
// A VM for executing lightly-compiled graph functions.
class CallFrame : public ngraph::runtime::CallFrame
{
public:
CallFrame(
size_t n_inputs,
size_t n_outputs,
const TensorViewPtrs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
/// @brief Invoke the function with values matching the signature of the function.
///
/// Tuples will be expanded into their tensor views to build the call frame.
void
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{
return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
}
template <typename ET>
typename ET::type* get_tensor_view_data(size_t i)
{
return &get_parameterized_tensor_view<ET>(i)->get_vector()[0];
}
protected:
size_t m_n_inputs;
size_t m_n_outputs;
TensorViewPtrs m_tensor_views;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
size_t m_next_pc;
bool m_return;
};
}
}
}
......@@ -14,9 +14,9 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
......@@ -24,28 +24,31 @@ namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction
namespace eigen
{
public:
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
template <typename ET>
class AbsInstruction : public Instruction
{
}
public:
AbsInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ETO>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg).template cast<typename ETO::type>();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,44 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class AddInstruction : public Instruction
namespace eigen
{
public:
AddInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class AddInstruction : public Instruction
{
}
public:
AddInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0) + EigenArray1d<ET>(call_frame, m_arg1);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) +
EigenArray1d<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class BroadcastScalarInstruction : public Instruction
{
public:
BroadcastScalarInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg)(0, 0);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,42 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class EqualInstruction : public Instruction
namespace eigen
{
public:
EqualInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class BroadcastVectorColwiseInstruction : public Instruction
{
}
public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) ==
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).colwise() =
EigenVector<ET>(call_frame, m_arg);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,38 +14,41 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class BroadcastVectorRowwiseInstruction : public Instruction
namespace eigen
{
public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
template <typename ET>
class BroadcastVectorRowwiseInstruction : public Instruction
{
}
public:
BroadcastVectorRowwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).rowwise() =
EigenVector<ET>(call_frame, m_arg).transpose();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
class CallInstruction : public Instruction
{
public:
CallInstruction(std::shared_ptr<ExternalFunction> ef,
std::vector<TensorViewInfo> in,
std::vector<TensorViewInfo> out)
: m_external_function(ef)
, m_in(in)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
m_external_function->make_call_frame());
std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs;
std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs;
for (auto in : m_in)
{
inputs.push_back(call_frame.get_tensor_view(in.get_index()));
}
for (auto out : m_out)
{
outputs.push_back(call_frame.get_tensor_view(out.get_index()));
}
(*cf)(inputs, outputs);
}
protected:
std::shared_ptr<ExternalFunction> m_external_function;
std::vector<TensorViewInfo> m_in;
std::vector<TensorViewInfo> m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class ConcatMatrixInstruction : public Instruction
{
public:
ConcatMatrixInstruction(const std::vector<TensorViewInfo>& args,
size_t axis,
const TensorViewInfo& out)
: m_args(args)
, m_axis(axis)
, m_out(out)
{
size_t concat_pos[2]{0, 0};
for (auto arg : args)
{
auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
m_blocks.push_back(
{concat_pos[0], concat_pos[1], arg_shape.at(0), arg_shape.at(1)});
concat_pos[axis] += arg_shape.at(axis);
}
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET> out(call_frame, m_out);
for (size_t i = 0; i < m_args.size(); i++)
{
auto& b = m_blocks[i];
out.block(b[0], b[1], b[2], b[3])
<< EigenMatrix<ET>(call_frame, m_args.at(i));
}
}
protected:
std::vector<TensorViewInfo> m_args;
size_t m_axis;
TensorViewInfo m_out;
std::vector<std::vector<size_t>> m_blocks;
};
}
}
}
}
......@@ -14,9 +14,11 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include <vector>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
......@@ -24,46 +26,45 @@ namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class ConcatMatrixInstruction : public Instruction
namespace eigen
{
public:
ConcatMatrixInstruction(const std::vector<TensorViewInfo>& args,
size_t axis,
const TensorViewInfo& out)
: m_args(args)
, m_axis(axis)
, m_out(out)
// Would be better to just generate a sequence of copy into slice of output instructions
template <typename ET>
class ConcatVectorInstruction : public Instruction
{
size_t concat_pos[2]{0, 0};
for (auto arg : args)
public:
ConcatVectorInstruction(const std::vector<TensorViewInfo>& args,
const TensorViewInfo& out)
: m_args(args)
, m_out(out)
{
auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
m_blocks.push_back(
{concat_pos[0], concat_pos[1], arg_shape.at(0), arg_shape.at(1)});
concat_pos[axis] += arg_shape.at(axis);
for (auto arg : args)
{
auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
m_sizes.push_back(arg_shape.at(0));
}
}
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET> out(call_frame, m_out);
for (size_t i = 0; i < m_args.size(); i++)
virtual void execute(CallFrame& call_frame) const override
{
auto& b = m_blocks[i];
out.block(b[0], b[1], b[2], b[3])
<< EigenMatrix<ET>(call_frame, m_args.at(i));
EigenVector<ET> out(call_frame, m_out);
size_t concat_pos = 0;
for (size_t i = 0; i < m_args.size(); i++)
{
out.segment(concat_pos, m_sizes[i])
<< EigenVector<ET>(call_frame, m_args.at(i));
concat_pos += m_sizes[i];
}
}
}
protected:
std::vector<TensorViewInfo> m_args;
size_t m_axis;
TensorViewInfo m_out;
std::vector<std::vector<size_t>> m_blocks;
};
protected:
std::vector<TensorViewInfo> m_args;
TensorViewInfo m_out;
std::vector<size_t> m_sizes;
};
}
}
}
}
......@@ -14,42 +14,42 @@
#pragma once
#include <cassert>
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
/// @brief Copies a tensor from in to out.
template <typename ET>
class CopyInstruction : public Instruction
namespace eigen
{
public:
/// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out)
: m_in(in)
, m_out(out)
template <typename ET>
class ConstantInstruction : public Instruction
{
}
public:
ConstantInstruction(const std::vector<typename ET::type> value,
const TensorViewInfo& out)
: m_value(value)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.get_parameterized_tensor_view<ET>(m_out)->get_vector() =
call_frame.get_parameterized_tensor_view<ET>(m_in)->get_vector();
}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())
->get_vector() = m_value;
}
protected:
size_t m_in;
size_t m_out;
};
protected:
const std::vector<typename ET::type> m_value;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,9 +14,9 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
......@@ -24,29 +24,32 @@ namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class ConstantInstruction : public Instruction
namespace eigen
{
public:
ConstantInstruction(const std::vector<typename ET::type> value,
const TensorViewInfo& out)
: m_value(value)
, m_out(out)
template <typename ETI, typename ETO>
class ConvertInstruction : public Instruction
{
}
public:
ConvertInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.get_parameterized_tensor_view<ET>(m_out.get_index())->get_vector() =
m_value;
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ETO>(call_frame, m_out) =
EigenArray1d<ETI>(call_frame, m_arg)
.template cast<typename ETO::type>();
}
protected:
const std::vector<typename ET::type> m_value;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,45 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include <cassert>
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class LessEqInstruction : public Instruction
namespace eigen
{
public:
LessEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
/// @brief Copies a tensor from in to out.
template <typename ET>
class CopyInstruction : public Instruction
{
}
public:
/// @param in Index of input tensor in call frame.
/// @param out Index of output tensor in call frame.
CopyInstruction(size_t in, size_t out)
: m_in(in)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.get_parameterized_tensor_view<ET>(m_out)->get_vector() =
call_frame.get_parameterized_tensor_view<ET>(m_in)->get_vector();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
size_t m_in;
size_t m_out;
};
}
}
}
}
......@@ -14,41 +14,44 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class MatrixMultInstruction : public Instruction
namespace eigen
{
public:
MatrixMultInstruction(const TensorViewInfo& arg0,
template <typename ET>
class DivideInstruction : public Instruction
{
public:
DivideInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out) =
EigenMatrix<ET>(call_frame, m_arg0) * EigenMatrix<ET>(call_frame, m_arg1);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) /
EigenArray1d<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,51 +14,45 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
// Would be better to just generate a sequence of copy into slice of output instructions
template <typename ET>
class ConcatVectorInstruction : public Instruction
namespace eigen
{
public:
ConcatVectorInstruction(const std::vector<TensorViewInfo>& args,
const TensorViewInfo& out)
: m_args(args)
, m_out(out)
template <typename ET>
class DotInstruction : public Instruction
{
for (auto arg : args)
public:
DotInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
auto& arg_shape = arg.get_tensor_view_layout()->get_shape();
m_sizes.push_back(arg_shape.at(0));
}
}
virtual void execute(CallFrame& call_frame) const override
{
EigenVector<ET> out(call_frame, m_out);
size_t concat_pos = 0;
for (size_t i = 0; i < m_args.size(); i++)
virtual void execute(CallFrame& call_frame) const override
{
out.segment(concat_pos, m_sizes[i])
<< EigenVector<ET>(call_frame, m_args.at(i));
concat_pos += m_sizes[i];
EigenArray1d<ET>(call_frame, m_out)
<< EigenVector<ET>(call_frame, m_arg0)
.dot(EigenVector<ET>(call_frame, m_arg1));
}
}
protected:
std::vector<TensorViewInfo> m_args;
TensorViewInfo m_out;
std::vector<size_t> m_sizes;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,46 +14,44 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class SelectInstruction : public Instruction
namespace eigen
{
public:
SelectInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo arg2,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_arg2(arg2)
, m_out(out)
template <typename ET>
class EqualInstruction : public Instruction
{
}
public:
EqualInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<element::Bool>(call_frame, m_arg0)
.select(EigenArray1d<ET>(call_frame, m_arg1),
EigenArray1d<ET>(call_frame, m_arg2));
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) ==
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_arg2;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,53 +14,54 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
class CallInstruction : public Instruction
namespace eigen
{
public:
CallInstruction(std::shared_ptr<ExternalFunction> ef,
std::vector<TensorViewInfo> in,
std::vector<TensorViewInfo> out)
: m_external_function(ef)
, m_in(in)
, m_out(out)
template <typename TI, typename TO>
void greater_eq(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map_array(&*arg0) <= get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char);
}
virtual void execute(CallFrame& call_frame) const override
template <typename ET>
class GreaterEqInstruction : public Instruction
{
std::shared_ptr<CallFrame> cf = m_external_function->make_call_frame();
std::vector<std::shared_ptr<ngraph::runtime::Value>> inputs;
std::vector<std::shared_ptr<ngraph::runtime::Value>> outputs;
for (auto in : m_in)
public:
GreaterEqInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
inputs.push_back(call_frame.get_tensor_view(in.get_index()));
}
for (auto out : m_out)
virtual void execute(CallFrame& call_frame) const override
{
outputs.push_back(call_frame.get_tensor_view(out.get_index()));
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
(*cf)(inputs, outputs);
}
protected:
std::shared_ptr<ExternalFunction> m_external_function;
std::vector<TensorViewInfo> m_in;
std::vector<TensorViewInfo> m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,38 +14,46 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class BroadcastVectorColwiseInstruction : public Instruction
namespace eigen
{
public:
BroadcastVectorColwiseInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
template <typename ET>
class GreaterThanInstruction : public Instruction
{
}
public:
GreaterThanInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out).colwise() =
EigenVector<ET>(call_frame, m_arg);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,40 +14,44 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class BroadcastScalarInstruction : public Instruction
namespace eigen
{
public:
BroadcastScalarInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
template <typename ET>
class LessEqInstruction : public Instruction
{
}
public:
LessEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg)(0, 0);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,46 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class LessThanInstruction : public Instruction
namespace eigen
{
public:
LessThanInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class LessThanInstruction : public Instruction
{
}
public:
LessThanInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) <
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,37 +14,40 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class LogInstruction : public Instruction
namespace eigen
{
public:
LogInstruction(TensorViewInfo arg, TensorViewInfo out)
: m_arg(arg)
, m_out(out)
template <typename ET>
class LogInstruction : public Instruction
{
}
public:
LogInstruction(TensorViewInfo arg, TensorViewInfo out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg));
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET, fmt::V>(call_frame, m_out) =
Eigen::log(EigenArray1d<ET, fmt::V>(call_frame, m_arg));
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class MatrixMultInstruction : public Instruction
{
public:
MatrixMultInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out) = EigenMatrix<ET>(call_frame, m_arg0) *
EigenMatrix<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,44 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class MatrixVectorProductInstruction : public Instruction
namespace eigen
{
public:
MatrixVectorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class MatrixVectorProductInstruction : public Instruction
{
}
public:
MatrixVectorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenVector<ET>(call_frame, m_out) =
EigenMatrix<ET>(call_frame, m_arg0) * EigenVector<ET>(call_frame, m_arg1);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenVector<ET>(call_frame, m_out) = EigenMatrix<ET>(call_frame, m_arg0) *
EigenVector<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,43 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class GreaterThanInstruction : public Instruction
namespace eigen
{
public:
GreaterThanInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class MaximumInstruction : public Instruction
{
}
public:
MaximumInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0)
.max(EigenArray1d<ET>(call_frame, m_arg1));
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,38 +14,43 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class MultiplyInstruction : public Instruction
namespace eigen
{
public:
MultiplyInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class MultiplyInstruction : public Instruction
{
}
public:
MultiplyInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0) * EigenArray1d<ET>(call_frame, m_arg1);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) *
EigenArray1d<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,38 +14,39 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class AbsInstruction : public Instruction
namespace eigen
{
public:
AbsInstruction(const TensorViewInfo& arg, const TensorViewInfo& out)
: m_arg(arg)
, m_out(out)
template <typename ET>
class NegateInstruction : public Instruction
{
}
public:
NegateInstruction(TensorViewInfo arg, TensorViewInfo out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
Eigen::abs(EigenArray1d<ET>(call_frame, m_arg));
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = -EigenArray1d<ET>(call_frame, m_arg);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,41 +14,46 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class NotEqualInstruction : public Instruction
namespace eigen
{
public:
NotEqualInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class NotEqualInstruction : public Instruction
{
}
public:
NotEqualInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) !=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) !=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
class ReturnInstruction : public Instruction
{
public:
ReturnInstruction() {}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.set_return();
}
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class ScalarTensorProductInstruction : public Instruction
{
public:
ScalarTensorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
// This is a bit hacky: regardless of the tensor rank we
// pull it out as a vector. This works because of the way
// fmt::V computes sizes---it lumps together any higher
// dimensions---while fmt::M ignores them.
EigenVector<ET>(call_frame, m_out) =
call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0] *
EigenVector<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,49 +14,49 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename TI, typename TO>
void greater_eq(TI arg0, TI arg1, TO out)
namespace eigen
{
auto result_as_float = get_map_array(&*arg0) <= get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map_array(&*out, result_as_char);
}
template <typename ET>
class GreaterEqInstruction : public Instruction
{
public:
GreaterEqInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class SelectInstruction : public Instruction
{
}
public:
SelectInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo arg2,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_arg2(arg2)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<element::Bool>(call_frame, m_out) =
(EigenArray1d<ET>(call_frame, m_arg0) >=
EigenArray1d<ET>(call_frame, m_arg1))
.template cast<char>();
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<element::Bool>(call_frame, m_arg0)
.select(EigenArray1d<ET>(call_frame, m_arg1),
EigenArray1d<ET>(call_frame, m_arg2));
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_arg2;
TensorViewInfo m_out;
};
}
}
}
}
......@@ -14,39 +14,44 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class SubtractInstruction : public Instruction
namespace eigen
{
public:
SubtractInstruction(TensorViewInfo arg0, TensorViewInfo arg1, TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
template <typename ET>
class SubtractInstruction : public Instruction
{
}
public:
SubtractInstruction(TensorViewInfo arg0,
TensorViewInfo arg1,
TensorViewInfo out)
: m_arg0(arg0)
, m_arg1(arg1)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) =
EigenArray1d<ET>(call_frame, m_arg0) - EigenArray1d<ET>(call_frame, m_arg1);
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_arg0) -
EigenArray1d<ET>(call_frame, m_arg1);
}
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
protected:
TensorViewInfo m_arg0;
TensorViewInfo m_arg1;
TensorViewInfo m_out;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <Eigen/Dense>
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
class TensorViewInfo;
namespace ngvm
{
class CallFrame;
namespace eigen
{
using DynamicStrides = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
using VectorStrides = Eigen::Stride<Eigen::Dynamic, 1>;
template <typename ET>
using DynamicArray =
Eigen::Array<typename ET::type, Eigen::Dynamic, Eigen::Dynamic>;
template <typename ET>
using EigenArrayBase = Eigen::Map<DynamicArray<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicMatrix = Eigen::
Matrix<typename ET::type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
template <typename ET>
using EigenMatrixBase = Eigen::Map<DynamicMatrix<ET>, 0, DynamicStrides>;
template <typename ET>
using DynamicVector = Eigen::Matrix<typename ET::type, Eigen::Dynamic, 1>;
template <typename ET>
using EigenVectorBase = Eigen::Map<DynamicVector<ET>, 0, VectorStrides>;
namespace fmt
{
/// @brief vector format for Eigen wrappers.
class V
{
public:
V(const TensorViewInfo& tensor_view_info)
: l0(tensor_view_info
.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>()
->get_size())
{
}
public:
size_t l0;
size_t l1{1};
size_t s0{1};
size_t s1{1};
};
class M
{
M(const Shape& shape, const Strides& strides)
: l0(shape.at(0))
, l1(shape.at(1))
, s0(strides.at(0))
, s1(strides.at(1))
{
}
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const TensorViewInfo& tensor_view_info)
: M(tensor_view_info.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>())
{
}
public:
size_t l0;
size_t l1;
size_t s0;
size_t s1;
};
}
// ET element type
// FMT array format (fmt::V for vector, etc.)
// BASE select array/matrix
template <typename ET,
typename FMT,
typename BASE,
typename STRIDES = DynamicStrides>
class EigenWrapper : public BASE
{
using base = BASE;
public:
EigenWrapper(typename ET::type* t, const FMT& fmt)
: base(t, fmt.l0, fmt.l1, STRIDES(fmt.s0, fmt.s1))
{
}
EigenWrapper(
typename ET::type* t,
const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: base(t, layout->get_size(), 1, DynamicStrides(1, 1))
{
}
EigenWrapper(CallFrame& call_frame, const TensorViewInfo& tensor_view_info)
: EigenWrapper(
call_frame.get_tensor_view_data<ET>(tensor_view_info.get_index()),
FMT(tensor_view_info))
{
}
template <typename U>
EigenWrapper& operator=(const U& other)
{
this->base::operator=(other);
return *this;
}
};
template <typename ET, typename FMT = fmt::V>
using EigenArray1d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenArray2d = EigenWrapper<ET, FMT, EigenArrayBase<ET>>;
template <typename ET, typename FMT = fmt::M>
using EigenMatrix = EigenWrapper<ET, FMT, EigenMatrixBase<ET>>;
template <typename ET, typename FMT = fmt::V>
using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>;
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/function.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
class Instruction;
class ExternalFunction : public ngraph::runtime::ExternalFunction
{
using FunctionMap = std::unordered_map<std::shared_ptr<Function>,
std::shared_ptr<ExternalFunction>>;
using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame();
virtual std::shared_ptr<ngraph::runtime::CallFrame>
make_call_frame(FunctionMap& function_map);
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> get_instructions()
{
return m_instructions;
}
// Release original function's resources
void release_function() { m_function = nullptr; }
protected:
void compile();
void compile(FunctionMap& function_map);
size_t m_n_inputs;
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
ngraph::descriptor::TensorViewPtrs m_temp_views;
static OpMap& get_op_map();
};
}
}
}
......@@ -14,35 +14,26 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/eigen/utils.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include <memory>
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
template <typename ET>
class NegateInstruction : public Instruction
class CallFrame;
/// @brief An interpreter for an Op
///
/// The call_frame has a vector of instructions and calls execute on each instruction, passing it the call_frame.
/// Instructions get argument, result, and intermediate tensor views from the call frame. Instructions may also
/// set a flag in the call_frame to end execution, or adjust execution by modifying the position in the instruction vector.
class Instruction
{
public:
NegateInstruction(TensorViewInfo arg, TensorViewInfo out)
: m_arg(arg)
, m_out(out)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenArray1d<ET>(call_frame, m_out) = -EigenArray1d<ET>(call_frame, m_arg);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
virtual ~Instruction() {}
virtual void execute(CallFrame& call_frame) const = 0;
};
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/ngvm/ngvm_backend.hpp"
#include "ngraph/runtime/external_function.hpp"
using namespace ngraph::runtime::ngvm;
std::shared_ptr<ngraph::runtime::CallFrame>
NGVMBackend::make_call_frame(const std::shared_ptr<ExternalFunction>& external_function)
{
return external_function->make_call_frame();
}
......@@ -14,23 +14,20 @@
#pragma once
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/instruction.hpp"
#include "ngraph/runtime/backend.hpp"
namespace ngraph
{
namespace runtime
{
namespace eigen
namespace ngvm
{
class ReturnInstruction : public Instruction
/// @brief Transformer for the interpreted backend
class NGVMBackend : public Backend
{
public:
ReturnInstruction() {}
virtual void execute(CallFrame& call_frame) const override
{
call_frame.set_return();
}
virtual std::shared_ptr<ngraph::runtime::CallFrame> make_call_frame(
const std::shared_ptr<ngraph::runtime::ExternalFunction>& external_function);
};
}
}
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/ngvm/ngvm_backend.hpp"
#include "ngraph/runtime/ngvm/ngvm_manager.hpp"
using namespace ngraph::runtime::ngvm;
std::shared_ptr<ngraph::runtime::Backend> NGVMManager::allocate_backend()
{
return std::make_shared<NGVMBackend>();
}
std::shared_ptr<ngraph::runtime::ExternalFunction>
NGVMManager::compile(const std::shared_ptr<ngraph::Function>& fun)
{
return std::make_shared<ExternalFunction>(fun);
}
ngraph::runtime::Manager::Factory NGVMManager::factory = ngraph::runtime::Manager::register_factory(
"NGVM", [](const std::string& name) -> std::shared_ptr<ngraph::runtime::Manager> {
return std::make_shared<NGVMManager>();
});
......@@ -16,24 +16,29 @@
#include <memory>
#include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/manager.hpp"
namespace ngraph
{
class Function;
namespace runtime
{
class CallFrame;
/// @brief An interpreter for an Op
///
/// The call_frame has a vector of instructions and calls execute on each instruction, passing it the call_frame.
/// Instructions get argument, result, and intermediate tensor views from the call frame. Instructions may also
/// set a flag in the call_frame to end execution, or adjust execution by modifying the position in the instruction vector.
class Instruction
class ExternalFunction;
namespace ngvm
{
public:
virtual ~Instruction() {}
virtual void execute(CallFrame& call_frame) const = 0;
/// @brief Transformer for the interpreted backend
class NGVMManager : public Manager
{
public:
virtual std::shared_ptr<Backend> allocate_backend() override;
virtual std::shared_ptr<ngraph::runtime::ExternalFunction>
compile(const std::shared_ptr<ngraph::Function>& fun) override;
static Factory factory;
};
};
}
}
......@@ -14,6 +14,7 @@
#pragma once
#include <cstring>
#include <memory>
#include <vector>
......@@ -61,6 +62,48 @@ namespace ngraph
// For getting the data out
storage_type& get_vector() { return m_vector; }
virtual void write(const void* p, size_t tensor_offset, size_t n) override
{
size_t elt_offset = tensor_offset / sizeof(typename ET::type);
if (elt_offset * sizeof(typename ET::type) != tensor_offset)
{
throw ngraph_error("Attempt to write to an address not aligned on an element");
}
size_t elt_n = n / sizeof(typename ET::type);
if (elt_n * sizeof(typename ET::type) != n)
{
throw ngraph_error("Attemmpt to write a partial element");
}
size_t elt_byte_size = sizeof(typename ET::type) * n;
if (tensor_offset + n > elt_byte_size)
{
throw ngraph_error("Attempt to write beyond the tensor");
}
std::memcpy(&m_vector[elt_offset], p, n);
}
virtual void read(void* p, size_t tensor_offset, size_t n) const override
{
size_t elt_offset = tensor_offset / sizeof(typename ET::type);
if (elt_offset * sizeof(typename ET::type) != tensor_offset)
{
throw ngraph_error("Attempt to read from an address not aligned on an element");
}
size_t elt_n = n / sizeof(typename ET::type);
if (elt_n * sizeof(typename ET::type) != n)
{
throw ngraph_error("Attemmpt to read a partial element");
}
size_t elt_byte_size = sizeof(typename ET::type) * n;
if (tensor_offset + n > elt_byte_size)
{
throw ngraph_error("Attempt to read beyond the tensor");
}
std::memcpy(p, &m_vector[elt_offset], n);
}
protected:
storage_type m_vector;
};
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/common.hpp"
#include "ngraph/types/element_type.hpp"
using namespace ngraph::runtime;
std::shared_ptr<const ngraph::descriptor::TensorView> TensorView::get_tensor_view_descriptor() const
{
return m_descriptor;
}
std::shared_ptr<ngraph::descriptor::Value> TensorView::get_descriptor() const
{
return m_descriptor;
}
void TensorView::collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
const std::shared_ptr<Value>& value) const
{
views.push_back(std::static_pointer_cast<TensorView>(value));
}
const ngraph::Shape& TensorView::get_shape() const
{
return m_descriptor->get_tensor_view_type()->get_shape();
}
......@@ -20,10 +20,14 @@
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/runtime/value.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
namespace ngraph
{
namespace descriptor
{
class Value;
}
namespace runtime
{
template <typename ET>
......@@ -46,23 +50,28 @@ namespace ngraph
return dynamic_cast<ParameterizedTensorView<ET>*>(this);
}
std::shared_ptr<const ngraph::descriptor::TensorView> get_tensor_view_descriptor() const
{
return m_descriptor;
}
std::shared_ptr<const ngraph::descriptor::TensorView>
get_tensor_view_descriptor() const;
virtual std::shared_ptr<ngraph::descriptor::Value> get_descriptor() const override
{
return m_descriptor;
}
virtual std::shared_ptr<ngraph::descriptor::Value> get_descriptor() const override;
virtual void collect_tensor_views(std::vector<std::shared_ptr<TensorView>>& views,
const std::shared_ptr<Value>& value) const override
{
views.push_back(std::static_pointer_cast<TensorView>(value));
}
const std::shared_ptr<Value>& value) const override;
const ngraph::Shape& get_shape() const;
/// @brief Write bytes directly into the tensor
/// @param p Pointer to source of data
/// @param tensor_offset Offset into tensor storage to begin writing. Must be element-aligned.
/// @param n Number of bytes to write, must be integral number of elements.
virtual void write(const void* p, size_t tensor_offset, size_t n) = 0;
/// @brief Read bytes directly from the tensor
/// @param p Pointer to destination for data
/// @param tensor_offset Offset into tensor storage to begin reading. Must be element-aligned.
/// @param n Number of bytes to read, must be integral number of elements.
virtual void read(void* p, size_t tensor_offset, size_t n) const = 0;
const Shape& get_shape() { return m_descriptor->get_tensor_view_type()->get_shape(); }
protected:
std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor;
};
......
......@@ -21,6 +21,11 @@
namespace ngraph
{
namespace descriptor
{
class Value;
}
namespace runtime
{
class TensorView;
......
......@@ -21,8 +21,6 @@
using namespace ngraph;
std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
ngraph::element::Type::Type(size_t bitwidth,
bool is_float,
bool is_signed,
......
......@@ -19,10 +19,14 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <type_traits>
#include "ngraph/common.hpp"
#include "ngraph/except.hpp"
#include "ngraph/runtime/parameterized_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
......@@ -34,6 +38,7 @@ namespace ngraph
Type& operator=(const Type&) = delete;
public:
virtual ~Type() {}
Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
const std::string& c_type_string() const;
......@@ -44,6 +49,9 @@ namespace ngraph
return h(m_cname);
}
virtual std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const Shape& shape) const = 0;
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream&, const Type&);
......@@ -102,6 +110,12 @@ namespace ngraph
static TraitedType<T> t;
return t;
}
virtual std::shared_ptr<ngraph::runtime::TensorView>
make_primary_tensor_view(const ngraph::Shape& shape) const override
{
return std::make_shared<runtime::ParameterizedTensorView<TraitedType<T>>>(shape);
}
};
NGRAPH_DEFINE_TRAITED_TYPE_NAME(char)
......
This diff is collapsed.
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <memory>
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
......@@ -83,3 +83,37 @@ TEST(tensor, size)
EXPECT_EQ(1 * 4, output.size());
}
}
template <typename ET>
void test_read_write(const std::vector<typename ET::type>& x)
{
using T = typename ET::type;
auto manager = ngraph::runtime::Manager::get("NGVM");
auto backend = manager->allocate_backend();
auto a = backend->make_primary_tensor_view(ET::element_type(), Shape{2, x.size()});
auto af = a->template get_parameterized_tensor_view<ET>();
std::vector<T> result(2 * x.size());
a->write(&x[0], 0, x.size() * sizeof(T));
std::copy(x.begin(), x.end(), result.begin());
a->write(&x[0], x.size() * sizeof(T), x.size() * sizeof(T));
std::copy(x.begin(), x.end(), result.begin() + x.size());
auto& af_vector = af->get_vector();
ASSERT_EQ(af_vector, result);
std::vector<T> result1(x.size());
std::vector<T> result2(x.size());
std::copy(result.begin() + 1, result.begin() + 1 + x.size(), result1.begin());
a->read(&result2[0], sizeof(T), sizeof(T) * x.size());
ASSERT_EQ(result1, result2);
}
TEST(tensor, read_write)
{
test_read_write<element::Float32>({1.0, 3.0, 5.0});
test_read_write<element::Int64>({-1, 2, 4});
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment