Unverified Commit 03cb31f5 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Do checking, tensor view spreading and type propagation during op construction (#232)

* Do checking, tensor view spreading and type propagation during op construction
Better names for builtin classes
replace set_value_type with assert_value_type, which checks if type is as expected

* Review comments

* Review comments
parent fff318bd
......@@ -26,8 +26,8 @@ set (SRC
node.cpp
ops/add.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/constant.cpp
......@@ -52,8 +52,7 @@ set (SRC
ops/sum.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
ops/unary_elementwise.cpp
pass/collect_functions.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
......@@ -62,7 +61,6 @@ set (SRC
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/visualize_tree.cpp
runtime/backend.cpp
......
......@@ -14,12 +14,12 @@
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
using namespace ngraph;
using namespace descriptor;
Input::Input(
const std::shared_ptr<Node>& node, size_t index, size_t argno, size_t arg_index, Output& output)
Input::Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
......@@ -31,7 +31,7 @@ Input::Input(
std::shared_ptr<Node> Input::get_node()
{
return m_node.lock();
return m_node->shared_from_this();
}
const Tensor& Input::get_tensor() const
......@@ -43,3 +43,18 @@ Tensor& Input::get_tensor()
{
return m_output.get_tensor();
}
std::shared_ptr<const TensorView> Input::get_tensor_view() const
{
return m_output.get_tensor_view();
}
std::shared_ptr<TensorView> Input::get_tensor_view()
{
return m_output.get_tensor_view();
}
std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{
return m_output.get_tensor_view()->get_tensor_view_type();
}
......@@ -17,6 +17,7 @@
#include <memory>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph
{
......@@ -37,26 +38,41 @@ namespace ngraph
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(const std::shared_ptr<Node>& node,
size_t index,
size_t argno,
size_t arg_index,
Output& output);
Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& output);
/// @return the node that this is an input of
std::shared_ptr<Node> get_node();
/// @return the position of the node argument that uses this input
size_t get_argno() const { return m_argno; }
/// @return the position within the node argument of this tensor
size_t get_arg_index() const { return m_arg_index; }
/// @return the position within all supplied tensors of this input
size_t get_index() const { return m_index; }
// @return the connected output
const Output& get_output() const { return m_output; }
// @return the connected output
Output& get_output() { return m_output; }
// @return the tensor of the connected output
const Tensor& get_tensor() const;
// @return the tensor of the connected output
Tensor& get_tensor();
/// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const;
/// @return the tensor view for the connected output
std::shared_ptr<TensorView> get_tensor_view();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
protected:
std::weak_ptr<Node> m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
Output& m_output;
private:
......
......@@ -16,6 +16,7 @@
#include "ngraph/except.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor::layout;
using ngraph::Shape;
......
......@@ -15,6 +15,7 @@
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor::layout;
......
......@@ -27,9 +27,19 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<const Valu
, m_is_output(false)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
size_t i = 0;
size_t argno = 0;
for (auto arg : m_arguments)
{
node->m_users.insert(this);
arg->assign_tensors();
arg->m_users.insert(this);
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(this, i, argno, arg_index++, output);
i++;
}
argno++;
}
}
......@@ -47,6 +57,14 @@ Node::~Node()
{
}
void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) const
{
if (*m_value_type != *value_type)
{
throw ngraph_error("Setting value type to a different ValueType");
}
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{
if (nullptr == m_value_type)
......@@ -64,26 +82,45 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
std::shared_ptr<const ValueType> Node::get_value_type()
{
if (nullptr == m_value_type)
{
propagate_types();
}
return m_value_type;
}
const std::shared_ptr<const ValueType> Node::get_value_type() const
{
if (nullptr == m_value_type)
if (!m_outputs_valid)
{
const_cast<Node*>(this)->propagate_types();
const_cast<Node*>(this)->assign_tensors();
}
return m_value_type;
}
std::deque<descriptor::Output>& Node::get_outputs()
{
if (!m_outputs_valid)
{
assign_tensors();
}
return m_outputs;
}
const std::deque<descriptor::Output>& Node::get_outputs() const
{
if (!m_outputs_valid)
{
const_cast<Node*>(this)->assign_tensors();
}
return m_outputs;
}
void Node::assign_tensors()
{
if (m_outputs_valid)
{
return;
}
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
m_value_type->collect_tensor_views(tensor_view_types);
std::shared_ptr<Node> shared_this = shared_from_this();
size_t i = 0;
for (auto tvt : tensor_view_types)
......@@ -96,19 +133,7 @@ void Node::assign_tensors()
m_outputs.emplace_back(shared_this, i, tensor_view_descriptor);
i++;
}
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(shared_this, i, argno, arg_index++, output);
i++;
}
argno++;
}
m_outputs_valid = true;
}
bool Node::is_parameter() const
......
......@@ -15,6 +15,7 @@
#pragma once
#include <atomic>
#include <deque>
#include <memory>
#include <set>
#include <string>
......@@ -54,19 +55,16 @@ namespace ngraph
{
}
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
public:
/// The class name, must not contain spaces
virtual std::string description() const = 0;
std::string get_name() const;
void set_name(const std::string& name);
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; }
......@@ -83,14 +81,10 @@ namespace ngraph
std::shared_ptr<const ValueType> get_value_type();
const std::shared_ptr<const ValueType> get_value_type() const;
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<const ValueType>& value_type)
void assert_value_type(const std::shared_ptr<const ValueType>& value_type) const;
void assert_value_type(const element::Type& element_type, const Shape& shape) const
{
m_value_type = value_type;
assert_value_type(std::make_shared<TensorViewType>(element_type, shape));
}
// Set the value type if it has not already been set; otherwise, ensure that
......@@ -108,8 +102,8 @@ namespace ngraph
std::deque<descriptor::Input>& get_inputs() { return m_inputs; }
const std::deque<descriptor::Input>& get_inputs() const { return m_inputs; }
std::deque<descriptor::Output>& get_outputs() { return m_outputs; }
const std::deque<descriptor::Output>& get_outputs() const { return m_outputs; }
std::deque<descriptor::Output>& get_outputs();
const std::deque<descriptor::Output>& get_outputs() const;
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
......@@ -132,6 +126,7 @@ namespace ngraph
static std::atomic<size_t> m_next_instance_id;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
bool m_outputs_valid{false};
bool m_is_output;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
};
......
......@@ -19,29 +19,22 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
op::BinaryElementwise::BinaryElementwise(
std::function<const element::Type&(const element::Type&, const element::Type&)>
element_type_function,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(Nodes{arg0, arg1})
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape())
{
throw ngraph_error("Arguments must have the same tensor view shape");
}
const element::Type& result_element_type = propagate_element_types(
const element::Type& result_element_type = element_type_function(
arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type());
set_value_type_checked(
......
......@@ -16,20 +16,26 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
return arg0_element_type;
},
arg0,
arg1)
{
}
......@@ -16,15 +16,21 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
using namespace ngraph;
const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
{
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
return element::Bool::element_type();
return element::Bool::element_type();
},
arg0,
arg1)
{
}
......@@ -16,25 +16,16 @@
#include "ngraph/ops/sum.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Broadcast::propagate_types()
op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: RequiresTensorViewArgs({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type();
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
......@@ -48,8 +39,8 @@ void Broadcast::propagate_types()
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
}
void ngraph::op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
......
......@@ -56,7 +56,7 @@ namespace ngraph
/// | ------- | ----------------------------------------------- |
/// | NGVM | Implemented for scalars, matrices, and vectors. |
class Broadcast : public Builtin
class Broadcast : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a conversion operation.
......@@ -67,12 +67,7 @@ namespace ngraph
/// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Builtin({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
const AxisSet& broadcast_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -83,8 +78,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override;
/// \return An set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
const Shape& get_broadcast_shape() const { return m_shape; }
......
......@@ -17,27 +17,18 @@
#include "ngraph/ops/concatenate.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Concat::propagate_types()
op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
: RequiresTensorViewArgs(args)
, m_concatenation_axis(concatenation_axis)
{
if (m_arguments.size() < 1)
{
throw ngraph_error("At least one argument required");
}
auto arg0_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg0_type)
{
throw ngraph_error("Argument to concat is missing type.");
}
auto arg0_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg0_type);
if (nullptr == arg0_tensor_view_type)
{
throw ngraph_error("Argument to concat is not a tensor view");
}
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg0_shape = arg0_tensor_view_type->get_shape();
if (m_concatenation_axis >= arg0_shape.size())
{
......@@ -47,20 +38,9 @@ void Concat::propagate_types()
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
for (auto i = 1; i < m_arguments.size(); i++)
for (auto i = 1; i < get_inputs().size(); i++)
{
auto argi_type = m_arguments.at(i)->get_value_type();
if (nullptr == argi_type)
{
throw ngraph_error("Argument to concat is missing type.");
}
auto argi_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(argi_type);
if (nullptr == argi_tensor_view_type)
{
throw ngraph_error("Argument to concat is not a tensor view");
}
auto argi_tensor_view_type = get_inputs().at(i).get_tensor_view_type();
auto argi_shape = argi_tensor_view_type->get_shape();
if (argi_shape.size() != arg0_shape.size())
{
......@@ -85,7 +65,6 @@ void Concat::propagate_types()
}
}
}
vector<size_t> concatenated_shape = arg0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
......
......@@ -63,18 +63,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------------------------- |
/// | NGVM | Implemented for vectors and matrices. |
class Concat : public Builtin
class Concat : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const Nodes& args, size_t concatenation_axis)
: Builtin(args)
, m_concatenation_axis(concatenation_axis)
{
}
Concat(const Nodes& args, size_t concatenation_axis);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -82,9 +78,7 @@ namespace ngraph
return std::make_shared<Concat>(new_args, m_concatenation_axis);
}
virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override;
virtual std::string description() const override { return "Concat"; }
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected:
......
......@@ -14,24 +14,44 @@
#include "ngraph/ops/constant.hpp"
using namespace ngraph::op;
using namespace ngraph;
void ConstantBase::propagate_types()
namespace
{
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
{
auto result = ET::read(value_strings);
}
}
op::Constant::Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
check_args();
}
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
op::Constant::Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
auto result = ET::read(value_strings);
check_args();
}
void Constant::propagate_types()
void op::Constant::check_args()
{
// No actual type propagation is done here; however, we check the number of value strings and
// We check the number of value strings and
// also call check_value_strings just to make sure the result will be parseable at compile
// time. (It will throw an exception if not.)
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(get_value_type());
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(m_value_type);
if (nullptr == tvt)
{
throw ngraph_error("Constant does not have tensor view type");
......
......@@ -39,8 +39,6 @@ namespace ngraph
: Node({}, type)
{
}
virtual void propagate_types() override;
};
/// \brief Class for constants whose element types are known at C++ compile-time.
......@@ -162,22 +160,14 @@ namespace ngraph
/// \param value_strings A list of literals for initializing the tensor constant. There must be one literal for each element of the tensor; i.e., `value_strings.size()` must equal `ngraph::shape_size(shape)`.
Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
}
const std::vector<std::string>& value_strings);
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
}
Constant(const element::Type& et, const Shape& shape, const std::string& value_string);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -197,9 +187,9 @@ namespace ngraph
/// \return The initialization literals for the tensor constant.
const std::vector<std::string>& get_value_strings() const { return m_value_strings; }
virtual void propagate_types() override;
protected:
void check_args();
const std::vector<std::string> m_value_strings;
};
}
......
......@@ -18,9 +18,13 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& Convert::propagate_element_types(const element::Type& arg_element_type) const
op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwise(
[&](const ngraph::element::Type& ignored) -> const ngraph::element::Type& {
return element_type;
},
arg)
, m_element_type(element_type)
{
return m_element_type;
}
......@@ -48,18 +48,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Convert : public UnaryElementwiseBuiltin
class Convert : public UnaryElementwise
{
public:
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwiseBuiltin({arg})
, m_element_type(element_type)
{
}
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -69,8 +65,6 @@ namespace ngraph
return std::make_shared<Convert>(new_args.at(0), m_element_type);
}
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
const element::Type& get_convert_element_type() const { return m_element_type; }
virtual std::string description() const override { return "Convert"; }
protected:
......
......@@ -23,18 +23,14 @@
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Dot::propagate_types()
op::Dot::Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs({arg0, arg1})
{
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
......@@ -108,8 +104,7 @@ ngraph::AxisVector range<ngraph::AxisVector>(size_t n)
return result;
}
void ngraph::op::Dot::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
......
......@@ -102,17 +102,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ---------------------------------------------- |
/// | NGVM | Implemented for `arg1` with rank of exactly 2. |
class Dot : public Builtin
class Dot : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a dot product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -123,8 +120,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -16,9 +16,12 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void FunctionCall::propagate_types()
op::FunctionCall::FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Node(args)
, m_function(function)
{
auto& function_params = m_function->get_parameters();
......
......@@ -14,8 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -46,7 +45,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class FunctionCall : public Builtin
class FunctionCall : public ngraph::Node
{
public:
/// \brief Constructs a function call operation.
......@@ -54,11 +53,7 @@ namespace ngraph
/// \param function The function to be called.
/// \param args The arguments for the function call.
FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Builtin(args)
, m_function(function)
{
}
const std::vector<std::shared_ptr<Node>>& args);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -67,8 +62,6 @@ namespace ngraph
}
virtual std::string description() const override { return "FunctionCall"; }
virtual void propagate_types() override;
/// \return The function to be called.
std::shared_ptr<Function> get_function() const { return m_function; }
protected:
......
......@@ -17,15 +17,12 @@
#include "ngraph/ops/get_tuple_element.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void GetTupleElement::propagate_types()
op::GetTupleElement::GetTupleElement(const std::shared_ptr<Node>& arg, size_t n)
: Node({arg})
, m_n{n}
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tuple_type =
dynamic_pointer_cast<const TupleType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg0_tuple_type)
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -47,18 +47,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class GetTupleElement : public Builtin
class GetTupleElement : public ngraph::Node
{
public:
/// \brief Constructs a get-tuple-element operation.
///
/// \param arg The input tuple.
/// \param n The index of the tuple element to get.
GetTupleElement(const std::shared_ptr<Node>& arg, size_t n)
: Builtin({arg})
, m_n{n}
{
}
GetTupleElement(const std::shared_ptr<Node>& arg, size_t n);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -68,7 +64,6 @@ namespace ngraph
return std::make_shared<GetTupleElement>(new_args.at(0), m_n);
}
virtual void propagate_types() override;
virtual std::string description() const override { return "GetTupleElement"; }
/// \return The index of the tuple element to get.
size_t get_n() const { return m_n; }
......
......@@ -13,9 +13,24 @@
// ----------------------------------------------------------------------------
#include <algorithm>
#include <memory>
#include <sstream>
#include "ngraph/except.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
using namespace std;
op::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::vector<std::shared_ptr<Node>>& args)
: Node(args)
{
for (auto arg : args)
{
if (nullptr == std::dynamic_pointer_cast<const TensorViewType>(arg->get_value_type()))
{
throw ngraph_error("Arguments must be tensor views");
}
}
}
This diff is collapsed.
......@@ -29,10 +29,6 @@ Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& sha
{
}
void Parameter::propagate_types()
{
}
void Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
}
......@@ -71,7 +71,6 @@ namespace ngraph
}
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
};
}
}
......@@ -16,37 +16,19 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
void Reduce::propagate_types()
using namespace ngraph;
op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes)
: RequiresTensorViewArgs({arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_reductee_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_reductee_type)
{
throw ngraph_error("Argument to reduce is missing type.");
}
auto arg_reductee_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(arg_reductee_type);
if (nullptr == arg_reductee_tensor_view_type)
{
throw ngraph_error("Argument to reduce is not a tensor view");
}
auto arg_reductee_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_init_type = m_arguments.at(1)->get_value_type();
if (nullptr == arg_init_type)
{
throw ngraph_error("Argument for initial value is missing type.");
}
auto arg_init_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_init_type);
if (nullptr == arg_init_tensor_view_type)
{
throw ngraph_error("Argument for initial value is not a tensor view");
}
auto arg_init_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
if (arg_init_tensor_view_type->get_shape().size() != 0)
{
throw ngraph_error("Argument for initial value is not a scalar");
......@@ -85,18 +67,18 @@ void Reduce::propagate_types()
throw ngraph_error("Reduction function has wrong number of parameters (should be two)");
}
if (*(f_params.at(0)->get_value_type()) != *(arg_init_type))
if (*(f_params.at(0)->get_value_type()) != *(arg_init->get_value_type()))
{
throw ngraph_error("Argument 0 of reduction function has wrong type");
}
if (*(f_params.at(1)->get_value_type()) != *(arg_init_type))
if (*(f_params.at(1)->get_value_type()) != *(arg_init->get_value_type()))
{
throw ngraph_error("Argument 1 of reduction function has wrong type");
}
auto f_result_type = m_reduction_function->get_result_type();
if (*(f_result_type) != *(arg_init_type))
if (*(f_result_type) != *(arg_init->get_value_type()))
{
throw ngraph_error("Return type from reduction function does not match expected");
}
......
......@@ -87,7 +87,7 @@ namespace ngraph
/// | ------- | ----------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. |
class Reduce : public Builtin
class Reduce : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reduction operation.
......@@ -99,12 +99,7 @@ namespace ngraph
Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes)
: Builtin({arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
{
}
const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -116,8 +111,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Reduce"; }
virtual void propagate_types() override;
/// \return The function to use for reduction.
std::shared_ptr<Function> get_reduction_function() const
{
......
......@@ -18,26 +18,16 @@
#include <algorithm>
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Reshape::propagate_types()
op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: RequiresTensorViewArgs({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to reshape is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to reshape is not a tensor view");
}
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
......@@ -79,8 +69,8 @@ void Reshape::propagate_types()
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
}
void ngraph::op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto x_type = x->get_value_type();
......
......@@ -59,7 +59,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. Implemented for other shapes only when there is no reordering of the input axes, i.e. `input_order` is \f$(0,\dots,n-1)\f$. |
class Reshape : public Builtin
class Reshape : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reshape operation.
......@@ -71,12 +71,7 @@ namespace ngraph
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: Builtin({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
}
const Shape& output_shape);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -87,8 +82,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Reshape"; }
virtual void propagate_types() override;
/// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; }
/// \return The shape of the output tensor.
......
......@@ -19,25 +19,16 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void Select::propagate_types()
op::Select::Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: RequiresTensorViewArgs(Nodes{arg0, arg1, arg2})
{
if (m_arguments.size() != 3)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
auto arg2_tensor_type = get_inputs().at(2).get_tensor_view_type();
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
auto arg2_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(2)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type || nullptr == arg2_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (arg0_tensor_type->get_element_type() != element::Bool::element_type())
{
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
......
......@@ -41,7 +41,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Select : public Builtin
class Select : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a selection operation.
......@@ -51,10 +51,7 @@ namespace ngraph
/// \param arg2 Node that produces the third input tensor.
Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: Builtin(Nodes{arg0, arg1, arg2})
{
}
const std::shared_ptr<Node>& arg2);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -65,7 +62,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Select"; }
virtual void propagate_types() override;
};
}
}
......@@ -15,25 +15,34 @@
#include "ngraph/ops/slice.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Slice::propagate_types()
op::Slice::Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: RequiresTensorViewArgs({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
check_args();
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to slice is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to slice is not a tensor view");
}
op::Slice::Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: RequiresTensorViewArgs({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
{
check_args();
}
void op::Slice::check_args()
{
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_shape = arg_tensor_view_type->get_shape();
if (m_lower_bounds.size() != arg_shape.size())
......
......@@ -52,7 +52,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ----------------------------------------------- |
/// | NGVM | Implemented for scalars, matrices, and vectors. |
class Slice : public Builtin
class Slice : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a tensor slice operation.
......@@ -65,13 +65,7 @@ namespace ngraph
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
}
const Shape& step);
/// \brief Constructs a tensor slice operation with unit step; i.e., every element inside the bounding box will be copied to the output slice.
///
......@@ -80,13 +74,7 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
{
}
const Coordinate& upper_bounds);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -98,8 +86,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Slice"; }
virtual void propagate_types() override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
/// \return The exclusive upper-bound coordinates.
......@@ -107,6 +93,8 @@ namespace ngraph
/// \return The slicing step.
const Shape& get_step() const { return m_step; }
protected:
void check_args();
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Shape m_step;
......
......@@ -16,26 +16,13 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Sum::propagate_types()
op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: RequiresTensorViewArgs({arg})
, m_reduction_axes(reduction_axes)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to sum is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to sum is not a tensor view");
}
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
if (arg_element_type == element::Bool::element_type())
{
......
......@@ -80,18 +80,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ----------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. |
class Sum : public Builtin
class Sum : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a summation operation.
///
/// \param arg The tensor view to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: Builtin({arg})
, m_reduction_axes(reduction_axes)
{
}
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -102,8 +98,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Sum"; }
virtual void propagate_types() override;
/// \return The axis positions (0-based) to be eliminated through summation.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected:
......
......@@ -18,9 +18,10 @@
#include "ngraph/ops/tuple.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Tuple::propagate_types()
op::Tuple::Tuple(const Nodes& args)
: Node(args)
{
vector<shared_ptr<const ValueType>> element_types;
for (auto argument : m_arguments)
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -39,16 +39,13 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Tuple : public Builtin
class Tuple : public ngraph::Node
{
public:
/// \brief Constructs a tuple construction operation.
///
/// \param args The nodes that produce the elements of the constructed tuple.
Tuple(const Nodes& args)
: Builtin(args)
{
}
Tuple(const Nodes& args);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -57,7 +54,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override;
};
}
}
......@@ -18,24 +18,15 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
op::UnaryElementwise::UnaryElementwise(
std::function<const element::Type&(const element::Type&)> element_type_function,
const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(Nodes{arg})
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
auto arg_tensor_type = get_inputs().at(0).get_tensor_view_type();
const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type());
element_type_function(arg_tensor_type->get_element_type());
set_value_type_checked(
make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
......
......@@ -15,15 +15,19 @@
#include "ngraph/ops/op.hpp"
using namespace ngraph;
using namespace ngraph::op;
const element::Type&
UnaryElementwiseArithmetic::propagate_element_types(const element::Type& arg_element_type) const
{
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
}
op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: UnaryElementwise(
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error(
"Operands for arithmetic operators must have numeric element "
"type");
}
return arg_element_type;
return arg_element_type;
},
arg)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/assign_tensors.hpp"
#include <exception>
#include <sstream>
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
using namespace std;
using namespace ngraph;
bool pass::AssignTensors::run_on_call_graph(list<std::shared_ptr<Node>>& nodes)
{
for (shared_ptr<Node> node : nodes)
{
try
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if (node == get_state().get_functions().at(0)->get_result())
{
node->set_is_output();
}
node->assign_tensors();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class AssignTensors;
}
}
class ngraph::pass::AssignTensors : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override;
private:
};
......@@ -20,7 +20,6 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/propagate_types.hpp"
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_graph(list<shared_ptr<Node>>& nodes)
{
for (shared_ptr<Node> node : nodes)
{
try
{
node->propagate_types();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class PropagateTypes;
}
}
class ngraph::pass::PropagateTypes : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private:
};
......@@ -66,9 +66,7 @@
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/emitter.hpp"
......@@ -159,8 +157,6 @@ void ExternalFunction::compile(FunctionMap& function_map)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
// For now, just make everyone row-major.
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
pass_manager.run_passes(m_function);
......
......@@ -62,9 +62,7 @@
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/ngvm/eigen/abs.hpp"
#include "ngraph/runtime/ngvm/eigen/acos.hpp"
......@@ -1023,8 +1021,6 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Get the ordered list of ops in execution order
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.run_passes(m_function);
// Turn this into a pass
......
......@@ -20,6 +20,7 @@
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <type_traits>
......
......@@ -109,7 +109,7 @@ TEST(build_graph, tensor)
auto float_tensor_type =
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3});
ASSERT_EQ(*float0->get_value_type(), *float_tensor_type);
auto d = make_shared<op::Dot>(float0, float0);
auto d = make_shared<op::Add>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
......
......@@ -112,7 +112,7 @@ TEST(copy, concat)
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
size_t axis = 1;
size_t axis = 0;
auto node = make_shared<op::Concat>(Nodes{arg0, arg1}, axis);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Concat>(new_node);
......@@ -219,9 +219,11 @@ TEST(copy, FunctionCall)
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto node = make_shared<op::FunctionCall>(f, Nodes{arg0, arg1});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto node = make_shared<op::FunctionCall>(f, Nodes{arg0, arg1, arg2});
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto new_node = node->copy_with_new_args(new_args);
......
......@@ -25,8 +25,6 @@ TEST(input_output, param_tensor)
// Params have no arguments, so we can check that the value becomes a tensor output
auto tv_tp = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param = make_shared<op::Parameter>(tv_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 1);
for (size_t i = 0; i < param->get_outputs().size(); i++)
......@@ -46,8 +44,6 @@ TEST(input_output, param_tuple)
auto tv_tp_1 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4, 6});
auto tp_tp = make_shared<TupleType>(ValueTypes{tv_tp_0, tv_tp_1});
auto param = make_shared<op::Parameter>(tp_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 2);
for (size_t i = 0; i < param->get_outputs().size(); i++)
......@@ -74,19 +70,8 @@ TEST(input_output, simple_output)
nodes.push_back(param_1);
nodes.push_back(add);
// Type info
for (auto node : nodes)
{
node->propagate_types();
}
// Add inputs/outputs
for (auto node : nodes)
{
node->assign_tensors();
}
// At this point, the add should have each input associated with the output of the appropriate parameter
ASSERT_EQ(1, add->get_outputs().size());
auto& inputs = add->get_inputs();
ASSERT_EQ(2, inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
......
......@@ -21,12 +21,10 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
......@@ -44,8 +42,6 @@ TEST(pass, liveness)
pass_manager.register_pass<pass::VisualizeTree>(image);
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
......
......@@ -20,9 +20,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "util/test_tools.hpp"
......@@ -35,8 +33,6 @@ TEST(pass_manager, add)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
auto graph = make_test_graph();
size_t node_count = 0;
......
......@@ -20,13 +20,11 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "util/test_tools.hpp"
......@@ -210,8 +208,6 @@ TEST(memory_layout, basic)
string dump_file = "memory_layout.txt";
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
......
......@@ -21,10 +21,8 @@
#include "gtest/gtest.h"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "util/test_tools.hpp"
......@@ -37,8 +35,6 @@ TEST(tensor, size)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass<pass::Liveness>();
{
......
......@@ -21,11 +21,9 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
......@@ -204,8 +202,6 @@ TEST(topological_sort, unused_function_arg)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
// pass_manager.register_pass<pass::DumpSorted>("sorted.txt");
pass_manager.run_passes(f);
list<shared_ptr<Node>> ops = f->get_ordered_ops();
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment