Commit 9d40c6b2 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #58 from NervanaSystems/cyphers/factory

Switch to factory functions, style on some more files.
parents b2f1ef60 b0c72a83
...@@ -37,6 +37,7 @@ namespace nervana ...@@ -37,6 +37,7 @@ namespace nervana
} }
constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; } constexpr const char* get_ptr(size_t offset) const { return &_string[offset]; }
constexpr size_t size() const { return _size; } constexpr size_t size() const { return _size; }
private: private:
const char* _string; const char* _string;
size_t _size; size_t _size;
...@@ -44,8 +45,9 @@ namespace nervana ...@@ -44,8 +45,9 @@ namespace nervana
constexpr const char* find_last(conststring s, size_t offset, char ch) constexpr const char* find_last(conststring s, size_t offset, char ch)
{ {
return offset == 0 ? s.get_ptr(0) : (s[offset] == ch ? s.get_ptr(offset + 1) return offset == 0
: find_last(s, offset - 1, ch)); ? s.get_ptr(0)
: (s[offset] == ch ? s.get_ptr(offset + 1) : find_last(s, offset - 1, ch));
} }
constexpr const char* find_last(conststring s, char ch) constexpr const char* find_last(conststring s, char ch)
...@@ -67,6 +69,7 @@ namespace nervana ...@@ -67,6 +69,7 @@ namespace nervana
~log_helper(); ~log_helper();
std::ostream& stream() { return _stream; } std::ostream& stream() { return _stream; }
private: private:
std::stringstream _stream; std::stringstream _stream;
}; };
......
...@@ -33,6 +33,8 @@ namespace ngraph ...@@ -33,6 +33,8 @@ namespace ngraph
Parameter(Function& function, size_t index); Parameter(Function& function, size_t index);
std::string description() const override { return "Parameter"; }
protected: protected:
Function& m_function; Function& m_function;
size_t m_index; size_t m_index;
...@@ -66,8 +68,11 @@ namespace ngraph ...@@ -66,8 +68,11 @@ namespace ngraph
Parameter::ptr parameter(size_t i) { return m_parameters[i]; } Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const override { return m_name; }
protected: protected:
std::vector<Parameter::ptr> m_parameters; std::vector<Parameter::ptr> m_parameters;
Result m_result; Result m_result;
std::string m_name;
}; };
} }
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
#pragma once #pragma once
#include <vector>
#include <set> #include <set>
#include <string>
#include <vector>
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
...@@ -31,27 +32,41 @@ namespace ngraph ...@@ -31,27 +32,41 @@ namespace ngraph
class Node : public TypedValueMixin class Node : public TypedValueMixin
{ {
public: public:
using ptr = std::shared_ptr<Node>; using ptr = std::shared_ptr<Node>;
protected:
Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr) Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type = nullptr)
: TypedValueMixin(type) : TypedValueMixin(type)
, m_arguments(arguments) , m_arguments(arguments)
{ {
// Add this node as a user of each argument. // Add this node as a user of each argument.
for(auto node : m_arguments){ for (auto node : m_arguments)
{
node->m_users.insert(node.get()); node->m_users.insert(node.get());
} }
} }
virtual ~Node() {}
public:
/// A "one-liner" describing this node.
virtual std::string description() const = 0;
/// Propagate types and check arguments for consistency
// virtual void propagate_types() = 0;
const std::vector<Node::ptr> arguments() const { return m_arguments; } const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; } std::vector<Node::ptr> arguments() { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; } const std::multiset<Node*> users() const { return m_users; }
std::multiset<Node*> users() { return m_users; } std::multiset<Node*> users() { return m_users; }
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
std::string m_name;
}; };
} }
...@@ -21,12 +21,22 @@ ...@@ -21,12 +21,22 @@
namespace ngraph namespace ngraph
{ {
class Op; /**
** Every instance of Op corresponds to a unique defined operation.
**/
class Op
{
protected:
virtual ~Op() {}
public:
virtual std::string name() const = 0;
};
/** /**
** Call nodes are nodes whose value is the result of some operation, the op, ** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the ** applied to its arguments. We use the op as a callable to construct the
** call nodes. ** call nodes. For calls to user functions, the op will be the user function.
**/ **/
class Call : public Node class Call : public Node
{ {
...@@ -39,58 +49,90 @@ namespace ngraph ...@@ -39,58 +49,90 @@ namespace ngraph
{ {
} }
virtual std::string description() const override { return m_op->name(); }
protected: protected:
std::shared_ptr<Op> m_op; std::shared_ptr<Op> m_op;
}; };
/** /**
** The Op class provides the behavior for a Call. ** There is exactly one instance of builtin op for each pre-defined operation. These
** are intended to be used when matching calls in different graphs; every FooCall
** will have the same op.
**/ **/
class Op class BuiltinOp : public Op
{
};
class Broadcast : public Op, public std::enable_shared_from_this<Broadcast>
{
protected:
class BroadcastCall : public Call
{ {
friend class Broadcast; friend class Call;
public: public:
BroadcastCall(const std::shared_ptr<Op>& op, const Node::ptr& arg, size_t axis) BuiltinOp(const std::string& name)
: Call(op, {arg}) : m_name(name)
, m_axis(axis)
{ {
} }
public:
std::string name() const override { return m_name; }
protected: protected:
size_t m_axis; std::string m_name;
}; };
class BuiltinCall : public Call
{
public: public:
std::shared_ptr<BroadcastCall> operator()(const Node::ptr& tensor, size_t axis) virtual std::string description() const override { return "BuiltinCall"; }
protected:
BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args)
: Call(op, args)
{ {
return std::make_shared<BroadcastCall>(shared_from_this(), tensor, axis);
} }
}; };
namespace op namespace op
{ {
extern decltype(*std::shared_ptr<Broadcast>()) broadcast; std::shared_ptr<Node> broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
} }
class Dot : public Op, public std::enable_shared_from_this<Dot> class BroadcastCall : public BuiltinCall
{ {
public: public:
Call::ptr operator()(const Node::ptr& arg0, const Node::ptr& arg1) /**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall(s_op, {arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{ {
return std::make_shared<Call>(shared_from_this(), std::vector<Node::ptr>{arg0, arg1});
} }
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
protected:
static std::shared_ptr<BuiltinOp> s_op;
}; };
namespace op namespace op
{ {
extern decltype(*std::shared_ptr<Dot>()) dot; std::shared_ptr<Node> dot(const Node::ptr& arg0, const Node::ptr& arg1);
} }
class DotCall : public BuiltinCall
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
{
}
protected:
static std::shared_ptr<BuiltinOp> s_op;
};
} }
...@@ -26,6 +26,7 @@ Parameter::Parameter(Function& function, size_t index) ...@@ -26,6 +26,7 @@ Parameter::Parameter(Function& function, size_t index)
Function::Function(size_t n_parameters) Function::Function(size_t n_parameters)
: m_parameters(n_parameters) : m_parameters(n_parameters)
, m_name("Function")
{ {
for (int i = 0; i < n_parameters; i++) for (int i = 0; i < n_parameters; i++)
{ {
......
...@@ -15,7 +15,27 @@ ...@@ -15,7 +15,27 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std;
decltype(*std::shared_ptr<Broadcast>()) ngraph::op::broadcast = *std::make_shared<Broadcast>(); std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast");
decltype(*std::shared_ptr<Dot>()) ngraph::op::dot = *std::make_shared<Dot>(); /**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes);
}
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotCall>(arg0, arg1);
}
...@@ -148,6 +148,7 @@ namespace ngraph ...@@ -148,6 +148,7 @@ namespace ngraph
size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; } size_t get_total_milliseconds() const { return get_total_nanoseconds() / 1e6; }
size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; } size_t get_total_microseconds() const { return get_total_nanoseconds() / 1e3; }
size_t get_total_nanoseconds() const { return m_total_time.count(); } size_t get_total_nanoseconds() const { return m_total_time.count(); }
private: private:
std::chrono::high_resolution_clock m_clock; std::chrono::high_resolution_clock m_clock;
std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time; std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
......
...@@ -29,8 +29,8 @@ TEST(DISABLED_graph, build_simple) ...@@ -29,8 +29,8 @@ TEST(DISABLED_graph, build_simple)
cluster_0->parameter(2)->type(element::float32_t, {32, 7}); cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element::float32_t, {32, 7}); cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3); auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 1. // call broadcast op on arg3, broadcasting on axis 0.
auto broadcast_1 = op::broadcast(arg3, 1); auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg2 = cluster_0->parameter(2); auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0); auto arg0 = cluster_0->parameter(0);
// call dot op // call dot op
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment