Commit 689c22d8 authored by Scott Cyphers's avatar Scott Cyphers

Review comments

parent a136956b
...@@ -55,11 +55,9 @@ namespace ngraph ...@@ -55,11 +55,9 @@ namespace ngraph
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
virtual void propagate_types() = 0; virtual void propagate_types() = 0;
const std::vector<Node::ptr> arguments() const { return m_arguments; } const std::vector<Node::ptr>& arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
std::multiset<Node*> users() { return m_users; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; } void name(const std::string& name) { m_name = name; }
......
...@@ -84,18 +84,18 @@ namespace ngraph ...@@ -84,18 +84,18 @@ namespace ngraph
class Call : public Node class Call : public Node
{ {
public: public:
std::shared_ptr<Op> op() const { return m_op; } const Op& op() const { return m_op; }
Call(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& arguments) Call(const Op& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
, m_op(op) , m_op(op)
{ {
} }
virtual std::string description() const override { return m_op->name(); } virtual std::string description() const override { return m_op.name(); }
protected: protected:
std::shared_ptr<Op> m_op; const Op& m_op;
}; };
/** /**
...@@ -129,7 +129,7 @@ namespace ngraph ...@@ -129,7 +129,7 @@ namespace ngraph
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args) BuiltinCall(const Op& op, const std::vector<Node::ptr>& args)
: Call(op, args) : Call(op, args)
{ {
} }
...@@ -144,7 +144,7 @@ namespace ngraph ...@@ -144,7 +144,7 @@ namespace ngraph
} }
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class AddCall : public BuiltinCall class AddCall : public BuiltinCall
...@@ -156,7 +156,7 @@ namespace ngraph ...@@ -156,7 +156,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class BroadcastCall : public BuiltinCall class BroadcastCall : public BuiltinCall
...@@ -181,7 +181,7 @@ namespace ngraph ...@@ -181,7 +181,7 @@ namespace ngraph
Shape m_shape; Shape m_shape;
std::vector<size_t> m_broadcast_axes; std::vector<size_t> m_broadcast_axes;
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class CeilingCall : public BuiltinCall class CeilingCall : public BuiltinCall
...@@ -193,7 +193,7 @@ namespace ngraph ...@@ -193,7 +193,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class DivideCall : public BuiltinCall class DivideCall : public BuiltinCall
...@@ -205,7 +205,7 @@ namespace ngraph ...@@ -205,7 +205,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class DotCall : public BuiltinCall class DotCall : public BuiltinCall
...@@ -219,7 +219,7 @@ namespace ngraph ...@@ -219,7 +219,7 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class EqualCall : public BuiltinCall class EqualCall : public BuiltinCall
...@@ -231,7 +231,7 @@ namespace ngraph ...@@ -231,7 +231,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class ExponentialCall : public BuiltinCall class ExponentialCall : public BuiltinCall
...@@ -243,7 +243,7 @@ namespace ngraph ...@@ -243,7 +243,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class FloorCall : public BuiltinCall class FloorCall : public BuiltinCall
...@@ -255,7 +255,7 @@ namespace ngraph ...@@ -255,7 +255,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class GreaterCall : public BuiltinCall class GreaterCall : public BuiltinCall
...@@ -267,7 +267,7 @@ namespace ngraph ...@@ -267,7 +267,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class LessCall : public BuiltinCall class LessCall : public BuiltinCall
...@@ -279,7 +279,7 @@ namespace ngraph ...@@ -279,7 +279,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class LogCall : public BuiltinCall class LogCall : public BuiltinCall
...@@ -291,7 +291,7 @@ namespace ngraph ...@@ -291,7 +291,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class MaximumCall : public BuiltinCall class MaximumCall : public BuiltinCall
...@@ -303,7 +303,7 @@ namespace ngraph ...@@ -303,7 +303,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class MinimumCall : public BuiltinCall class MinimumCall : public BuiltinCall
...@@ -315,7 +315,7 @@ namespace ngraph ...@@ -315,7 +315,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class MultiplyCall : public BuiltinCall class MultiplyCall : public BuiltinCall
...@@ -327,7 +327,7 @@ namespace ngraph ...@@ -327,7 +327,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class NegateCall : public BuiltinCall class NegateCall : public BuiltinCall
...@@ -339,7 +339,7 @@ namespace ngraph ...@@ -339,7 +339,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class PowerCall : public BuiltinCall class PowerCall : public BuiltinCall
...@@ -351,7 +351,7 @@ namespace ngraph ...@@ -351,7 +351,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class RemainderCall : public BuiltinCall class RemainderCall : public BuiltinCall
...@@ -363,7 +363,7 @@ namespace ngraph ...@@ -363,7 +363,7 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class ReshapeCall : public BuiltinCall class ReshapeCall : public BuiltinCall
...@@ -378,7 +378,7 @@ namespace ngraph ...@@ -378,7 +378,7 @@ namespace ngraph
protected: protected:
Shape m_shape; Shape m_shape;
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
class SubtractCall : public BuiltinCall class SubtractCall : public BuiltinCall
...@@ -390,6 +390,6 @@ namespace ngraph ...@@ -390,6 +390,6 @@ namespace ngraph
} }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
static std::shared_ptr<BuiltinOp> s_op; static BuiltinOp s_op;
}; };
} }
...@@ -65,7 +65,7 @@ namespace ngraph ...@@ -65,7 +65,7 @@ namespace ngraph
} }
const element::Type& element_type() const { return m_element_type; } const element::Type& element_type() const { return m_element_type; }
const Shape shape() const { return m_shape; } const Shape& shape() const { return m_shape; }
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
......
...@@ -19,21 +19,21 @@ ...@@ -19,21 +19,21 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::shared_ptr<BuiltinOp> AbsCall::s_op = make_shared<BuiltinOp>("abs"); BuiltinOp AbsCall::s_op = BuiltinOp("abs");
Node::ptr ngraph::op::abs(const Node::ptr& arg) Node::ptr ngraph::op::abs(const Node::ptr& arg)
{ {
return make_shared<AbsCall>(arg); return make_shared<AbsCall>(arg);
} }
std::shared_ptr<BuiltinOp> AddCall::s_op = make_shared<BuiltinOp>("add"); BuiltinOp AddCall::s_op = BuiltinOp("add");
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<AddCall>(arg0, arg1); return make_shared<AddCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast"); BuiltinOp BroadcastCall::s_op = BuiltinOp("broadcast");
/** /**
** /param arg The tensor view to be broadcast. ** /param arg The tensor view to be broadcast.
...@@ -74,7 +74,7 @@ void BroadcastCall::propagate_types() ...@@ -74,7 +74,7 @@ void BroadcastCall::propagate_types()
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape); m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
} }
std::shared_ptr<BuiltinOp> CeilingCall::s_op = make_shared<BuiltinOp>("ceiling"); BuiltinOp CeilingCall::s_op = BuiltinOp("ceiling");
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
...@@ -86,14 +86,14 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -86,14 +86,14 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert', // 'convert',
// 'convolution', // 'convolution',
std::shared_ptr<BuiltinOp> DivideCall::s_op = make_shared<BuiltinOp>("divide"); BuiltinOp DivideCall::s_op = BuiltinOp("divide");
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DivideCall>(arg0, arg1); return make_shared<DivideCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot"); BuiltinOp DotCall::s_op = BuiltinOp("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
...@@ -139,49 +139,49 @@ void DotCall::propagate_types() ...@@ -139,49 +139,49 @@ void DotCall::propagate_types()
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape); m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
} }
std::shared_ptr<BuiltinOp> ExponentialCall::s_op = make_shared<BuiltinOp>("exponential"); BuiltinOp ExponentialCall::s_op = BuiltinOp("exponential");
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{ {
return make_shared<ExponentialCall>(arg0); return make_shared<ExponentialCall>(arg0);
} }
std::shared_ptr<BuiltinOp> FloorCall::s_op = make_shared<BuiltinOp>("floor"); BuiltinOp FloorCall::s_op = BuiltinOp("floor");
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<FloorCall>(arg0, arg1); return make_shared<FloorCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> LogCall::s_op = make_shared<BuiltinOp>("log"); BuiltinOp LogCall::s_op = BuiltinOp("log");
Node::ptr ngraph::op::log(const Node::ptr& arg0) Node::ptr ngraph::op::log(const Node::ptr& arg0)
{ {
return make_shared<LogCall>(arg0); return make_shared<LogCall>(arg0);
} }
std::shared_ptr<BuiltinOp> MaximumCall::s_op = make_shared<BuiltinOp>("maximum"); BuiltinOp MaximumCall::s_op = BuiltinOp("maximum");
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MaximumCall>(arg0, arg1); return make_shared<MaximumCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> MinimumCall::s_op = make_shared<BuiltinOp>("minimum"); BuiltinOp MinimumCall::s_op = BuiltinOp("minimum");
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MinimumCall>(arg0, arg1); return make_shared<MinimumCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> MultiplyCall::s_op = make_shared<BuiltinOp>("multiply"); BuiltinOp MultiplyCall::s_op = BuiltinOp("multiply");
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MultiplyCall>(arg0, arg1); return make_shared<MultiplyCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> NegateCall::s_op = make_shared<BuiltinOp>("negate"); BuiltinOp NegateCall::s_op = BuiltinOp("negate");
Node::ptr ngraph::op::negate(const Node::ptr& arg0) Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{ {
...@@ -191,7 +191,7 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0) ...@@ -191,7 +191,7 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
// 'pad', // 'pad',
// 'parameter', // 'parameter',
std::shared_ptr<BuiltinOp> PowerCall::s_op = make_shared<BuiltinOp>("power"); BuiltinOp PowerCall::s_op = BuiltinOp("power");
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
...@@ -200,14 +200,14 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -200,14 +200,14 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
//'reduce', //'reduce',
std::shared_ptr<BuiltinOp> RemainderCall::s_op = make_shared<BuiltinOp>("remainder"); BuiltinOp RemainderCall::s_op = BuiltinOp("remainder");
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<RemainderCall>(arg0, arg1); return make_shared<RemainderCall>(arg0, arg1);
} }
std::shared_ptr<BuiltinOp> ReshapeCall::s_op = make_shared<BuiltinOp>("reshape"); BuiltinOp ReshapeCall::s_op = BuiltinOp("reshape");
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{ {
...@@ -219,7 +219,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) ...@@ -219,7 +219,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select', // 'select',
//'slice', //'slice',
std::shared_ptr<BuiltinOp> SubtractCall::s_op = make_shared<BuiltinOp>("subtract"); BuiltinOp SubtractCall::s_op = BuiltinOp("subtract");
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment