Commit c302ab30 authored by Scott Cyphers's avatar Scott Cyphers

Simplify op/call

parent 3a53de5e
...@@ -22,6 +22,10 @@ namespace ngraph ...@@ -22,6 +22,10 @@ namespace ngraph
{ {
class Function; class Function;
/**
** One parameter of a function. Within the function's graph
** the parameter is a node that represents the argument in a call.
**/
class Parameter : public Node class Parameter : public Node
{ {
public: public:
...@@ -34,6 +38,10 @@ namespace ngraph ...@@ -34,6 +38,10 @@ namespace ngraph
size_t m_index; size_t m_index;
}; };
/**
** The result of a function. The ndoe addociated with the result
** supplies the return value when the function is called.
**/
class Result : public TypedValueMixin class Result : public TypedValueMixin
{ {
public: public:
...@@ -46,6 +54,9 @@ namespace ngraph ...@@ -46,6 +54,9 @@ namespace ngraph
Node::ptr m_value; Node::ptr m_value;
}; };
/**
** A user-defined function.
**/
class Function : public Op class Function : public Op
{ {
public: public:
...@@ -53,7 +64,7 @@ namespace ngraph ...@@ -53,7 +64,7 @@ namespace ngraph
Result* result() { return &m_result; } Result* result() { return &m_result; }
std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; } Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
protected: protected:
std::vector<Parameter::ptr> m_parameters; std::vector<Parameter::ptr> m_parameters;
......
...@@ -40,15 +40,4 @@ namespace ngraph ...@@ -40,15 +40,4 @@ namespace ngraph
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
}; };
class Call : public Node
{
public:
virtual Op& op() const = 0;
protected:
Call(const std::vector<Node::ptr>& arguments)
: Node(arguments, 0)
{
}
};
} }
...@@ -23,63 +23,78 @@ namespace ngraph ...@@ -23,63 +23,78 @@ namespace ngraph
{ {
class Op class Op
{ {
public:
using ptr = std::shared_ptr<Op>;
using ref = decltype(*std::shared_ptr<Op>());
}; };
class Broadcast : public Op class Call : public Node
{
public:
using ptr = std::shared_ptr<Call>;
Op::ptr op() const { return m_op; }
Call(const Op::ptr& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, 0)
, m_op(op)
{ {
}
protected:
Op::ptr m_op;
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{
public:
using ptr = std::shared_ptr<Broadcast>;
using ref = decltype(*std::shared_ptr<Broadcast>());
protected:
class BroadcastCall : public Call class BroadcastCall : public Call
{ {
friend class Broadcast; friend class Broadcast;
public: public:
BroadcastCall(const Node::ptr& arg, size_t axis) BroadcastCall(const Op::ptr& op, const Node::ptr& arg, size_t axis)
: Call({arg}) : Call(op, {arg})
, m_axis(axis) , m_axis(axis)
{ {
} }
Op& op() const override;
protected: protected:
size_t m_axis; size_t m_axis;
}; };
public: public:
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis) std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{ {
return std::make_shared<BroadcastCall>(tensor, axis); return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
} }
}; };
namespace op namespace op
{ {
extern Broadcast broadcast; extern Broadcast::ref broadcast;
} }
class Dot : public Op class Dot : public Op, public std::enable_shared_from_this<Dot>
{
class DotCall : public Call
{ {
friend class Dot;
public: public:
DotCall(const std::shared_ptr<Node>& arg0, const Node::ptr& arg1) using ptr = std::shared_ptr<Dot>;
: Call({arg0, arg1}) using ref = decltype(*std::shared_ptr<Dot>());
{
}
Op& op() const override;
};
public: public:
std::shared_ptr<DotCall> operator()(const Node::ptr& arg0, const Node::ptr& arg1) Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return std::make_shared<DotCall>(arg0, arg1); return Call::ptr::make_shared(shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
} }
}; };
namespace op namespace op
{ {
extern Dot dot; extern Dot::ref dot;
} }
} }
...@@ -16,16 +16,7 @@ ...@@ -16,16 +16,7 @@
using namespace ngraph; using namespace ngraph;
Broadcast ngraph::op::broadcast{}; Broadcast::ref ngraph::op::broadcast = *std::make_shared<Broadcast>();
Op& ngraph::Broadcast::BroadcastCall::op() const Dot::ref ngraph::op::dot = *std::make_shared<Dot>();
{
return op::broadcast;
}
Dot ngraph::op::dot{};
Op& ngraph::Dot::DotCall::op() const
{
return op::dot;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment