Commit 2e420183 authored by Scott Cyphers's avatar Scott Cyphers

Calls will be the focus of op behavior; shrink op to an identifier.

update broadcast op args
parent 5817c0a7
......@@ -33,6 +33,12 @@ namespace ngraph
Parameter(Function& function, size_t index);
const std::string& description() const override
{
static std::string name{"Parameter"};
return name;
}
protected:
Function& m_function;
size_t m_index;
......@@ -66,8 +72,11 @@ namespace ngraph
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
const std::string& name() const override { return m_name; }
protected:
std::vector<Parameter::ptr> m_parameters;
Result m_result;
std::string m_name;
};
}
......@@ -15,6 +15,7 @@
#pragma once
#include <set>
#include <string>
#include <vector>
#include "ngraph/type.hpp"
......@@ -33,6 +34,7 @@ namespace ngraph
public:
using ptr = std::shared_ptr<Node>;
protected:
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type)
, m_arguments(arguments)
......@@ -44,14 +46,27 @@ namespace ngraph
}
}
virtual ~Node() {}
public:
/// A "one-liner" describing this node.
virtual const std::string& description() const = 0;
/// Propagate types and check arguments for consistency
// virtual void propagate_types() = 0;
const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; }
std::multiset<Node*> users() { return m_users; }
const std::string& name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
protected:
std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users;
std::string m_name;
};
}
......@@ -21,7 +21,17 @@
namespace ngraph
{
class Op;
/**
** Every instance of Op corresponds to a unique defined operation.
**/
class Op
{
protected:
virtual ~Op() {}
public:
virtual const std::string& name() const = 0;
};
/**
** Call nodes are nodes whose value is the result of some operation, the op,
......@@ -39,43 +49,69 @@ namespace ngraph
{
}
const std::string& description() const override { return m_op->name(); }
protected:
std::shared_ptr<Op> m_op;
};
/**
** The Op class provides the behavior for a Call.
** There is exactly one instance of builtin op for each pre-defined operation.
**/
class Op
class BuiltinOp : public Op
{
};
friend class Call;
namespace op
{
std::shared_ptr<Node> broadcast(const Node::ptr& tensor, size_t axis);
}
public:
BuiltinOp(const std::string& name)
: m_name(name)
{
}
public:
const std::string& name() const override { return m_name; }
protected:
std::string m_name;
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
class BuiltinCall : public Call
{
friend std::shared_ptr<Node> op::broadcast(const Node::ptr& tensor, size_t axis);
public:
const std::string& description() const override
{
static std::string name{"BuiltinCall "};
return name;
}
protected:
class BroadcastCall : public Call
BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args)
: Call(op, args)
{
friend class Broadcast;
}
};
public:
BroadcastCall(const std::shared_ptr<Op>& op, const Node::ptr& arg, size_t axis)
: Call(op, {arg})
, m_axis(axis)
{
}
namespace op
{
std::shared_ptr<Node> broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
}
protected:
size_t m_axis;
};
class BroadcastCall : public BuiltinCall
{
public:
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall(s_op, {arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
static std::shared_ptr<Broadcast> s_op;
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
namespace op
......@@ -83,9 +119,15 @@ namespace ngraph
std::shared_ptr<Node> dot(const Node::ptr& arg0, const Node::ptr& arg1);
}
class Dot : public Op, public std::enable_shared_from_this<Dot>
class DotCall : public BuiltinCall
{
friend std::shared_ptr<Node> op::dot(const Node::ptr& arg0, const Node::ptr& arg1);
static std::shared_ptr<Dot> s_op;
public:
DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
}
......@@ -26,6 +26,7 @@ Parameter::Parameter(Function& function, size_t index)
Function::Function(size_t n_parameters)
: m_parameters(n_parameters)
, m_name("Function")
{
for (int i = 0; i < n_parameters; i++)
{
......
......@@ -17,16 +17,18 @@
using namespace ngraph;
using namespace std;
shared_ptr<Broadcast> ngraph::Broadcast::s_op = make_shared<ngraph::Broadcast>();
std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast");
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor, size_t axis)
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<Broadcast::BroadcastCall>(Broadcast::s_op->shared_from_this(), tensor, axis);
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes);
}
shared_ptr<Dot> ngraph::Dot::s_op = make_shared<ngraph::Dot>();
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot");
shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<Call>(Dot::s_op->shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
return make_shared<DotCall>(arg0, arg1);
}
......@@ -30,7 +30,7 @@ TEST(graph, build_simple)
cluster_0->parameter(3)->type(element_type_float, {32, 7});
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1.
auto broadcast_1 = op::broadcast(arg3, 1);
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment