Commit 689c22d8 authored by Scott Cyphers's avatar Scott Cyphers

Review comments

parent a136956b
......@@ -55,11 +55,9 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
const std::vector<Node::ptr>& arguments() const { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; }
std::multiset<Node*> users() { return m_users; }
const std::multiset<Node*>& users() const { return m_users; }
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
......
......@@ -84,18 +84,18 @@ namespace ngraph
class Call : public Node
{
public:
std::shared_ptr<Op> op() const { return m_op; }
const Op& op() const { return m_op; }
Call(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& arguments)
Call(const Op& op, const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
, m_op(op)
{
}
virtual std::string description() const override { return m_op->name(); }
virtual std::string description() const override { return m_op.name(); }
protected:
std::shared_ptr<Op> m_op;
const Op& m_op;
};
/**
......@@ -129,7 +129,7 @@ namespace ngraph
virtual void propagate_types() override {}
protected:
BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args)
BuiltinCall(const Op& op, const std::vector<Node::ptr>& args)
: Call(op, args)
{
}
......@@ -144,7 +144,7 @@ namespace ngraph
}
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class AddCall : public BuiltinCall
......@@ -156,7 +156,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class BroadcastCall : public BuiltinCall
......@@ -181,7 +181,7 @@ namespace ngraph
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class CeilingCall : public BuiltinCall
......@@ -193,7 +193,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class DivideCall : public BuiltinCall
......@@ -205,7 +205,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class DotCall : public BuiltinCall
......@@ -219,7 +219,7 @@ namespace ngraph
virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class EqualCall : public BuiltinCall
......@@ -231,7 +231,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class ExponentialCall : public BuiltinCall
......@@ -243,7 +243,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class FloorCall : public BuiltinCall
......@@ -255,7 +255,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class GreaterCall : public BuiltinCall
......@@ -267,7 +267,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class LessCall : public BuiltinCall
......@@ -279,7 +279,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class LogCall : public BuiltinCall
......@@ -291,7 +291,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class MaximumCall : public BuiltinCall
......@@ -303,7 +303,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class MinimumCall : public BuiltinCall
......@@ -315,7 +315,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class MultiplyCall : public BuiltinCall
......@@ -327,7 +327,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class NegateCall : public BuiltinCall
......@@ -339,7 +339,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class PowerCall : public BuiltinCall
......@@ -351,7 +351,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class RemainderCall : public BuiltinCall
......@@ -363,7 +363,7 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class ReshapeCall : public BuiltinCall
......@@ -378,7 +378,7 @@ namespace ngraph
protected:
Shape m_shape;
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
class SubtractCall : public BuiltinCall
......@@ -390,6 +390,6 @@ namespace ngraph
}
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
static BuiltinOp s_op;
};
}
......@@ -65,7 +65,7 @@ namespace ngraph
}
const element::Type& element_type() const { return m_element_type; }
const Shape shape() const { return m_shape; }
const Shape& shape() const { return m_shape; }
protected:
const element::Type& m_element_type;
......
......@@ -19,21 +19,21 @@
using namespace ngraph;
using namespace std;
std::shared_ptr<BuiltinOp> AbsCall::s_op = make_shared<BuiltinOp>("abs");
BuiltinOp AbsCall::s_op = BuiltinOp("abs");
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsCall>(arg);
}
std::shared_ptr<BuiltinOp> AddCall::s_op = make_shared<BuiltinOp>("add");
BuiltinOp AddCall::s_op = BuiltinOp("add");
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<AddCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast");
BuiltinOp BroadcastCall::s_op = BuiltinOp("broadcast");
/**
** /param arg The tensor view to be broadcast.
......@@ -74,7 +74,7 @@ void BroadcastCall::propagate_types()
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
std::shared_ptr<BuiltinOp> CeilingCall::s_op = make_shared<BuiltinOp>("ceiling");
BuiltinOp CeilingCall::s_op = BuiltinOp("ceiling");
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
......@@ -86,14 +86,14 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert',
// 'convolution',
std::shared_ptr<BuiltinOp> DivideCall::s_op = make_shared<BuiltinOp>("divide");
BuiltinOp DivideCall::s_op = BuiltinOp("divide");
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DivideCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot");
BuiltinOp DotCall::s_op = BuiltinOp("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
......@@ -139,49 +139,49 @@ void DotCall::propagate_types()
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
std::shared_ptr<BuiltinOp> ExponentialCall::s_op = make_shared<BuiltinOp>("exponential");
BuiltinOp ExponentialCall::s_op = BuiltinOp("exponential");
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialCall>(arg0);
}
std::shared_ptr<BuiltinOp> FloorCall::s_op = make_shared<BuiltinOp>("floor");
BuiltinOp FloorCall::s_op = BuiltinOp("floor");
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<FloorCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> LogCall::s_op = make_shared<BuiltinOp>("log");
BuiltinOp LogCall::s_op = BuiltinOp("log");
Node::ptr ngraph::op::log(const Node::ptr& arg0)
{
return make_shared<LogCall>(arg0);
}
std::shared_ptr<BuiltinOp> MaximumCall::s_op = make_shared<BuiltinOp>("maximum");
BuiltinOp MaximumCall::s_op = BuiltinOp("maximum");
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MaximumCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> MinimumCall::s_op = make_shared<BuiltinOp>("minimum");
BuiltinOp MinimumCall::s_op = BuiltinOp("minimum");
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MinimumCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> MultiplyCall::s_op = make_shared<BuiltinOp>("multiply");
BuiltinOp MultiplyCall::s_op = BuiltinOp("multiply");
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MultiplyCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> NegateCall::s_op = make_shared<BuiltinOp>("negate");
BuiltinOp NegateCall::s_op = BuiltinOp("negate");
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{
......@@ -191,7 +191,7 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
// 'pad',
// 'parameter',
std::shared_ptr<BuiltinOp> PowerCall::s_op = make_shared<BuiltinOp>("power");
BuiltinOp PowerCall::s_op = BuiltinOp("power");
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
......@@ -200,14 +200,14 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
//'reduce',
std::shared_ptr<BuiltinOp> RemainderCall::s_op = make_shared<BuiltinOp>("remainder");
BuiltinOp RemainderCall::s_op = BuiltinOp("remainder");
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<RemainderCall>(arg0, arg1);
}
std::shared_ptr<BuiltinOp> ReshapeCall::s_op = make_shared<BuiltinOp>("reshape");
BuiltinOp ReshapeCall::s_op = BuiltinOp("reshape");
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{
......@@ -219,7 +219,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select',
//'slice',
std::shared_ptr<BuiltinOp> SubtractCall::s_op = make_shared<BuiltinOp>("subtract");
BuiltinOp SubtractCall::s_op = BuiltinOp("subtract");
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment