Commit c302ab30 authored by Scott Cyphers's avatar Scott Cyphers

Simplify op/call

parent 3a53de5e
......@@ -22,6 +22,10 @@ namespace ngraph
{
class Function;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class Parameter : public Node
{
public:
......@@ -34,6 +38,10 @@ namespace ngraph
size_t m_index;
};
/**
** The result of a function. The ndoe addociated with the result
** supplies the return value when the function is called.
**/
class Result : public TypedValueMixin
{
public:
......@@ -46,6 +54,9 @@ namespace ngraph
Node::ptr m_value;
};
/**
** A user-defined function.
**/
class Function : public Op
{
public:
......@@ -53,7 +64,7 @@ namespace ngraph
Result* result() { return &m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
protected:
std::vector<Parameter::ptr> m_parameters;
......
......@@ -40,15 +40,4 @@ namespace ngraph
std::vector<Node::ptr> m_arguments;
};
class Call : public Node
{
public:
virtual Op& op() const = 0;
protected:
Call(const std::vector<Node::ptr>& arguments)
: Node(arguments, 0)
{
}
};
}
......@@ -23,63 +23,78 @@ namespace ngraph
{
class Op
{
public:
using ptr = std::shared_ptr<Op>;
using ref = decltype(*std::shared_ptr<Op>());
};
class Broadcast : public Op
class Call : public Node
{
public:
using ptr = std::shared_ptr<Call>;
Op::ptr op() const { return m_op; }
Call(const Op::ptr& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, 0)
, m_op(op)
{
}
protected:
Op::ptr m_op;
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{
public:
using ptr = std::shared_ptr<Broadcast>;
using ref = decltype(*std::shared_ptr<Broadcast>());
protected:
class BroadcastCall : public Call
{
friend class Broadcast;
public:
BroadcastCall(const Node::ptr& arg, size_t axis)
: Call({arg})
BroadcastCall(const Op::ptr& op, const Node::ptr& arg, size_t axis)
: Call(op, {arg})
, m_axis(axis)
{
}
Op& op() const override;
protected:
size_t m_axis;
};
public:
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(tensor, axis);
return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
}
};
namespace op
{
extern Broadcast broadcast;
extern Broadcast::ref broadcast;
}
class Dot : public Op
{
class DotCall : public Call
class Dot : public Op, public std::enable_shared_from_this<Dot>
{
friend class Dot;
public:
DotCall(const std::shared_ptr<Node>& arg0, const Node::ptr& arg1)
: Call({arg0, arg1})
{
}
Op& op() const override;
};
using ptr = std::shared_ptr<Dot>;
using ref = decltype(*std::shared_ptr<Dot>());
public:
std::shared_ptr<DotCall> operator()(const Node::ptr& arg0, const Node::ptr& arg1)
Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1)
{
return std::make_shared<DotCall>(arg0, arg1);
return Call::ptr::make_shared(shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
}
};
namespace op
{
extern Dot dot;
extern Dot::ref dot;
}
}
......@@ -16,16 +16,7 @@
using namespace ngraph;
Broadcast ngraph::op::broadcast{};
Broadcast::ref ngraph::op::broadcast = *std::make_shared<Broadcast>();
Op& ngraph::Broadcast::BroadcastCall::op() const
{
return op::broadcast;
}
Dot::ref ngraph::op::dot = *std::make_shared<Dot>();
Dot ngraph::op::dot{};
Op& ngraph::Dot::DotCall::op() const
{
return op::dot;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment