Commit 520d9d5d authored by Scott Cyphers's avatar Scott Cyphers

style

parent 14436c81
......@@ -19,16 +19,17 @@
class NGraph
{
public:
void add_params(const std::vector<std::string>& paramList);
void add_params(const std::vector<std::string>& paramList);
const std::vector<std::string>& get_params() const;
std::string get_name() const { return "NGraph Implementation Object"; }
private:
std::vector<std::string> m_params;
};
// Factory methods
extern "C" NGraph* create_ngraph_object();
extern "C" void destroy_ngraph_object(NGraph* pObj);
extern "C" void destroy_ngraph_object(NGraph* pObj);
// FUnction pointers to the factory methods
typedef NGraph* (*CreateNGraphObjPfn)();
......
......@@ -70,5 +70,4 @@ namespace ngraph
std::vector<Parameter::ptr> m_parameters;
Result m_result;
};
}
......@@ -22,6 +22,11 @@ namespace ngraph
{
class Op;
/**
** Nodes are the backbone of the graph of Value dataflow. Every node has
** zero or more nodes as arguments and one value, which is either a tensor
** view or a (possibly empty) tuple of values.
**/
class Node : public TypedValueMixin
{
public:
......@@ -33,11 +38,10 @@ namespace ngraph
{
}
virtual ~Node() {}
virtual std::vector<Node::ptr> dependents() { return m_arguments; }
const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
protected:
std::vector<Node::ptr> m_arguments;
};
}
......@@ -21,43 +21,44 @@
namespace ngraph
{
class Op
{
public:
using ptr = std::shared_ptr<Op>;
using ref = decltype(*std::shared_ptr<Op>());
};
class Op;
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
** call nodes.
**/
class Call : public Node
{
public:
using ptr = std::shared_ptr<Call>;
std::shared_ptr<Op> op() const { return m_op; }
Op::ptr op() const { return m_op; }
Call(const Op::ptr& op, const std::vector<Node::ptr>& arguments)
Call(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
, m_op(op)
{
}
protected:
Op::ptr m_op;
std::shared_ptr<Op> m_op;
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
/**
** The Op class provides the behavior for a Call.
**/
class Op
{
public:
using ptr = std::shared_ptr<Broadcast>;
using ref = decltype(*std::shared_ptr<Broadcast>());
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{
protected:
class BroadcastCall : public Call
{
friend class Broadcast;
public:
BroadcastCall(const Op::ptr& op, const Node::ptr& arg, size_t axis)
BroadcastCall(const std::shared_ptr<Op>& op, const Node::ptr& arg, size_t axis)
: Call(op, {arg})
, m_axis(axis)
{
......@@ -68,7 +69,6 @@ namespace ngraph
};
public:
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis)
{
return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
......@@ -77,15 +77,11 @@ namespace ngraph
namespace op
{
extern Broadcast::ref broadcast;
extern decltype(*std::shared_ptr<Broadcast>()) broadcast;
}
class Dot : public Op, public std::enable_shared_from_this<Dot>
{
public:
using ptr = std::shared_ptr<Dot>;
using ref = decltype(*std::shared_ptr<Dot>());
public:
Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1)
{
......@@ -95,6 +91,6 @@ namespace ngraph
namespace op
{
extern Dot::ref dot;
extern decltype(*std::shared_ptr<Dot>()) dot;
}
}
......@@ -16,7 +16,6 @@
using namespace ngraph;
Broadcast::ref ngraph::op::broadcast = *std::make_shared<Broadcast>();
Dot::ref ngraph::op::dot = *std::make_shared<Dot>();
decltype(*std::shared_ptr<Broadcast>()) ngraph::op::broadcast = *std::make_shared<Broadcast>();
decltype(*std::shared_ptr<Dot>()) ngraph::op::dot = *std::make_shared<Dot>();
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment