Unverified Commit 03cb31f5 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Do checking, tensor view spreading and type propagation during op construction (#232)

* Do checking, tensor view spreading and type propagation during op construction
Better names for builtin classes
replace set_value_type with assert_value_type, which checks if type is as expected

* Review comments

* Review comments
parent fff318bd
......@@ -26,8 +26,8 @@ set (SRC
node.cpp
ops/add.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_builtin.cpp
ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp
ops/broadcast.cpp
ops/concatenate.cpp
ops/constant.cpp
......@@ -52,8 +52,7 @@ set (SRC
ops/sum.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
pass/assign_tensors.cpp
ops/unary_elementwise.cpp
pass/collect_functions.cpp
pass/dump_sorted.cpp
pass/liveness.cpp
......@@ -62,7 +61,6 @@ set (SRC
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/propagate_types.cpp
pass/topological_sort.cpp
pass/visualize_tree.cpp
runtime/backend.cpp
......
......@@ -14,12 +14,12 @@
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/node.hpp"
using namespace ngraph;
using namespace descriptor;
Input::Input(
const std::shared_ptr<Node>& node, size_t index, size_t argno, size_t arg_index, Output& output)
Input::Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& output)
: m_node(node)
, m_index(index)
, m_argno(argno)
......@@ -31,7 +31,7 @@ Input::Input(
std::shared_ptr<Node> Input::get_node()
{
return m_node.lock();
return m_node->shared_from_this();
}
const Tensor& Input::get_tensor() const
......@@ -43,3 +43,18 @@ Tensor& Input::get_tensor()
{
return m_output.get_tensor();
}
std::shared_ptr<const TensorView> Input::get_tensor_view() const
{
return m_output.get_tensor_view();
}
std::shared_ptr<TensorView> Input::get_tensor_view()
{
return m_output.get_tensor_view();
}
std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{
return m_output.get_tensor_view()->get_tensor_view_type();
}
......@@ -17,6 +17,7 @@
#include <memory>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph
{
......@@ -37,23 +38,38 @@ namespace ngraph
/// @param argno The position of the argument with this tensor
/// @param arg_index The position of the tensor within the argument's tensors
/// @param output The output that supplies a value for this input
Input(const std::shared_ptr<Node>& node,
size_t index,
size_t argno,
size_t arg_index,
Output& output);
Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& output);
/// @return the node that this is an input of
std::shared_ptr<Node> get_node();
/// @return the position of the node argument that uses this input
size_t get_argno() const { return m_argno; }
/// @return the position within the node argument of this tensor
size_t get_arg_index() const { return m_arg_index; }
/// @return the position within all supplied tensors of this input
size_t get_index() const { return m_index; }
// @return the connected output
const Output& get_output() const { return m_output; }
// @return the connected output
Output& get_output() { return m_output; }
// @return the tensor of the connected output
const Tensor& get_tensor() const;
// @return the tensor of the connected output
Tensor& get_tensor();
/// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const;
/// @return the tensor view for the connected output
std::shared_ptr<TensorView> get_tensor_view();
/// @return the tensor view type for the connected output
std::shared_ptr<const TensorViewType> get_tensor_view_type() const;
protected:
std::weak_ptr<Node> m_node; // The node we are an input for
Node* m_node; // The node we are an input for
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
......
......@@ -16,6 +16,7 @@
#include "ngraph/except.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor::layout;
using ngraph::Shape;
......
......@@ -15,6 +15,7 @@
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph::descriptor::layout;
......
......@@ -27,9 +27,19 @@ Node::Node(const std::vector<shared_ptr<Node>>& arguments, shared_ptr<const Valu
, m_is_output(false)
{
// Add this node as a user of each argument.
for (auto node : m_arguments)
size_t i = 0;
size_t argno = 0;
for (auto arg : m_arguments)
{
arg->assign_tensors();
arg->m_users.insert(this);
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
node->m_users.insert(this);
m_inputs.emplace_back(this, i, argno, arg_index++, output);
i++;
}
argno++;
}
}
......@@ -47,6 +57,14 @@ Node::~Node()
{
}
void Node::assert_value_type(const shared_ptr<const ValueType>& value_type) const
{
if (*m_value_type != *value_type)
{
throw ngraph_error("Setting value type to a different ValueType");
}
}
void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
{
if (nullptr == m_value_type)
......@@ -64,26 +82,45 @@ void Node::set_value_type_checked(const shared_ptr<const ValueType>& value_type)
std::shared_ptr<const ValueType> Node::get_value_type()
{
if (nullptr == m_value_type)
{
propagate_types();
}
return m_value_type;
}
const std::shared_ptr<const ValueType> Node::get_value_type() const
{
if (nullptr == m_value_type)
if (!m_outputs_valid)
{
const_cast<Node*>(this)->propagate_types();
const_cast<Node*>(this)->assign_tensors();
}
return m_value_type;
}
std::deque<descriptor::Output>& Node::get_outputs()
{
if (!m_outputs_valid)
{
assign_tensors();
}
return m_outputs;
}
const std::deque<descriptor::Output>& Node::get_outputs() const
{
if (!m_outputs_valid)
{
const_cast<Node*>(this)->assign_tensors();
}
return m_outputs;
}
void Node::assign_tensors()
{
if (m_outputs_valid)
{
return;
}
vector<std::shared_ptr<const TensorViewType>> tensor_view_types;
get_value_type()->collect_tensor_views(tensor_view_types);
m_value_type->collect_tensor_views(tensor_view_types);
std::shared_ptr<Node> shared_this = shared_from_this();
size_t i = 0;
for (auto tvt : tensor_view_types)
......@@ -96,19 +133,7 @@ void Node::assign_tensors()
m_outputs.emplace_back(shared_this, i, tensor_view_descriptor);
i++;
}
i = 0;
size_t argno = 0;
for (auto arg : get_arguments())
{
size_t arg_index = 0;
for (descriptor::Output& output : arg->get_outputs())
{
m_inputs.emplace_back(shared_this, i, argno, arg_index++, output);
i++;
}
argno++;
}
m_outputs_valid = true;
}
bool Node::is_parameter() const
......
......@@ -15,6 +15,7 @@
#pragma once
#include <atomic>
#include <deque>
#include <memory>
#include <set>
#include <string>
......@@ -54,19 +55,16 @@ namespace ngraph
{
}
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
public:
/// The class name, must not contain spaces
virtual std::string description() const = 0;
std::string get_name() const;
void set_name(const std::string& name);
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
/// Assign Input and Output vectors
// This might later need to be virtual.
void assign_tensors();
const Nodes& get_arguments() const { return m_arguments; }
void clear_arguments() { m_arguments.clear(); }
const std::multiset<Node*>& users() const { return m_users; }
......@@ -83,14 +81,10 @@ namespace ngraph
std::shared_ptr<const ValueType> get_value_type();
const std::shared_ptr<const ValueType> get_value_type() const;
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
void set_value_type(const std::shared_ptr<const ValueType>& value_type)
void assert_value_type(const std::shared_ptr<const ValueType>& value_type) const;
void assert_value_type(const element::Type& element_type, const Shape& shape) const
{
m_value_type = value_type;
assert_value_type(std::make_shared<TensorViewType>(element_type, shape));
}
// Set the value type if it has not already been set; otherwise, ensure that
......@@ -108,8 +102,8 @@ namespace ngraph
std::deque<descriptor::Input>& get_inputs() { return m_inputs; }
const std::deque<descriptor::Input>& get_inputs() const { return m_inputs; }
std::deque<descriptor::Output>& get_outputs() { return m_outputs; }
const std::deque<descriptor::Output>& get_outputs() const { return m_outputs; }
std::deque<descriptor::Output>& get_outputs();
const std::deque<descriptor::Output>& get_outputs() const;
std::unordered_set<descriptor::Tensor*> liveness_live_list;
std::unordered_set<descriptor::Tensor*> liveness_new_list;
std::unordered_set<descriptor::Tensor*> liveness_free_list;
......@@ -132,6 +126,7 @@ namespace ngraph
static std::atomic<size_t> m_next_instance_id;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
bool m_outputs_valid{false};
bool m_is_output;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
};
......
......@@ -19,29 +19,22 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void BinaryElementwiseBuiltin::propagate_types()
op::BinaryElementwise::BinaryElementwise(
std::function<const element::Type&(const element::Type&, const element::Type&)>
element_type_function,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(Nodes{arg0, arg1})
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_shape() != arg1_tensor_type->get_shape())
{
throw ngraph_error("Arguments must have the same tensor view shape");
}
const element::Type& result_element_type = propagate_element_types(
const element::Type& result_element_type = element_type_function(
arg0_tensor_type->get_element_type(), arg1_tensor_type->get_element_type());
set_value_type_checked(
......
......@@ -16,11 +16,12 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
{
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
......@@ -28,8 +29,13 @@ const element::Type& BinaryElementwiseArithmetic::propagate_element_types(
if (arg0_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
throw ngraph_error(
"Operands for arithmetic operators must have numeric element type");
}
return arg0_element_type;
},
arg0,
arg1)
{
}
......@@ -16,15 +16,21 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
using namespace ngraph;
const element::Type& BinaryElementwiseComparison::propagate_element_types(
const element::Type& arg0_element_type, const element::Type& arg1_element_type) const
{
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(
[](const element::Type& arg0_element_type,
const element::Type& arg1_element_type) -> const element::Type& {
if (arg0_element_type != arg1_element_type)
{
throw ngraph_error("Arguments must have the same tensor view element type");
}
return element::Bool::element_type();
},
arg0,
arg1)
{
}
......@@ -16,25 +16,16 @@
#include "ngraph/ops/sum.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Broadcast::propagate_types()
op::Broadcast::Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: RequiresTensorViewArgs({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
auto arg_tensor_view_type = m_inputs.at(0).get_tensor_view_type();
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
......@@ -48,7 +39,7 @@ void Broadcast::propagate_types()
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
}
void ngraph::op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
......
......@@ -56,7 +56,7 @@ namespace ngraph
/// | ------- | ----------------------------------------------- |
/// | NGVM | Implemented for scalars, matrices, and vectors. |
class Broadcast : public Builtin
class Broadcast : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a conversion operation.
......@@ -67,12 +67,7 @@ namespace ngraph
/// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Builtin({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
const AxisSet& broadcast_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -83,8 +78,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Broadcast"; }
virtual void propagate_types() override;
/// \return An set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
const Shape& get_broadcast_shape() const { return m_shape; }
......
......@@ -17,27 +17,18 @@
#include "ngraph/ops/concatenate.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Concat::propagate_types()
op::Concat::Concat(const Nodes& args, size_t concatenation_axis)
: RequiresTensorViewArgs(args)
, m_concatenation_axis(concatenation_axis)
{
if (m_arguments.size() < 1)
{
throw ngraph_error("At least one argument required");
}
auto arg0_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg0_type)
{
throw ngraph_error("Argument to concat is missing type.");
}
auto arg0_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg0_type);
if (nullptr == arg0_tensor_view_type)
{
throw ngraph_error("Argument to concat is not a tensor view");
}
auto arg0_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg0_shape = arg0_tensor_view_type->get_shape();
if (m_concatenation_axis >= arg0_shape.size())
{
......@@ -47,20 +38,9 @@ void Concat::propagate_types()
size_t concatenation_axis_length = arg0_shape.at(m_concatenation_axis);
auto& arg0_element_type = arg0_tensor_view_type->get_element_type();
for (auto i = 1; i < m_arguments.size(); i++)
{
auto argi_type = m_arguments.at(i)->get_value_type();
if (nullptr == argi_type)
for (auto i = 1; i < get_inputs().size(); i++)
{
throw ngraph_error("Argument to concat is missing type.");
}
auto argi_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(argi_type);
if (nullptr == argi_tensor_view_type)
{
throw ngraph_error("Argument to concat is not a tensor view");
}
auto argi_tensor_view_type = get_inputs().at(i).get_tensor_view_type();
auto argi_shape = argi_tensor_view_type->get_shape();
if (argi_shape.size() != arg0_shape.size())
{
......@@ -85,7 +65,6 @@ void Concat::propagate_types()
}
}
}
vector<size_t> concatenated_shape = arg0_shape;
concatenated_shape.at(m_concatenation_axis) = concatenation_axis_length;
......
......@@ -63,18 +63,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------------------------- |
/// | NGVM | Implemented for vectors and matrices. |
class Concat : public Builtin
class Concat : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const Nodes& args, size_t concatenation_axis)
: Builtin(args)
, m_concatenation_axis(concatenation_axis)
{
}
Concat(const Nodes& args, size_t concatenation_axis);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -82,9 +78,7 @@ namespace ngraph
return std::make_shared<Concat>(new_args, m_concatenation_axis);
}
virtual std::string description() const override { return "Concatenate"; }
virtual void propagate_types() override;
virtual std::string description() const override { return "Concat"; }
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
protected:
......
......@@ -14,24 +14,44 @@
#include "ngraph/ops/constant.hpp"
using namespace ngraph::op;
using namespace ngraph;
void ConstantBase::propagate_types()
namespace
{
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
{
auto result = ET::read(value_strings);
}
}
template <typename ET>
void check_value_strings(const std::vector<std::string>& value_strings)
op::Constant::Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
auto result = ET::read(value_strings);
check_args();
}
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
op::Constant::Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
check_args();
}
void Constant::propagate_types()
void op::Constant::check_args()
{
// No actual type propagation is done here; however, we check the number of value strings and
// We check the number of value strings and
// also call check_value_strings just to make sure the result will be parseable at compile
// time. (It will throw an exception if not.)
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(get_value_type());
auto tvt = std::dynamic_pointer_cast<const TensorViewType>(m_value_type);
if (nullptr == tvt)
{
throw ngraph_error("Constant does not have tensor view type");
......
......@@ -39,8 +39,6 @@ namespace ngraph
: Node({}, type)
{
}
virtual void propagate_types() override;
};
/// \brief Class for constants whose element types are known at C++ compile-time.
......@@ -162,22 +160,14 @@ namespace ngraph
/// \param value_strings A list of literals for initializing the tensor constant. There must be one literal for each element of the tensor; i.e., `value_strings.size()` must equal `ngraph::shape_size(shape)`.
Constant(const element::Type& et,
const Shape& shape,
const std::vector<std::string>& value_strings)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(value_strings)
{
}
const std::vector<std::string>& value_strings);
/// \brief Constructs a tensor constant with the same initialization value copied across the tensor.
///
/// \param et The element type of the tensor constant.
/// \param shape The shape of the tensor constant.
/// \param value_string A literal for initializing each tensor constant.
Constant(const element::Type& et, const Shape& shape, const std::string& value_string)
: ConstantBase(std::make_shared<TensorViewType>(et, shape))
, m_value_strings(ngraph::shape_size(shape), value_string)
{
}
Constant(const element::Type& et, const Shape& shape, const std::string& value_string);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -197,9 +187,9 @@ namespace ngraph
/// \return The initialization literals for the tensor constant.
const std::vector<std::string>& get_value_strings() const { return m_value_strings; }
virtual void propagate_types() override;
protected:
void check_args();
const std::vector<std::string> m_value_strings;
};
}
......
......@@ -18,9 +18,13 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
const element::Type& Convert::propagate_element_types(const element::Type& arg_element_type) const
op::Convert::Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwise(
[&](const ngraph::element::Type& ignored) -> const ngraph::element::Type& {
return element_type;
},
arg)
, m_element_type(element_type)
{
return m_element_type;
}
......@@ -48,18 +48,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Convert : public UnaryElementwiseBuiltin
class Convert : public UnaryElementwise
{
public:
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: UnaryElementwiseBuiltin({arg})
, m_element_type(element_type)
{
}
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -69,8 +65,6 @@ namespace ngraph
return std::make_shared<Convert>(new_args.at(0), m_element_type);
}
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const override;
const element::Type& get_convert_element_type() const { return m_element_type; }
virtual std::string description() const override { return "Convert"; }
protected:
......
......@@ -23,18 +23,14 @@
#include "ngraph/shape.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Dot::propagate_types()
op::Dot::Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs({arg0, arg1})
{
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
......@@ -108,8 +104,7 @@ ngraph::AxisVector range<ngraph::AxisVector>(size_t n)
return result;
}
void ngraph::op::Dot::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
void op::Dot::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
auto y = m_arguments[1];
......
......@@ -102,17 +102,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ---------------------------------------------- |
/// | NGVM | Implemented for `arg1` with rank of exactly 2. |
class Dot : public Builtin
class Dot : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a dot product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -123,8 +120,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Dot"; }
virtual void propagate_types() override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
......
......@@ -16,9 +16,12 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void FunctionCall::propagate_types()
op::FunctionCall::FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Node(args)
, m_function(function)
{
auto& function_params = m_function->get_parameters();
......
......@@ -14,8 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -46,7 +45,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class FunctionCall : public Builtin
class FunctionCall : public ngraph::Node
{
public:
/// \brief Constructs a function call operation.
......@@ -54,11 +53,7 @@ namespace ngraph
/// \param function The function to be called.
/// \param args The arguments for the function call.
FunctionCall(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<Node>>& args)
: Builtin(args)
, m_function(function)
{
}
const std::vector<std::shared_ptr<Node>>& args);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -67,8 +62,6 @@ namespace ngraph
}
virtual std::string description() const override { return "FunctionCall"; }
virtual void propagate_types() override;
/// \return The function to be called.
std::shared_ptr<Function> get_function() const { return m_function; }
protected:
......
......@@ -17,15 +17,12 @@
#include "ngraph/ops/get_tuple_element.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void GetTupleElement::propagate_types()
op::GetTupleElement::GetTupleElement(const std::shared_ptr<Node>& arg, size_t n)
: Node({arg})
, m_n{n}
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tuple_type =
dynamic_pointer_cast<const TupleType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg0_tuple_type)
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -47,18 +47,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class GetTupleElement : public Builtin
class GetTupleElement : public ngraph::Node
{
public:
/// \brief Constructs a get-tuple-element operation.
///
/// \param arg The input tuple.
/// \param n The index of the tuple element to get.
GetTupleElement(const std::shared_ptr<Node>& arg, size_t n)
: Builtin({arg})
, m_n{n}
{
}
GetTupleElement(const std::shared_ptr<Node>& arg, size_t n);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -68,7 +64,6 @@ namespace ngraph
return std::make_shared<GetTupleElement>(new_args.at(0), m_n);
}
virtual void propagate_types() override;
virtual std::string description() const override { return "GetTupleElement"; }
/// \return The index of the tuple element to get.
size_t get_n() const { return m_n; }
......
......@@ -13,9 +13,24 @@
// ----------------------------------------------------------------------------
#include <algorithm>
#include <memory>
#include <sstream>
#include "ngraph/except.hpp"
#include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
using namespace std;
op::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::vector<std::shared_ptr<Node>>& args)
: Node(args)
{
for (auto arg : args)
{
if (nullptr == std::dynamic_pointer_cast<const TensorViewType>(arg->get_value_type()))
{
throw ngraph_error("Arguments must be tensor views");
}
}
}
......@@ -14,6 +14,7 @@
#pragma once
#include <functional>
#include <memory>
#include "ngraph/node.hpp"
......@@ -25,19 +26,14 @@ namespace ngraph
// TODO: These class definitions are to be moved into separate files in the op directory
namespace op
{
/// \brief Abstract base class for built-in (primitive) operations.
class Builtin : public Node
/// \brief Abstract base class for ops on tensors views.
class RequiresTensorViewArgs : public Node
{
public:
virtual std::string description() const override { return "Builtin"; }
protected:
/// \brief Constructs a builtin operation.
/// \brief Constructs an operation on tensor view arguments.
///
/// \param args The nodes producing this node's input tensors.
Builtin(const std::vector<std::shared_ptr<Node>>& args)
: Node(args)
{
}
RequiresTensorViewArgs(const std::vector<std::shared_ptr<Node>>& args);
};
/// \brief Abstract base class for elementwise unary operations, i.e., operations where the same
......@@ -50,38 +46,22 @@ namespace ngraph
///
/// | | Type | Description |
/// | ----- | --------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E\f$ (see UnaryElementwiseBuiltin::propagate_element_types). |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E\f$. |
///
/// ## Output
///
/// | Type | Description |
/// | ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E'[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensor, but subclasses must determine the element type \f$E'\f$ (see UnaryElementwiseBuiltin::propagate_element_types). |
class UnaryElementwiseBuiltin : public Builtin
/// | \f$E'[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensor, but subclasses must determine the element type \f$E'\f$. |
class UnaryElementwise : public RequiresTensorViewArgs
{
protected:
/// \brief Constructs a unary elementwise builtin operation.
/// \brief Constructs a unary elementwise tensor operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwiseBuiltin(const std::shared_ptr<Node>& arg)
: Builtin(Nodes{arg})
{
}
/// \brief Propagate element type from input to output.
///
/// Subclasses must override this method to both:
///
/// 1. Verify that `arg_element_type` is valid, throwing an ngraph_error message if it is not.
/// 2. Infer and return the element type for the return tensor.
///
/// \param arg_element_type The element type of the input tensor.
/// \return The inferred element type for the output tensor.
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const = 0;
public:
virtual void propagate_types() override;
UnaryElementwise(
std::function<const element::Type&(const element::Type&)> element_type_function,
const std::shared_ptr<Node>& arg);
};
/// \brief Abstract base class for elementwise unary arithmetic operations, i.e., operations where the same
......@@ -101,26 +81,13 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensor. |
class UnaryElementwiseArithmetic : public UnaryElementwiseBuiltin
class UnaryElementwiseArithmetic : public UnaryElementwise
{
protected:
/// \brief Constructs a unary elementwise builtin arithmetic operation.
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: UnaryElementwiseBuiltin({arg})
{
}
/// \brief Propagate element type from input to output.
///
/// If the input type is numeric, returns the input type.
/// If the input type is not numeric, throws ngraph_error.
///
/// \param arg_element_type The element type of the input tensor.
/// \return The inferred element type (same as arg_element_type).
virtual const element::Type&
propagate_element_types(const element::Type& arg_element_type) const final override;
UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg);
};
/// \brief Abstract base class for elementwise binary operations, i.e., operations where the same
......@@ -134,43 +101,26 @@ namespace ngraph
///
/// | | Type | Description |
/// | ------ | ----------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg0` | \f$E_0[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E_0\f$ (see BinaryElementwiseBuiltin::propagate_element_types). |
/// | `arg1` | \f$E_1[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape as `arg0`. Subclasses may impose restrictions on the element type \f$E_1\f$ (see BinaryElementwiseBuiltin::propagate_element_types). |
/// | `arg0` | \f$E_0[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E_0\f$. |
/// | `arg1` | \f$E_1[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape as `arg0`. Subclasses may impose restrictions on the element type \f$E_1\f$. |
///
/// ## Output
///
/// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E_2[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, but subclasses must determine the element type \f$E_2\f$ (see BinaryElementwiseBuiltin::propagate_element_types). |
class BinaryElementwiseBuiltin : public Builtin
/// | \f$E_2[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, but subclasses must determine the element type \f$E_2\f$. |
class BinaryElementwise : public RequiresTensorViewArgs
{
protected:
/// \brief Constructs a biary elementwise builtin operation.
/// \brief Constructs a biary elementwise operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseBuiltin(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: Builtin(Nodes{arg0, arg1})
{
}
/// \brief Propagate element type from inputs to output.
///
/// Subclasses must override this method to:
///
/// 1. Verify that `arg0_element_type` and `arg1_element_type` are valid, throwing an ngraph_error message if not; and
/// 2. Infer and return the element type for the return tensor.
///
/// \param arg0_element_type The element type of the first input tensor.
/// \param arg1_element_type The element type of the second input tensor.
/// \return The inferred element type for the output tensor.
virtual const element::Type&
propagate_element_types(const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const = 0;
public:
virtual void propagate_types() override;
BinaryElementwise(
std::function<const element::Type&(const element::Type&, const element::Type&)>
element_type_function,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
/// \brief Abstract base class for elementwise binary comparison operations, i.e., operations where the same
......@@ -192,36 +142,15 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseComparison : public BinaryElementwiseBuiltin
class BinaryElementwiseComparison : public BinaryElementwise
{
public:
/// \brief Constructs a biary elementwise builtin comparison operation.
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseComparison(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string description() const override
{
return "BinaryElementwiseComparison";
}
/// \brief Propagate element type from inputs to output.
///
/// If the input types are the same, returns
/// element::Bool::element_type(). If the input types are not the
/// same, throws ngraph_error.
///
/// \param arg0_element_type The element type of the first input tensor.
/// \param arg1_element_type The element type of the second input tensor.
/// \return The inferred element type for the output tensor, which is always element::Bool::element_type().
virtual const element::Type&
propagate_element_types(const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const override;
const std::shared_ptr<Node>& arg1);
};
/// \brief Abstract base class for elementwise binary arithmetic operations, i.e., operations where the same
......@@ -243,36 +172,15 @@ namespace ngraph
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors. |
class BinaryElementwiseArithmetic : public BinaryElementwiseBuiltin
class BinaryElementwiseArithmetic : public BinaryElementwise
{
public:
/// \brief Constructs a binary elementwise builtin arithmetic operation.
/// \brief Constructs a binary elementwise arithmetic operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseArithmetic(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwiseBuiltin(arg0, arg1)
{
}
virtual std::string description() const override
{
return "BinaryElementwiseArithmetic";
}
/// \brief Propagate element type from inputs to output.
///
/// If the input types are the same type, and that type is numeric,
/// returns the input type. If the input types are not the same or
/// are not numeric, throws ngraph_error.
///
/// \param arg0_element_type The element type of the first input tensor.
/// \param arg1_element_type The element type of the second input tensor.
/// \return The inferred element type for the output tensor, which is always the same as that of the input tensors.
virtual const element::Type& propagate_element_types(
const element::Type& arg0_element_type,
const element::Type& arg1_element_type) const final override;
const std::shared_ptr<Node>& arg1);
};
}
}
......@@ -29,10 +29,6 @@ Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& sha
{
}
void Parameter::propagate_types()
{
}
void Parameter::generate_adjoints(autodiff::Adjoints& adjoints, const std::shared_ptr<Node>& delta)
{
}
......@@ -71,7 +71,6 @@ namespace ngraph
}
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
};
}
}
......@@ -16,37 +16,19 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
void Reduce::propagate_types()
using namespace ngraph;
op::Reduce::Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes)
: RequiresTensorViewArgs({arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
{
if (m_arguments.size() != 2)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_reductee_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_reductee_type)
{
throw ngraph_error("Argument to reduce is missing type.");
}
auto arg_reductee_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(arg_reductee_type);
if (nullptr == arg_reductee_tensor_view_type)
{
throw ngraph_error("Argument to reduce is not a tensor view");
}
auto arg_reductee_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_init_type = m_arguments.at(1)->get_value_type();
if (nullptr == arg_init_type)
{
throw ngraph_error("Argument for initial value is missing type.");
}
auto arg_init_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_init_type);
if (nullptr == arg_init_tensor_view_type)
{
throw ngraph_error("Argument for initial value is not a tensor view");
}
auto arg_init_tensor_view_type = get_inputs().at(1).get_tensor_view_type();
if (arg_init_tensor_view_type->get_shape().size() != 0)
{
throw ngraph_error("Argument for initial value is not a scalar");
......@@ -85,18 +67,18 @@ void Reduce::propagate_types()
throw ngraph_error("Reduction function has wrong number of parameters (should be two)");
}
if (*(f_params.at(0)->get_value_type()) != *(arg_init_type))
if (*(f_params.at(0)->get_value_type()) != *(arg_init->get_value_type()))
{
throw ngraph_error("Argument 0 of reduction function has wrong type");
}
if (*(f_params.at(1)->get_value_type()) != *(arg_init_type))
if (*(f_params.at(1)->get_value_type()) != *(arg_init->get_value_type()))
{
throw ngraph_error("Argument 1 of reduction function has wrong type");
}
auto f_result_type = m_reduction_function->get_result_type();
if (*(f_result_type) != *(arg_init_type))
if (*(f_result_type) != *(arg_init->get_value_type()))
{
throw ngraph_error("Return type from reduction function does not match expected");
}
......
......@@ -87,7 +87,7 @@ namespace ngraph
/// | ------- | ----------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. |
class Reduce : public Builtin
class Reduce : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reduction operation.
......@@ -99,12 +99,7 @@ namespace ngraph
Reduce(const std::shared_ptr<Node>& arg_reductee,
const std::shared_ptr<Node>& arg_init,
const std::shared_ptr<Function>& reduction_function,
const AxisSet& reduction_axes)
: Builtin({arg_reductee, arg_init})
, m_reduction_function(reduction_function)
, m_reduction_axes(reduction_axes)
{
}
const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -116,8 +111,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Reduce"; }
virtual void propagate_types() override;
/// \return The function to use for reduction.
std::shared_ptr<Function> get_reduction_function() const
{
......
......@@ -18,26 +18,16 @@
#include <algorithm>
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Reshape::propagate_types()
op::Reshape::Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: RequiresTensorViewArgs({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to reshape is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to reshape is not a tensor view");
}
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
......@@ -79,7 +69,7 @@ void Reshape::propagate_types()
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_output_shape));
}
void ngraph::op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
void op::Reshape::generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta)
{
auto x = m_arguments[0];
......
......@@ -59,7 +59,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. Implemented for other shapes only when there is no reordering of the input axes, i.e. `input_order` is \f$(0,\dots,n-1)\f$. |
class Reshape : public Builtin
class Reshape : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a reshape operation.
......@@ -71,12 +71,7 @@ namespace ngraph
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: Builtin({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
}
const Shape& output_shape);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -87,8 +82,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Reshape"; }
virtual void propagate_types() override;
/// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; }
/// \return The shape of the output tensor.
......
......@@ -19,25 +19,16 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void Select::propagate_types()
op::Select::Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: RequiresTensorViewArgs(Nodes{arg0, arg1, arg2})
{
if (m_arguments.size() != 3)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg0_tensor_type = get_inputs().at(0).get_tensor_view_type();
auto arg1_tensor_type = get_inputs().at(1).get_tensor_view_type();
auto arg2_tensor_type = get_inputs().at(2).get_tensor_view_type();
auto arg0_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(1)->get_value_type());
auto arg2_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(2)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type || nullptr == arg2_tensor_type)
{
throw ngraph_error("Arguments must be tensor views");
}
if (arg0_tensor_type->get_element_type() != element::Bool::element_type())
{
throw ngraph_error("Argument 0 for arithmetic operators must have boolean element type");
......
......@@ -41,7 +41,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Select : public Builtin
class Select : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a selection operation.
......@@ -51,10 +51,7 @@ namespace ngraph
/// \param arg2 Node that produces the third input tensor.
Select(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
const std::shared_ptr<Node>& arg2)
: Builtin(Nodes{arg0, arg1, arg2})
{
}
const std::shared_ptr<Node>& arg2);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -65,7 +62,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Select"; }
virtual void propagate_types() override;
};
}
}
......@@ -15,25 +15,34 @@
#include "ngraph/ops/slice.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Slice::propagate_types()
op::Slice::Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: RequiresTensorViewArgs({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
check_args();
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to slice is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to slice is not a tensor view");
}
op::Slice::Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: RequiresTensorViewArgs({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
{
check_args();
}
void op::Slice::check_args()
{
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_shape = arg_tensor_view_type->get_shape();
if (m_lower_bounds.size() != arg_shape.size())
......
......@@ -52,7 +52,7 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ----------------------------------------------- |
/// | NGVM | Implemented for scalars, matrices, and vectors. |
class Slice : public Builtin
class Slice : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a tensor slice operation.
......@@ -65,13 +65,7 @@ namespace ngraph
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
}
const Shape& step);
/// \brief Constructs a tensor slice operation with unit step; i.e., every element inside the bounding box will be copied to the output slice.
///
......@@ -80,13 +74,7 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
{
}
const Coordinate& upper_bounds);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -98,8 +86,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Slice"; }
virtual void propagate_types() override;
/// \return The inclusive lower-bound coordinates.
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
/// \return The exclusive upper-bound coordinates.
......@@ -107,6 +93,8 @@ namespace ngraph
/// \return The slicing step.
const Shape& get_step() const { return m_step; }
protected:
void check_args();
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Shape m_step;
......
......@@ -16,26 +16,13 @@
#include "ngraph/function.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Sum::propagate_types()
op::Sum::Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: RequiresTensorViewArgs({arg})
, m_reduction_axes(reduction_axes)
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to sum is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to sum is not a tensor view");
}
auto arg_tensor_view_type = get_inputs().at(0).get_tensor_view_type();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
if (arg_element_type == element::Bool::element_type())
{
......
......@@ -80,18 +80,14 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ----------------------------------------------------- |
/// | NGVM | Fully implemented for scalars, vectors, and matrices. |
class Sum : public Builtin
class Sum : public RequiresTensorViewArgs
{
public:
/// \brief Constructs a summation operation.
///
/// \param arg The tensor view to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes)
: Builtin({arg})
, m_reduction_axes(reduction_axes)
{
}
Sum(const std::shared_ptr<Node>& arg, const AxisSet& reduction_axes);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -102,8 +98,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Sum"; }
virtual void propagate_types() override;
/// \return The axis positions (0-based) to be eliminated through summation.
const AxisSet& get_reduction_axes() const { return m_reduction_axes; }
protected:
......
......@@ -18,9 +18,10 @@
#include "ngraph/ops/tuple.hpp"
using namespace std;
using namespace ngraph::op;
using namespace ngraph;
void Tuple::propagate_types()
op::Tuple::Tuple(const Nodes& args)
: Node(args)
{
vector<shared_ptr<const ValueType>> element_types;
for (auto argument : m_arguments)
......
......@@ -14,7 +14,7 @@
#pragma once
#include "ngraph/ops/op.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
......@@ -39,16 +39,13 @@ namespace ngraph
/// | Backend | Status |
/// | ------- | ------------------ |
/// | NGVM | Fully implemented. |
class Tuple : public Builtin
class Tuple : public ngraph::Node
{
public:
/// \brief Constructs a tuple construction operation.
///
/// \param args The nodes that produce the elements of the constructed tuple.
Tuple(const Nodes& args)
: Builtin(args)
{
}
Tuple(const Nodes& args);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
......@@ -57,7 +54,6 @@ namespace ngraph
}
virtual std::string description() const override { return "Tuple"; }
virtual void propagate_types() override;
};
}
}
......@@ -18,24 +18,15 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void UnaryElementwiseBuiltin::propagate_types()
op::UnaryElementwise::UnaryElementwise(
std::function<const element::Type&(const element::Type&)> element_type_function,
const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(Nodes{arg})
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_tensor_type =
dynamic_pointer_cast<const TensorViewType>(m_arguments.at(0)->get_value_type());
if (nullptr == arg_tensor_type)
{
throw ngraph_error("Argument must be tensor view");
}
auto arg_tensor_type = get_inputs().at(0).get_tensor_view_type();
const element::Type& result_element_type =
propagate_element_types(arg_tensor_type->get_element_type());
element_type_function(arg_tensor_type->get_element_type());
set_value_type_checked(
make_shared<TensorViewType>(result_element_type, arg_tensor_type->get_shape()));
......
......@@ -15,15 +15,19 @@
#include "ngraph/ops/op.hpp"
using namespace ngraph;
using namespace ngraph::op;
const element::Type&
UnaryElementwiseArithmetic::propagate_element_types(const element::Type& arg_element_type) const
{
op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::shared_ptr<Node>& arg)
: UnaryElementwise(
[](const ngraph::element::Type& arg_element_type) -> const ngraph::element::Type& {
if (arg_element_type == element::Bool::element_type())
{
throw ngraph_error("Operands for arithmetic operators must have numeric element type");
throw ngraph_error(
"Operands for arithmetic operators must have numeric element "
"type");
}
return arg_element_type;
},
arg)
{
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/pass/assign_tensors.hpp"
#include <exception>
#include <sstream>
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
using namespace std;
using namespace ngraph;
bool pass::AssignTensors::run_on_call_graph(list<std::shared_ptr<Node>>& nodes)
{
for (shared_ptr<Node> node : nodes)
{
try
{
// We need to set the nodes is_output state prior to call assign_tensors
// so that the output state can be passes to the constructed tensors.
if (node == get_state().get_functions().at(0)->get_result())
{
node->set_is_output();
}
node->assign_tensors();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class AssignTensors;
}
}
class ngraph::pass::AssignTensors : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override;
private:
};
......@@ -20,7 +20,6 @@
#include "ngraph/descriptor/output.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/util.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <sstream>
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/propagate_types.hpp"
using namespace std;
using namespace ngraph;
bool pass::PropagateTypes::run_on_call_graph(list<shared_ptr<Node>>& nodes)
{
for (shared_ptr<Node> node : nodes)
{
try
{
node->propagate_types();
}
catch (exception& e)
{
stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw invalid_argument(ss.str());
}
}
return false;
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class PropagateTypes;
}
}
class ngraph::pass::PropagateTypes : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>&) override;
private:
};
......@@ -66,9 +66,7 @@
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/emitter.hpp"
......@@ -159,8 +157,6 @@ void ExternalFunction::compile(FunctionMap& function_map)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
// For now, just make everyone row-major.
pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
pass_manager.run_passes(m_function);
......
......@@ -62,9 +62,7 @@
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/ngvm/eigen/abs.hpp"
#include "ngraph/runtime/ngvm/eigen/acos.hpp"
......@@ -1023,8 +1021,6 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Get the ordered list of ops in execution order
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.run_passes(m_function);
// Turn this into a pass
......
......@@ -20,6 +20,7 @@
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <type_traits>
......
......@@ -109,7 +109,7 @@ TEST(build_graph, tensor)
auto float_tensor_type =
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3});
ASSERT_EQ(*float0->get_value_type(), *float_tensor_type);
auto d = make_shared<op::Dot>(float0, float0);
auto d = make_shared<op::Add>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
......
......@@ -112,7 +112,7 @@ TEST(copy, concat)
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
size_t axis = 1;
size_t axis = 0;
auto node = make_shared<op::Concat>(Nodes{arg0, arg1}, axis);
auto new_node = node->copy_with_new_args(new_args);
auto node_cast = dynamic_pointer_cast<op::Concat>(new_node);
......@@ -219,9 +219,11 @@ TEST(copy, FunctionCall)
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto node = make_shared<op::FunctionCall>(f, Nodes{arg0, arg1});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), shape);
auto node = make_shared<op::FunctionCall>(f, Nodes{arg0, arg1, arg2});
std::vector<std::shared_ptr<Node>> new_args{
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape),
make_shared<op::Parameter>(element::Float32::element_type(), shape)};
auto new_node = node->copy_with_new_args(new_args);
......
......@@ -25,8 +25,6 @@ TEST(input_output, param_tensor)
// Params have no arguments, so we can check that the value becomes a tensor output
auto tv_tp = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4});
auto param = make_shared<op::Parameter>(tv_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 1);
for (size_t i = 0; i < param->get_outputs().size(); i++)
......@@ -46,8 +44,6 @@ TEST(input_output, param_tuple)
auto tv_tp_1 = make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4, 6});
auto tp_tp = make_shared<TupleType>(ValueTypes{tv_tp_0, tv_tp_1});
auto param = make_shared<op::Parameter>(tp_tp);
param->propagate_types();
param->assign_tensors();
ASSERT_EQ(param->get_outputs().size(), 2);
for (size_t i = 0; i < param->get_outputs().size(); i++)
......@@ -74,19 +70,8 @@ TEST(input_output, simple_output)
nodes.push_back(param_1);
nodes.push_back(add);
// Type info
for (auto node : nodes)
{
node->propagate_types();
}
// Add inputs/outputs
for (auto node : nodes)
{
node->assign_tensors();
}
// At this point, the add should have each input associated with the output of the appropriate parameter
ASSERT_EQ(1, add->get_outputs().size());
auto& inputs = add->get_inputs();
ASSERT_EQ(2, inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
......
......@@ -21,12 +21,10 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
......@@ -44,8 +42,6 @@ TEST(pass, liveness)
pass_manager.register_pass<pass::VisualizeTree>(image);
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
......
......@@ -20,9 +20,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/util.hpp"
#include "util/test_tools.hpp"
......@@ -35,8 +33,6 @@ TEST(pass_manager, add)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
auto graph = make_test_graph();
size_t node_count = 0;
......
......@@ -20,13 +20,11 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "util/test_tools.hpp"
......@@ -210,8 +208,6 @@ TEST(memory_layout, basic)
string dump_file = "memory_layout.txt";
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>();
pass_manager.register_pass<pass::DumpSorted>(dump_file);
......
......@@ -21,10 +21,8 @@
#include "gtest/gtest.h"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "util/test_tools.hpp"
......@@ -37,8 +35,6 @@ TEST(tensor, size)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
pass_manager.register_pass<pass::Liveness>();
{
......
......@@ -21,11 +21,9 @@
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/collect_functions.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp"
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
......@@ -204,8 +202,6 @@ TEST(topological_sort, unused_function_arg)
pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>();
// pass_manager.register_pass<pass::DumpSorted>("sorted.txt");
pass_manager.run_passes(f);
list<shared_ptr<Node>> ops = f->get_ordered_ops();
......
......@@ -20,11 +20,6 @@
using namespace std;
using namespace ngraph;
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node);
void test_binary_bad_arguments_views(const shared_ptr<Node>& node);
void test_binary_good_arguments(const shared_ptr<Node>& node);
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y));
//
// Tests for broadcast.
//
......@@ -33,7 +28,6 @@ TEST(type_prop, broadcast_deduce)
// Deduce type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
......@@ -43,9 +37,6 @@ TEST(type_prop, broadcast_deduce_correct)
// Check deduced type against correctly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 3, 4}, AxisSet{1});
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 4}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 4}));
}
......@@ -54,12 +45,11 @@ TEST(type_prop, broadcast_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 4});
auto bc = make_shared<op::Broadcast>(param, Shape{2, 4, 3}, AxisSet{1});
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 3, 4}));
try
{
bc->propagate_types();
auto bc = make_shared<op::Broadcast>(param, Shape{2, 4, 3}, AxisSet{1});
bc->assert_value_type(element::Float32::element_type(), Shape{2, 3, 4});
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -75,18 +65,17 @@ TEST(type_prop, broadcast_deduce_incorrect)
TEST(type_prop, broadcast_bad_arguments)
{
try
{
// Check for bad arguments
auto param = make_shared<op::Parameter>(make_shared<TupleType>());
auto bc = make_shared<op::Broadcast>(param, Shape{2, 4, 3}, AxisSet{1});
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument to broadcast not detected.";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Argument to broadcast is not a tensor view"));
EXPECT_EQ(error.what(), std::string("Arguments must be tensor views"));
}
catch (...)
{
......@@ -101,7 +90,6 @@ TEST(type_prop, concat_deduce)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 2, 4});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Float32::element_type(), Shape{2, 12, 4}));
}
......@@ -112,12 +100,10 @@ TEST(type_prop, concat_deduce_incorrect)
auto param0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 2, 4});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
c->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 14, 4}));
try
{
c->propagate_types();
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
c->assert_value_type(element::Float32::element_type(), Shape{2, 14, 4});
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -139,10 +125,9 @@ TEST(type_prop, concat_deduce_wrong_rank)
Shape{
2, 2,
});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
try
{
c->propagate_types();
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -161,10 +146,9 @@ TEST(type_prop, concat_deduce_wrong_shape)
auto param0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 2, 5});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
try
{
c->propagate_types();
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -186,10 +170,9 @@ TEST(type_prop, concat_deduce_axis_oob)
auto param0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 2, 5});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 3);
try
{
c->propagate_types();
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 3);
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -210,7 +193,6 @@ TEST(type_prop, concat_deduce_axis_barely_in_bounds)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 8});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 12});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 2);
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Float32::element_type(), Shape{2, 3, 24}));
}
......@@ -220,10 +202,9 @@ TEST(type_prop, concat_deduce_elem_type_mismatch)
auto param0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto param1 = make_shared<op::Parameter>(element::Int32::element_type(), Shape{2, 7, 4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 2, 4});
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
try
{
c->propagate_types();
auto c = make_shared<op::Concat>(Nodes{param0, param1, param2}, 1);
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -242,7 +223,6 @@ TEST(type_prop, convert_deduce)
// Deduce type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
......@@ -252,8 +232,6 @@ TEST(type_prop, convert_deduce_correct)
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 3, 4}));
c->propagate_types();
auto c_vt = c->get_value_type();
ASSERT_EQ(*c_vt, TensorViewType(element::Int32::element_type(), Shape{2, 3, 4}));
}
......@@ -262,11 +240,10 @@ TEST(type_prop, convert_deduce_incorrect)
{
// Check deduced type against incorrectly specified type
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3, 4});
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->set_value_type(make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 14, 4}));
try
{
c->propagate_types();
auto c = make_shared<op::Convert>(param, element::Int32::element_type());
c->assert_value_type(element::Int32::element_type(), Shape{2, 14, 4});
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
}
......@@ -286,7 +263,6 @@ TEST(type_prop, dot_deduce_scalar_2d)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 5});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4, 5}));
}
......@@ -297,7 +273,6 @@ TEST(type_prop, dot_deduce_2d_scalar)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 5});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4, 5}));
}
......@@ -308,7 +283,6 @@ TEST(type_prop, dot_deduce_scalar_scalar)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
......@@ -319,7 +293,6 @@ TEST(type_prop, dot_deduce_scalar_1d)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{6});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{6}));
}
......@@ -330,7 +303,6 @@ TEST(type_prop, dot_deduce_1d)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
......@@ -341,7 +313,6 @@ TEST(type_prop, dot_deduce_2d)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4, 3}));
}
......@@ -352,7 +323,6 @@ TEST(type_prop, dot_deduce_different_rank)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
}
......@@ -363,9 +333,6 @@ TEST(type_prop, dot_deduce_different_rank_correct)
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2, 8, 4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1, 2, 3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 8, 4, 1, 3}));
}
......@@ -375,10 +342,9 @@ TEST(type_prop, dot_deduce_element_type_mismatch)
// Type deduction fails due to element type mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 2});
auto param2 = make_shared<op::Parameter>(element::Int32::element_type(), Shape{2, 5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Dot>(param1, param2);
// Should have thrown, so fail if it didn't
FAIL() << "Element type mismatch not detected";
}
......@@ -397,10 +363,9 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
// Type deduction fails due to reduction axes size mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3, 5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Dot>(param1, param2);
// Should have thrown, so fail if it didn't
FAIL() << "Dot reduction axes size mismatch not detected";
}
......@@ -417,11 +382,26 @@ TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
//
// Tests for binary elementwise ops.
//
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tp0_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tp1_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
auto test_binary_bad_arguments_tuple = [&](const shared_ptr<Node>& x,
const shared_ptr<Node>& y) {
try
{
node->propagate_types();
auto node = f(x, y);
//node->get_value_type();
// Should have thrown, so fail if it didn't
FAIL() << "Tuple argument not detected.";
}
......@@ -433,13 +413,17 @@ void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
};
void test_binary_bad_arguments_view_shapes(const shared_ptr<Node>& node)
{
test_binary_bad_arguments_tuple(tp0_param, tp1_param);
test_binary_bad_arguments_tuple(tp0_param, tv0_2_4_param_0);
test_binary_bad_arguments_tuple(tv0_2_4_param_0, tp0_param);
auto test_binary_bad_arguments_view_shapes = [&](const shared_ptr<Node>& x,
const shared_ptr<Node>& y) {
try
{
node->propagate_types();
auto node = f(x, y);
node->get_value_type();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
......@@ -451,13 +435,15 @@ void test_binary_bad_arguments_view_shapes(const shared_ptr<Node>& node)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
};
test_binary_bad_arguments_view_shapes(tv0_2_4_param_0, tv0_4_2_param);
void test_binary_bad_arguments_view_element_types(const shared_ptr<Node>& node)
{
auto test_binary_bad_arguments_view_element_types = [&](const shared_ptr<Node>& x,
const shared_ptr<Node>& y) {
try
{
node->propagate_types();
auto node = f(x, y);
node->get_value_type();
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible view arguments not detected.";
}
......@@ -470,34 +456,15 @@ void test_binary_bad_arguments_view_element_types(const shared_ptr<Node>& node)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
void test_binary_good_arguments(const shared_ptr<Node>& node)
{
node->propagate_types();
EXPECT_EQ(*node->get_value_type(), *node->get_arguments()[0]->get_value_type());
}
};
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y))
{
// Check for bad arguments
auto tp0_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tp1_param = make_shared<op::Parameter>(make_shared<TupleType>());
auto tv0_2_4_param_0 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 4}));
auto tv0_4_2_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 2}));
test_binary_bad_arguments_view_element_types(tv0_2_4_param_0, tv0_2_4_param_2);
test_binary_bad_arguments_tuple(f(tp0_param, tp1_param));
test_binary_bad_arguments_tuple(f(tp0_param, tv0_2_4_param_0));
test_binary_bad_arguments_tuple(f(tv0_2_4_param_0, tp0_param));
test_binary_bad_arguments_view_shapes(f(tv0_2_4_param_0, tv0_4_2_param));
test_binary_bad_arguments_view_element_types(f(tv0_2_4_param_0, tv0_2_4_param_2));
test_binary_good_arguments(f(tv0_2_4_param_0, tv0_2_4_param_1));
auto test_binary_good_arguments = [&](const shared_ptr<Node>& x, const shared_ptr<Node>& y) {
auto node = f(x, y);
EXPECT_EQ(*node->get_value_type(), *node->get_arguments()[0]->get_value_type());
};
test_binary_good_arguments(tv0_2_4_param_0, tv0_2_4_param_1);
}
TEST(type_prop, add_bad_arguments)
......@@ -536,7 +503,6 @@ TEST(type_prop, comparison_good)
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto eq = make_shared<op::Equal>(tv0_2_4_param_0, tv0_2_4_param_1);
TensorViewType expected_type{element::Bool::element_type(), Shape{2, 4}};
eq->propagate_types();
EXPECT_EQ(*eq->get_value_type(), expected_type);
}
......@@ -546,10 +512,9 @@ TEST(type_prop, binary_arithmetic_bad_argument_element_types)
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto tv0_2_4_param_1 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Add>(tv0_2_4_param_0, tv0_2_4_param_1);
try
{
bc->propagate_types();
auto bc = make_shared<op::Add>(tv0_2_4_param_0, tv0_2_4_param_1);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -568,10 +533,9 @@ TEST(type_prop, unary_arithmetic_bad_argument_element_types)
{
auto tv0_2_4_param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Bool::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Negative>(tv0_2_4_param);
try
{
bc->propagate_types();
auto bc = make_shared<op::Negative>(tv0_2_4_param);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -595,7 +559,6 @@ TEST(type_prop, select_deduce)
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
......@@ -609,8 +572,6 @@ TEST(type_prop, select_deduce_correct)
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
bc->set_value_type(make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
......@@ -623,10 +584,9 @@ TEST(type_prop, select_shape_mismatch_a)
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -648,10 +608,9 @@ TEST(type_prop, select_shape_mismatch_b)
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 5}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -673,10 +632,9 @@ TEST(type_prop, select_shape_mismatch_c)
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 5}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -698,10 +656,9 @@ TEST(type_prop, select_elem_mismatch_a)
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -725,10 +682,9 @@ TEST(type_prop, select_elem_mismatch_bc)
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2, 4}));
auto tv0_2_4_param_2 = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Int32::element_type(), Shape{2, 4}));
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
try
{
bc->propagate_types();
auto bc = make_shared<op::Select>(tv0_2_4_param_0, tv0_2_4_param_1, tv0_2_4_param_2);
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -758,19 +714,15 @@ TEST(type_prop, reduce_deduce)
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
r0->propagate_types();
ASSERT_EQ(*(r0->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{4}));
auto r1 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{1});
r1->propagate_types();
ASSERT_EQ(*(r1->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{2}));
auto r01 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0, 1});
r01->propagate_types();
ASSERT_EQ(*(r01->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
auto r_none = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{});
r_none->propagate_types();
ASSERT_EQ(*(r_none->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{2, 4}));
}
......@@ -790,8 +742,6 @@ TEST(type_prop, reduce_deduce_correct)
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
r0->set_value_type(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4}));
r0->propagate_types();
ASSERT_EQ(*(r0->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{4}));
}
......@@ -809,10 +759,9 @@ TEST(type_prop, reduce_nonscalar)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
try
{
r0->propagate_types();
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -840,10 +789,9 @@ TEST(type_prop, reduce_elem_type_mismatch)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
try
{
r0->propagate_types();
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -873,10 +821,9 @@ TEST(type_prop, reduce_function_return_type_mismatch)
auto f = make_shared<Function>(
make_shared<op::Equal>(f_param_0, f_param_1), rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
try
{
r0->propagate_types();
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -905,10 +852,9 @@ TEST(type_prop, reduce_function_arg0_type_mismatch)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
try
{
r0->propagate_types();
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -936,10 +882,9 @@ TEST(type_prop, reduce_function_arg1_type_mismatch)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0, rt, op::Parameters{f_param_0, f_param_1});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
try
{
r0->propagate_types();
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -970,10 +915,9 @@ TEST(type_prop, reduce_function_arg_count_mismatch)
auto f = make_shared<Function>(
f_param_0 + f_param_1 + f_param_2, rt, op::Parameters{f_param_0, f_param_1, f_param_2});
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
try
{
r0->propagate_types();
auto r0 = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -1002,10 +946,9 @@ TEST(type_prop, reduce_axis_oob)
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
auto f = make_shared<Function>(f_param_0 + f_param_1, rt, op::Parameters{f_param_0, f_param_1});
auto r = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0, 2, 1});
try
{
r->propagate_types();
auto r = make_shared<op::Reduce>(param_0, param_1, f, AxisSet{0, 2, 1});
// Should have thrown, so fail if it didn't
FAIL() << "Did not detect incorrect element types for arithmetic operator";
}
......@@ -1036,8 +979,6 @@ TEST(type_prop, function_call_deduce)
auto r = make_shared<op::FunctionCall>(f, Nodes{X, Y, Z});
auto r_p_r = r + r;
r->propagate_types();
r_p_r->propagate_types();
auto r_p_r_vt = r_p_r->get_value_type();
ASSERT_EQ(*r_p_r_vt, TensorViewType(element::Float32::element_type(), shape));
}
......@@ -1047,7 +988,6 @@ TEST(type_prop, reshape_deduce_s2v)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{1}));
}
......@@ -1056,7 +996,6 @@ TEST(type_prop, reshape_deduce_s2m)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1, 1});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 1}));
}
......@@ -1066,7 +1005,6 @@ TEST(type_prop, reshape_deduce_s2t)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{}));
auto r = make_shared<op::Reshape>(param, AxisVector{}, Shape{1, 1, 1});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 1, 1}));
}
......@@ -1076,7 +1014,6 @@ TEST(type_prop, reshape_deduce_v2s)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{1}));
auto r = make_shared<op::Reshape>(param, AxisVector{0}, Shape{});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
......@@ -1085,7 +1022,6 @@ TEST(type_prop, reshape_deduce_m2s)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{1, 1}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1}, Shape{});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
......@@ -1094,7 +1030,6 @@ TEST(type_prop, reshape_deduce_t2s)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{1, 1, 1}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1, 2}, Shape{});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
......@@ -1103,7 +1038,6 @@ TEST(type_prop, reshape_deduce_m2v_01)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1}, Shape{12});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{12}));
}
......@@ -1112,7 +1046,6 @@ TEST(type_prop, reshape_deduce_m2v_10)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{12});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{12}));
}
......@@ -1121,7 +1054,6 @@ TEST(type_prop, reshape_deduce_t2v_012)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{0, 1, 2}, Shape{60});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
......@@ -1130,7 +1062,6 @@ TEST(type_prop, reshape_deduce_t2v_120)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
......@@ -1139,8 +1070,6 @@ TEST(type_prop, reshape_deduce_correct_t2v_120)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{60});
r->set_value_type(make_shared<TensorViewType>(element::Float32::element_type(), Shape{60}));
r->propagate_types();
ASSERT_EQ(*(r->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{60}));
}
......@@ -1148,10 +1077,9 @@ TEST(type_prop, reshape_deduce_not_enough_axes)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{60});
try
{
r->propagate_types();
auto r = make_shared<op::Reshape>(param, AxisVector{1, 0}, Shape{60});
// Should have thrown, so fail if it didn't
FAIL() << "Not enough axes not detected";
}
......@@ -1171,10 +1099,9 @@ TEST(type_prop, reshape_deduce_too_many_axes)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0, 3}, Shape{60});
try
{
r->propagate_types();
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0, 3}, Shape{60});
// Should have thrown, so fail if it didn't
FAIL() << "Too many axes not detected";
}
......@@ -1194,10 +1121,9 @@ TEST(type_prop, reshape_deduce_duplicate_axes)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 1, 0}, Shape{60});
try
{
r->propagate_types();
auto r = make_shared<op::Reshape>(param, AxisVector{1, 1, 0}, Shape{60});
// Should have thrown, so fail if it didn't
FAIL() << "Too many axes not detected";
}
......@@ -1217,10 +1143,9 @@ TEST(type_prop, reshape_deduce_wrong_output_shape)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{3, 4, 5}));
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{3, 3, 3});
try
{
r->propagate_types();
auto r = make_shared<op::Reshape>(param, AxisVector{1, 2, 0}, Shape{3, 3, 3});
// Should have thrown, so fail if it didn't
FAIL() << "Too many axes not detected";
}
......@@ -1241,7 +1166,6 @@ TEST(type_prop, slice_deduce_vector)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{2}, Coordinate{5});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{3}));
}
......@@ -1250,7 +1174,6 @@ TEST(type_prop, slice_deduce_matrix)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{2, 1}, Coordinate{5, 7});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{3, 6}));
}
......@@ -1260,7 +1183,6 @@ TEST(type_prop, slice_deduce_matrix_strided)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{2, 1}, Coordinate{5, 7}, Shape{3, 2});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 3}));
}
......@@ -1270,7 +1192,6 @@ TEST(type_prop, slice_deduce_matrix_strided_uneven)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{2, 1}, Coordinate{5, 7}, Shape{3, 4});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 2}));
}
......@@ -1280,7 +1201,6 @@ TEST(type_prop, slice_deduce_vector_edge)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{6});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{6}));
}
......@@ -1289,7 +1209,6 @@ TEST(type_prop, slice_deduce_matrix_edge)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 8});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{6, 8}));
}
......@@ -1299,7 +1218,6 @@ TEST(type_prop, slice_deduce_matrix_zero_cols)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 0});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{6, 0}));
}
......@@ -1309,7 +1227,6 @@ TEST(type_prop, slice_deduce_matrix_zero_zero)
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{0, 0});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{0, 0}));
}
......@@ -1318,10 +1235,9 @@ TEST(type_prop, slice_deduce_vector_invalid_step)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{7}, Shape{1, 2});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{7}, Shape{1, 2});
// Should have thrown, so fail if it didn't
FAIL() << "Invalid slice step not detected";
}
......@@ -1342,10 +1258,9 @@ TEST(type_prop, slice_deduce_vector_edge_upper_oob)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{7});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{7});
// Should have thrown, so fail if it didn't
FAIL() << "Upper bound out of range not detected";
}
......@@ -1363,10 +1278,9 @@ TEST(type_prop, slice_deduce_matrix_edge_upper_oob)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 9});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 9});
// Should have thrown, so fail if it didn't
FAIL() << "Upper bound out of range not detected";
}
......@@ -1384,10 +1298,9 @@ TEST(type_prop, slice_deduce_vector_lower_above_upper)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{3}, Coordinate{2});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{3}, Coordinate{2});
// Should have thrown, so fail if it didn't
FAIL() << "Lower bound above upper not detected";
}
......@@ -1405,10 +1318,9 @@ TEST(type_prop, slice_deduce_matrix_lower_above_upper)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 5}, Coordinate{6, 4});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0, 5}, Coordinate{6, 4});
// Should have thrown, so fail if it didn't
FAIL() << "Lower bound above upper not detected";
}
......@@ -1426,10 +1338,9 @@ TEST(type_prop, slice_deduce_matrix_lower_missing)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{5, 5});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{5, 5});
// Should have thrown, so fail if it didn't
FAIL() << "Missing lower bound coordinate not detected";
}
......@@ -1450,10 +1361,9 @@ TEST(type_prop, slice_deduce_matrix_upper_missing)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{5});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{5});
// Should have thrown, so fail if it didn't
FAIL() << "Missing upper bound coordinate not detected";
}
......@@ -1474,10 +1384,9 @@ TEST(type_prop, slice_deduce_matrix_lower_extra)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0, 0}, Coordinate{5, 5});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0, 0}, Coordinate{5, 5});
// Should have thrown, so fail if it didn't
FAIL() << "Extra lower bound coordinate not detected";
}
......@@ -1498,10 +1407,9 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{5, 5, 5});
try
{
sl->propagate_types();
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{5, 5, 5});
// Should have thrown, so fail if it didn't
FAIL() << "Extra upper bound coordinate not detected";
}
......@@ -1521,14 +1429,12 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
TEST(type_prop, scalar_constant_deduce_float32)
{
auto c = make_shared<op::Constant>(element::Float32::element_type(), Shape{}, "208");
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, scalar_constant_deduce_bool)
{
auto c = make_shared<op::Constant>(element::Bool::element_type(), Shape{}, "1");
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{}));
}
......@@ -1537,7 +1443,6 @@ TEST(type_prop, tensor_constant_deduce_float32)
auto c = make_shared<op::Constant>(element::Float32::element_type(),
Shape{2, 2},
std::vector<std::string>{"208", "208", "208", "208"});
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{2, 2}));
}
......@@ -1546,18 +1451,16 @@ TEST(type_prop, tensor_constant_deduce_bool)
{
auto c = make_shared<op::Constant>(
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1", "1"});
c->propagate_types();
ASSERT_EQ(*(c->get_value_type()), TensorViewType(element::Bool::element_type(), Shape{2, 2}));
}
TEST(type_prop, tensor_constant_bad_parse)
{
try
{
auto c = make_shared<op::Constant>(element::Bool::element_type(),
Shape{2, 2},
std::vector<std::string>{"1", "grunk", "1", "1"});
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
......@@ -1573,12 +1476,11 @@ TEST(type_prop, tensor_constant_bad_parse)
TEST(type_prop, tensor_constant_bad_parse_float_for_int)
{
try
{
auto c = make_shared<op::Constant>(element::Int32::element_type(),
Shape{2, 2},
std::vector<std::string>{"1", "2.7", "1", "1"});
try
{
c->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Bad literal parse not detected";
}
......@@ -1594,11 +1496,10 @@ TEST(type_prop, tensor_constant_bad_parse_float_for_int)
TEST(type_prop, tensor_constant_bad_count)
{
auto c = make_shared<op::Constant>(
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1"});
try
{
c->propagate_types();
auto c = make_shared<op::Constant>(
element::Bool::element_type(), Shape{2, 2}, std::vector<std::string>{"1", "1", "1"});
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect number of literals not detected";
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment