Commit 304f1219 authored by Scott Cyphers's avatar Scott Cyphers

style

parent 966d2cc2
...@@ -21,7 +21,6 @@ namespace ngraph ...@@ -21,7 +21,6 @@ namespace ngraph
/// Base error for ngraph runtime errors. /// Base error for ngraph runtime errors.
struct ngraph_error : std::runtime_error struct ngraph_error : std::runtime_error
{ {
explicit ngraph_error(const std::string& what_arg) explicit ngraph_error(const std::string& what_arg)
: std::runtime_error(what_arg) : std::runtime_error(what_arg)
{ {
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
namespace ngraph namespace ngraph
{ {
/** /**
** A user-defined function. ** A user-defined function.
**/ **/
class Function class Function
{ {
public: public:
Function(const Node::ptr& result, const std::vector<std::shared_ptr<Parameter>>& parameters); Function(const Node::ptr& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
Node::ptr result() { return m_result; } Node::ptr result() { return m_result; }
...@@ -37,14 +37,18 @@ namespace ngraph ...@@ -37,14 +37,18 @@ namespace ngraph
std::string name() const { return m_name; } std::string name() const { return m_name; }
protected: protected:
Node::ptr m_result; Node::ptr m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
namespace op namespace op
{ {
std::shared_ptr<Function> function(const Node::ptr& result, const std::initializer_list<std::shared_ptr<Parameter>>& parameters); std::shared_ptr<Function>
std::shared_ptr<Function> function(const Node::ptr& result, const std::vector<std::shared_ptr<Parameter>>& parameters); function(const Node::ptr& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function>
function(const Node::ptr& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
} }
} }
...@@ -69,7 +69,10 @@ namespace ngraph ...@@ -69,7 +69,10 @@ namespace ngraph
** will be used by the pattern matcher when comparing a pattern ** will be used by the pattern matcher when comparing a pattern
** graph against the graph. ** graph against the graph.
**/ **/
bool is_same_op_type(const Node::ptr& node) const { return typeid(*this) == typeid(*node.get()); } bool is_same_op_type(const Node::ptr& node) const
{
return typeid(*this) == typeid(*node.get());
}
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
......
...@@ -67,13 +67,12 @@ namespace ngraph ...@@ -67,13 +67,12 @@ namespace ngraph
/** /**
** Op nodes are nodes whose value is the result of some operation ** Op nodes are nodes whose value is the result of some operation
** applied to its arguments. For calls to user functions, the op will ** applied to its arguments. For calls to user functions, the op will
** reference the user function. ** reference the user function.
**/ **/
class Op : public Node class Op : public Node
{ {
public: public:
Op(const std::vector<Node::ptr>& arguments) Op(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
{ {
...@@ -86,8 +85,8 @@ namespace ngraph ...@@ -86,8 +85,8 @@ namespace ngraph
**/ **/
class FunctionOp : public Op class FunctionOp : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
Node::ptr m_function; Node::ptr m_function;
}; };
...@@ -102,7 +101,7 @@ namespace ngraph ...@@ -102,7 +101,7 @@ namespace ngraph
virtual std::string description() const override { return "BuiltinOp"; } virtual std::string description() const override { return "BuiltinOp"; }
/// Name of the builtin op, for debugging and logging. /// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0; virtual std::string op_name() const = 0;
// TODO: Implement for each op // TODO: Implement for each op
virtual void propagate_types() override {} virtual void propagate_types() override {}
...@@ -122,7 +121,7 @@ namespace ngraph ...@@ -122,7 +121,7 @@ namespace ngraph
} }
virtual std::string op_name() const override { return "abs"; } virtual std::string op_name() const override { return "abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class AddOp : public BuiltinOp class AddOp : public BuiltinOp
...@@ -133,7 +132,7 @@ namespace ngraph ...@@ -133,7 +132,7 @@ namespace ngraph
{ {
} }
virtual std::string op_name() const override { return "add"; } virtual std::string op_name() const override { return "add"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class BroadcastOp : public BuiltinOp class BroadcastOp : public BuiltinOp
...@@ -153,7 +152,7 @@ namespace ngraph ...@@ -153,7 +152,7 @@ namespace ngraph
} }
virtual std::string op_name() const override { return "broadcast"; } virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
...@@ -192,9 +191,9 @@ namespace ngraph ...@@ -192,9 +191,9 @@ namespace ngraph
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "dot"; } virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
class EqualOp : public BuiltinOp class EqualOp : public BuiltinOp
...@@ -216,7 +215,7 @@ namespace ngraph ...@@ -216,7 +215,7 @@ namespace ngraph
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
virtual std::string op_name() const override { return "exp"; } virtual std::string op_name() const override { return "exp"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -228,7 +227,7 @@ namespace ngraph ...@@ -228,7 +227,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "floor"; } virtual std::string op_name() const override { return "floor"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -252,7 +251,7 @@ namespace ngraph ...@@ -252,7 +251,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "less"; } virtual std::string op_name() const override { return "less"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -324,7 +323,7 @@ namespace ngraph ...@@ -324,7 +323,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "power"; } virtual std::string op_name() const override { return "power"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
......
...@@ -19,9 +19,8 @@ ...@@ -19,9 +19,8 @@
namespace ngraph namespace ngraph
{ {
class Function; class Function;
/** /**
** Parameters are nodes that represent the arguments that will be passed to user-defined functions. ** Parameters are nodes that represent the arguments that will be passed to user-defined functions.
** Function creation requires a sequence of parameters. ** Function creation requires a sequence of parameters.
...@@ -30,6 +29,7 @@ namespace ngraph ...@@ -30,6 +29,7 @@ namespace ngraph
class Parameter : public Node class Parameter : public Node
{ {
friend class Function; friend class Function;
protected: protected:
// Called by the Function constructor to associate this parameter with the function. // Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function. // It is an error to try to associate a parameter with more than one function.
...@@ -50,8 +50,9 @@ namespace ngraph ...@@ -50,8 +50,9 @@ namespace ngraph
namespace op namespace op
{ {
/// Factory for frameworks /// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type=nullptr); std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr);
/// Convenience factory for tests /// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type, const Shape& shape); std::shared_ptr<ngraph::Parameter> parameter(const ngraph::element::Type element_type,
const Shape& shape);
} }
} }
...@@ -63,7 +63,7 @@ namespace ngraph ...@@ -63,7 +63,7 @@ namespace ngraph
} }
const element::Type& element_type() const { return m_element_type; } const element::Type& element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; } const Shape& shape() const { return m_shape; }
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const Node::ptr& result, const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters) Function::Function(const Node::ptr& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
, m_name("Function") , m_name("Function")
...@@ -29,12 +30,14 @@ Function::Function(const Node::ptr& result, const std::vector<std::shared_ptr<ng ...@@ -29,12 +30,14 @@ Function::Function(const Node::ptr& result, const std::vector<std::shared_ptr<ng
} }
} }
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, const initializer_list<shared_ptr<Parameter>>& parameters) shared_ptr<Function> ngraph::op::function(const Node::ptr& result,
const initializer_list<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
} }
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, const vector<shared_ptr<Parameter>>& parameters) shared_ptr<Function> ngraph::op::function(const Node::ptr& result,
const vector<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
} }
...@@ -26,23 +26,23 @@ Parameter::Parameter(const ValueType::ptr& value_type) ...@@ -26,23 +26,23 @@ Parameter::Parameter(const ValueType::ptr& value_type)
void Parameter::assign_function(Function* function, size_t index) void Parameter::assign_function(Function* function, size_t index)
{ {
if (nullptr != m_function){ if (nullptr != m_function)
{
throw ngraph_error("Re-assigning function to a parameter."); throw ngraph_error("Re-assigning function to a parameter.");
} }
m_function = function; m_function = function;
m_index = index; m_index = index;
} }
void Parameter::propagate_types() void Parameter::propagate_types() {}
{
}
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type) shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type)
{ {
return make_shared<Parameter>(value_type); return make_shared<Parameter>(value_type);
} }
shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type, const Shape& shape) shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type,
const Shape& shape)
{ {
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape)); return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
} }
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list; std::map<std::string, ngraph::element::Type> ngraph::element::Type::m_element_list;
ngraph::element::Type::Type(size_t bitwidth, ngraph::element::Type::Type(size_t bitwidth,
bool is_float, bool is_float,
bool is_signed, bool is_signed,
const std::string& cname) const std::string& cname)
: m_bitwidth{bitwidth} : m_bitwidth{bitwidth}
, m_is_float{is_float} , m_is_float{is_float}
, m_is_signed{is_signed} , m_is_signed{is_signed}
......
...@@ -22,16 +22,17 @@ using namespace ngraph; ...@@ -22,16 +22,17 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = op::parameter(element::float32_t, {7, 3}); auto arg0 = op::parameter(element::float32_t, {7, 3});
auto arg1 = op::parameter(element::float32_t, {3}); auto arg1 = op::parameter(element::float32_t, {3});
auto arg2 = op::parameter(element::float32_t, {32, 7}); auto arg2 = op::parameter(element::float32_t, {32, 7});
auto arg3 = op::parameter(element::float32_t, {32, 7}); auto arg3 = op::parameter(element::float32_t, {32, 7});
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0}); auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto dot = op::dot(arg2, arg0); auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3}); auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->result(), dot); ASSERT_EQ(cluster_0->result(), dot);
} }
...@@ -40,14 +41,14 @@ TEST(build_graph, as_type) ...@@ -40,14 +41,14 @@ TEST(build_graph, as_type)
{ {
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5}); ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt); auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv); ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt); auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp); ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple. // Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt}); ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt); auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv); ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt); auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
ASSERT_EQ(tp_vt, tp_tp); ASSERT_EQ(tp_vt, tp_tp);
...@@ -63,17 +64,13 @@ TEST(build_graph, node_comparison) ...@@ -63,17 +64,13 @@ TEST(build_graph, node_comparison)
auto dot = op::dot(arg0, arg1); auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2); auto add = op::add(dot, arg2);
auto parg = op::parameter(element::float32_t, {}); auto parg = op::parameter(element::float32_t, {});
auto pattern_dot = op::dot(parg, parg); auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong. // Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->is_same_op_type(add)); ASSERT_FALSE(pattern_dot->is_same_op_type(add));
} }
// Check argument inverses // Check argument inverses
TEST(build_graph, arg_inverse) TEST(build_graph, arg_inverse) {}
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment