Commit 520d9d5d authored by Scott Cyphers's avatar Scott Cyphers

style

parent 14436c81
...@@ -19,16 +19,17 @@ ...@@ -19,16 +19,17 @@
class NGraph class NGraph
{ {
public: public:
void add_params(const std::vector<std::string>& paramList); void add_params(const std::vector<std::string>& paramList);
const std::vector<std::string>& get_params() const; const std::vector<std::string>& get_params() const;
std::string get_name() const { return "NGraph Implementation Object"; } std::string get_name() const { return "NGraph Implementation Object"; }
private: private:
std::vector<std::string> m_params; std::vector<std::string> m_params;
}; };
// Factory methods // Factory methods
extern "C" NGraph* create_ngraph_object(); extern "C" NGraph* create_ngraph_object();
extern "C" void destroy_ngraph_object(NGraph* pObj); extern "C" void destroy_ngraph_object(NGraph* pObj);
// FUnction pointers to the factory methods // FUnction pointers to the factory methods
typedef NGraph* (*CreateNGraphObjPfn)(); typedef NGraph* (*CreateNGraphObjPfn)();
......
...@@ -70,5 +70,4 @@ namespace ngraph ...@@ -70,5 +70,4 @@ namespace ngraph
std::vector<Parameter::ptr> m_parameters; std::vector<Parameter::ptr> m_parameters;
Result m_result; Result m_result;
}; };
} }
...@@ -22,6 +22,11 @@ namespace ngraph ...@@ -22,6 +22,11 @@ namespace ngraph
{ {
class Op; class Op;
/**
** Nodes are the backbone of the graph of Value dataflow. Every node has
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class Node : public TypedValueMixin class Node : public TypedValueMixin
{ {
public: public:
...@@ -33,11 +38,10 @@ namespace ngraph ...@@ -33,11 +38,10 @@ namespace ngraph
{ {
} }
virtual ~Node() {} const std::vector<Node::ptr> arguments() const { return m_arguments; }
virtual std::vector<Node::ptr> dependents() { return m_arguments; } std::vector<Node::ptr> arguments() { return m_arguments; }
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
}; };
} }
...@@ -21,43 +21,44 @@ ...@@ -21,43 +21,44 @@
namespace ngraph namespace ngraph
{ {
class Op class Op;
{
public:
using ptr = std::shared_ptr<Op>;
using ref = decltype(*std::shared_ptr<Op>());
};
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
** call nodes.
**/
class Call : public Node class Call : public Node
{ {
public: public:
using ptr = std::shared_ptr<Call>; std::shared_ptr<Op> op() const { return m_op; }
Op::ptr op() const { return m_op; } Call(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& arguments)
Call(const Op::ptr& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
, m_op(op) , m_op(op)
{ {
} }
protected: protected:
Op::ptr m_op; std::shared_ptr<Op> m_op;
}; };
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast> /**
** The Op class provides the behavior for a Call.
**/
class Op
{ {
public: };
using ptr = std::shared_ptr<Broadcast>;
using ref = decltype(*std::shared_ptr<Broadcast>());
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{
protected: protected:
class BroadcastCall : public Call class BroadcastCall : public Call
{ {
friend class Broadcast; friend class Broadcast;
public: public:
BroadcastCall(const Op::ptr& op, const Node::ptr& arg, size_t axis) BroadcastCall(const std::shared_ptr<Op>& op, const Node::ptr& arg, size_t axis)
: Call(op, {arg}) : Call(op, {arg})
, m_axis(axis) , m_axis(axis)
{ {
...@@ -68,7 +69,6 @@ namespace ngraph ...@@ -68,7 +69,6 @@ namespace ngraph
}; };
public: public:
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis) std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{ {
return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis); return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
...@@ -77,15 +77,11 @@ namespace ngraph ...@@ -77,15 +77,11 @@ namespace ngraph
namespace op namespace op
{ {
extern Broadcast::ref broadcast; extern decltype(*std::shared_ptr<Broadcast>()) broadcast;
} }
class Dot : public Op, public std::enable_shared_from_this<Dot> class Dot : public Op, public std::enable_shared_from_this<Dot>
{ {
public:
using ptr = std::shared_ptr<Dot>;
using ref = decltype(*std::shared_ptr<Dot>());
public: public:
Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1) Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
...@@ -95,6 +91,6 @@ namespace ngraph ...@@ -95,6 +91,6 @@ namespace ngraph
namespace op namespace op
{ {
extern Dot::ref dot; extern decltype(*std::shared_ptr<Dot>()) dot;
} }
} }
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
using namespace ngraph; using namespace ngraph;
Broadcast::ref ngraph::op::broadcast = *std::make_shared<Broadcast>(); decltype(*std::shared_ptr<Broadcast>()) ngraph::op::broadcast = *std::make_shared<Broadcast>();
Dot::ref ngraph::op::dot = *std::make_shared<Dot>();
decltype(*std::shared_ptr<Dot>()) ngraph::op::dot = *std::make_shared<Dot>();
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment