Commit 6d6e923b authored by Scott Cyphers's avatar Scott Cyphers

Try universal node creator

parent 42f599c6
...@@ -23,12 +23,25 @@ namespace ngraph ...@@ -23,12 +23,25 @@ namespace ngraph
{ {
class Node; class Node;
class Parameter; class Parameter;
class ValueType;
template<typename T, typename ...A>
std::shared_ptr<T> node(A&&... args)
{
return std::make_shared<T>(args...);
}
/// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
/// Zero or more nodes /// Zero or more nodes
using Nodes = std::vector<std::shared_ptr<Node>>; using Nodes = std::vector<std::shared_ptr<Node>>;
/// A set of indices, for example, reduction axes /// A sequence of axes
using IndexSet = std::set<size_t>; using AxisList = std::vector<size_t>;
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// A list of parameters /// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>; using Parameters = std::vector<std::shared_ptr<Parameter>>;
......
...@@ -19,15 +19,13 @@ namespace ngraph ...@@ -19,15 +19,13 @@ namespace ngraph
class BroadcastOp : public BuiltinOp class BroadcastOp : public BuiltinOp
{ {
public: public:
using Axes = std::vector<size_t>;
/** /**
** /param arg The tensor view to be broadcast. ** /param arg The tensor view to be broadcast.
** /param shape The shape of the result ** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. ** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg. ** the remaining axes in shape must be the same as the shape of arg.
**/ **/
BroadcastOp(const Node::ptr& arg, const Shape& shape, const Axes& broadcast_axes) BroadcastOp(const Node::ptr& arg, const Shape& shape, AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
...@@ -39,13 +37,13 @@ namespace ngraph ...@@ -39,13 +37,13 @@ namespace ngraph
protected: protected:
Shape m_shape; Shape m_shape;
Axes m_broadcast_axes; AxisSet m_broadcast_axes;
}; };
namespace op namespace op
{ {
Node::ptr broadcast(const Node::ptr& tensor, Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes); AxisSet&& broadcast_axes);
} }
} }
...@@ -59,14 +59,6 @@ namespace ngraph ...@@ -59,14 +59,6 @@ namespace ngraph
typename T::type value() const { return m_value; } typename T::type value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarConstant<T>> make(U value)
{
return std::make_shared<ScalarConstant<T>>(value);
}
protected: protected:
typename T::type m_value; typename T::type m_value;
}; };
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
public: public:
Parameter(const ValueType::ptr& value_type); Parameter(const ValueType::ptr& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
......
...@@ -25,7 +25,7 @@ using namespace ngraph; ...@@ -25,7 +25,7 @@ using namespace ngraph;
**/ **/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes) AxisSet&& broadcast_axes)
{ {
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes); return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
} }
......
...@@ -26,6 +26,11 @@ Parameter::Parameter(const ValueType::ptr& value_type) ...@@ -26,6 +26,11 @@ Parameter::Parameter(const ValueType::ptr& value_type)
{ {
} }
Parameter::Parameter(const ngraph::element::Type element_type, const Shape& shape)
: Parameter(make_shared<TensorViewType>(element_type, shape))
{
}
void Parameter::assign_function(Function* function, size_t index) void Parameter::assign_function(Function* function, size_t index)
{ {
if (nullptr != m_function) if (nullptr != m_function)
......
...@@ -20,29 +20,16 @@ ...@@ -20,29 +20,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
template<typename T, typename ...A>
std::shared_ptr<T> myfun(A&&... args)
{
return std::make_shared<T>(args...);
}
template<>
std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_type, Shape&& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = op::parameter(element::Float::element_type(), Shape{7, 3}); auto arg0 = node<Parameter>(element::Float::element_type(), Shape{7, 3});
auto arg1 = op::parameter(element::Float::element_type(), Shape{3}); auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::element_type(), Shape{32, 7}); auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto arg3 = op::parameter(element::Float::element_type(), Shape{32, 7}); auto arg3 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0}); auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0}); auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = op::dot(arg2, arg0); auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
...@@ -55,14 +42,14 @@ TEST(build_graph, build_simple) ...@@ -55,14 +42,14 @@ TEST(build_graph, build_simple)
TEST(build_graph, as_type) TEST(build_graph, as_type)
{ {
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5}); auto tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt); auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv); ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt); auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp); ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt}); auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt); auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv); ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt); auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
...@@ -72,15 +59,15 @@ TEST(build_graph, as_type) ...@@ -72,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = op::parameter(element::Float::element_type(), {32, 3}); auto arg0 = node<Parameter>(element::Float::element_type(), Shape{32, 3});
auto arg1 = op::parameter(element::Float::element_type(), {3}); auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::element_type(), {32}); auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto parg = op::parameter(element::Float::element_type(), {}); auto parg = node<Parameter>(element::Float::element_type(), Shape{});
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = node<DotOp>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong. // Need to figure out what's wrong.
...@@ -90,20 +77,21 @@ TEST(build_graph, node_comparison) ...@@ -90,20 +77,21 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal) TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0); ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->value_type(), float_scalar_type); ASSERT_EQ(*float0->value_type(), float_scalar_type);
auto d = op::dot(float0, float0); auto d = node<DotOp>(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0); ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0); ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = FloatScalarConstant::make(3); auto float1 = node<FloatScalarConstant>(3);
ASSERT_EQ(float1->value(), 3); ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->value_type(), float_scalar_type); ASSERT_EQ(*float1->value_type(), float_scalar_type);
auto int32_0 = Int32ScalarConstant::make(3.0); auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3); ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->value_type(), int32_scalar_type); ASSERT_EQ(*int32_0->value_type(), int32_scalar_type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment