Commit 6d6e923b authored by Scott Cyphers's avatar Scott Cyphers

Try universal node creator

parent 42f599c6
......@@ -23,12 +23,25 @@ namespace ngraph
{
class Node;
class Parameter;
class ValueType;
template<typename T, typename ...A>
std::shared_ptr<T> node(A&&... args)
{
return std::make_shared<T>(args...);
}
/// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
/// Zero or more nodes
using Nodes = std::vector<std::shared_ptr<Node>>;
/// A sequence of axes
using AxisList = std::vector<size_t>;
/// A set of indices, for example, reduction axes
using IndexSet = std::set<size_t>;
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
......
......@@ -19,15 +19,13 @@ namespace ngraph
class BroadcastOp : public BuiltinOp
{
public:
using Axes = std::vector<size_t>;
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastOp(const Node::ptr& arg, const Shape& shape, const Axes& broadcast_axes)
BroadcastOp(const Node::ptr& arg, const Shape& shape, AxisSet& broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
......@@ -39,13 +37,13 @@ namespace ngraph
protected:
Shape m_shape;
Axes m_broadcast_axes;
AxisSet m_broadcast_axes;
};
namespace op
{
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes);
AxisSet&& broadcast_axes);
}
}
......@@ -59,14 +59,6 @@ namespace ngraph
typename T::type value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
template <typename U>
static std::shared_ptr<ScalarConstant<T>> make(U value)
{
return std::make_shared<ScalarConstant<T>>(value);
}
protected:
typename T::type m_value;
};
......
......@@ -37,6 +37,7 @@ namespace ngraph
public:
Parameter(const ValueType::ptr& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
......
......@@ -25,7 +25,7 @@ using namespace ngraph;
**/
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const BroadcastOp::Axes&& broadcast_axes)
AxisSet&& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
......
......@@ -26,6 +26,11 @@ Parameter::Parameter(const ValueType::ptr& value_type)
{
}
Parameter::Parameter(const ngraph::element::Type element_type, const Shape& shape)
: Parameter(make_shared<TensorViewType>(element_type, shape))
{
}
void Parameter::assign_function(Function* function, size_t index)
{
if (nullptr != m_function)
......
......@@ -20,29 +20,16 @@
using namespace std;
using namespace ngraph;
template<typename T, typename ...A>
std::shared_ptr<T> myfun(A&&... args)
{
return std::make_shared<T>(args...);
}
template<>
std::shared_ptr<Parameter> myfun<Parameter> (ngraph::element::Type&& element_type, Shape&& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = op::parameter(element::Float::element_type(), Shape{7, 3});
auto arg1 = op::parameter(element::Float::element_type(), Shape{3});
auto arg2 = op::parameter(element::Float::element_type(), Shape{32, 7});
auto arg3 = op::parameter(element::Float::element_type(), Shape{32, 7});
auto broadcast_1 = op::broadcast(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto b1 = myfun<BroadcastOp>(arg3, Shape{10, 32, 7}, BroadcastOp::Axes{0});
auto dot = op::dot(arg2, arg0);
auto arg0 = node<Parameter>(element::Float::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
......@@ -55,14 +42,14 @@ TEST(build_graph, build_simple)
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
auto tv_vt = make_shared<TensorViewType>(element::Float::element_type(), Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt});
auto tp_vt = make_shared<TupleType>(ValueTypes{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
......@@ -72,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto arg0 = op::parameter(element::Float::element_type(), {32, 3});
auto arg1 = op::parameter(element::Float::element_type(), {3});
auto arg2 = op::parameter(element::Float::element_type(), {32});
auto arg0 = node<Parameter>(element::Float::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto parg = op::parameter(element::Float::element_type(), {});
auto pattern_dot = op::dot(parg, parg);
auto parg = node<Parameter>(element::Float::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
......@@ -90,20 +77,21 @@ TEST(build_graph, node_comparison)
TEST(build_graph, literal)
{
// float scalar from a float
auto float0 = FloatScalarConstant::make(3.0);
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->value_type(), float_scalar_type);
auto d = op::dot(float0, float0);
auto d = node<DotOp>(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0);
// float scalar from an int
auto float1 = FloatScalarConstant::make(3);
auto float1 = node<FloatScalarConstant>(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->value_type(), float_scalar_type);
auto int32_0 = Int32ScalarConstant::make(3.0);
auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->value_type(), int32_scalar_type);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment