Commit 235c8ea0 authored by Scott Cyphers's avatar Scott Cyphers

Remove ::ptr

parent 8b44c451
...@@ -25,14 +25,14 @@ namespace ngraph ...@@ -25,14 +25,14 @@ namespace ngraph
class Function class Function
{ {
public: public:
Function(const Node::ptr& result, Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
Node::ptr result() { return m_result; } std::shared_ptr<Node> result() { return m_result; }
Parameter::ptr parameter(size_t i) { return m_parameters[i]; } std::shared_ptr<Parameter> parameter(size_t i) { return m_parameters[i]; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
protected: protected:
Node::ptr m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
...@@ -40,10 +40,10 @@ namespace ngraph ...@@ -40,10 +40,10 @@ namespace ngraph
namespace op namespace op
{ {
std::shared_ptr<Function> std::shared_ptr<Function>
function(const Node::ptr& result, function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters); const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function> std::shared_ptr<Function>
function(const Node::ptr& result, function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<Parameter>>& parameters);
} }
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
size_t ngraph::Node::m_next_instance_id = 0; size_t ngraph::Node::m_next_instance_id = 0;
ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type) ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments, std::shared_ptr<ValueType> type)
: TypedValueMixin(type) : TypedValueMixin(type)
, m_arguments(arguments) , m_arguments(arguments)
, m_instance_id(m_next_instance_id++) , m_instance_id(m_next_instance_id++)
......
...@@ -32,11 +32,9 @@ namespace ngraph ...@@ -32,11 +32,9 @@ namespace ngraph
/// view or a (possibly empty) tuple of values. /// view or a (possibly empty) tuple of values.
class Node : public TypedValueMixin class Node : public TypedValueMixin
{ {
public:
using ptr = std::shared_ptr<Node>;
protected: protected:
Node(const Nodes& arguments, ValueType::ptr type = nullptr); Node(const Nodes& arguments, std::shared_ptr<ValueType> type = nullptr);
virtual ~Node() {} virtual ~Node() {}
public: public:
...@@ -58,7 +56,7 @@ namespace ngraph ...@@ -58,7 +56,7 @@ namespace ngraph
/// Return true if this has the same implementing class as node. This /// Return true if this has the same implementing class as node. This
/// will be used by the pattern matcher when comparing a pattern /// will be used by the pattern matcher when comparing a pattern
/// graph against the graph. /// graph against the graph.
bool is_same_op_type(const Node::ptr& node) const bool is_same_op_type(const std::shared_ptr<Node>& node) const
{ {
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
} }
...@@ -76,6 +74,4 @@ namespace ngraph ...@@ -76,6 +74,4 @@ namespace ngraph
size_t m_instance_id; size_t m_instance_id;
static size_t m_next_instance_id; static size_t m_next_instance_id;
}; };
using node_ptr = std::shared_ptr<Node>;
} }
...@@ -25,40 +25,40 @@ namespace ngraph ...@@ -25,40 +25,40 @@ namespace ngraph
namespace op namespace op
{ {
Node::ptr abs(const Node::ptr& arg); std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//Node::ptr convert(); //std::shared_ptr<Node> convert();
//Node::ptr convolution(); //std::shared_ptr<Node> convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr exp(const Node::ptr& arg0); std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//Node::ptr get_tuple_element(); //std::shared_ptr<Node> get_tuple_element();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//Node::ptr greater_equal(const Node::ptr& arg0, const Node::ptr& arg1); //std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//Node::ptr less_equal(const Node::ptr& arg0, const Node::ptr& arg1); //std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr log(const Node::ptr& arg0); std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//Node::ptr logical(); and, or, not //std::shared_ptr<Node> logical(); and, or, not
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr negative(const Node::ptr& arg0); std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//Node::ptr pad(); //std::shared_ptr<Node> pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//Node::ptr reduce(); //std::shared_ptr<Node> reduce();
// Node::ptr reduce_window(); // std::shared_ptr<Node> reduce_window();
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape); std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//Node::ptr reverse(); //std::shared_ptr<Node> reverse();
//Node::ptr rng(); //std::shared_ptr<Node> rng();
//Node::ptr select(); //std::shared_ptr<Node> select();
//Node::ptr select_scatter(); //std::shared_ptr<Node> select_scatter();
//Node::ptr slice(); //std::shared_ptr<Node> slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
//Node::ptr transpose(); //std::shared_ptr<Node> transpose();
//Node::ptr while(); //std::shared_ptr<Node> while();
} }
/// Op nodes are nodes whose value is the result of some operation /// Op nodes are nodes whose value is the result of some operation
...@@ -67,7 +67,7 @@ namespace ngraph ...@@ -67,7 +67,7 @@ namespace ngraph
class Op : public Node class Op : public Node
{ {
public: public:
Op(const std::vector<Node::ptr>& arguments) Op(const std::vector<std::shared_ptr<Node>>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
{ {
} }
...@@ -82,7 +82,7 @@ namespace ngraph ...@@ -82,7 +82,7 @@ namespace ngraph
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionOp"; }
protected: protected:
Node::ptr m_function; std::shared_ptr<Node> m_function;
}; };
/// The is an operation we handle directly, i.e. all type checking, etc. /// The is an operation we handle directly, i.e. all type checking, etc.
...@@ -96,7 +96,7 @@ namespace ngraph ...@@ -96,7 +96,7 @@ namespace ngraph
// TODO: Implement for each op. This enables graphs to be built for now. // TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinOp(const std::vector<Node::ptr>& args) BuiltinOp(const std::vector<std::shared_ptr<Node>>& args)
: Op(args) : Op(args)
{ {
} }
...@@ -105,7 +105,7 @@ namespace ngraph ...@@ -105,7 +105,7 @@ namespace ngraph
class AbsOp : public BuiltinOp class AbsOp : public BuiltinOp
{ {
public: public:
AbsOp(const Node::ptr& arg0) AbsOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -117,7 +117,7 @@ namespace ngraph ...@@ -117,7 +117,7 @@ namespace ngraph
class AddOp : public BuiltinOp class AddOp : public BuiltinOp
{ {
public: public:
AddOp(const Node::ptr& arg0, const Node::ptr& arg1) AddOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -128,7 +128,7 @@ namespace ngraph ...@@ -128,7 +128,7 @@ namespace ngraph
class CeilingOp : public BuiltinOp class CeilingOp : public BuiltinOp
{ {
public: public:
CeilingOp(const Node::ptr& arg0, const Node::ptr& arg1) CeilingOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -140,7 +140,7 @@ namespace ngraph ...@@ -140,7 +140,7 @@ namespace ngraph
class DivideOp : public BuiltinOp class DivideOp : public BuiltinOp
{ {
public: public:
DivideOp(const Node::ptr& arg0, const Node::ptr& arg1) DivideOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -152,7 +152,7 @@ namespace ngraph ...@@ -152,7 +152,7 @@ namespace ngraph
class EqualOp : public BuiltinOp class EqualOp : public BuiltinOp
{ {
public: public:
EqualOp(const Node::ptr& arg0, const Node::ptr& arg1) EqualOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -164,7 +164,7 @@ namespace ngraph ...@@ -164,7 +164,7 @@ namespace ngraph
class ExpOp : public BuiltinOp class ExpOp : public BuiltinOp
{ {
public: public:
ExpOp(const Node::ptr& arg0) ExpOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -176,7 +176,7 @@ namespace ngraph ...@@ -176,7 +176,7 @@ namespace ngraph
class FloorOp : public BuiltinOp class FloorOp : public BuiltinOp
{ {
public: public:
FloorOp(const Node::ptr& arg0, const Node::ptr& arg1) FloorOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -188,7 +188,7 @@ namespace ngraph ...@@ -188,7 +188,7 @@ namespace ngraph
class GreaterOp : public BuiltinOp class GreaterOp : public BuiltinOp
{ {
public: public:
GreaterOp(const Node::ptr& arg0, const Node::ptr& arg1) GreaterOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -200,7 +200,7 @@ namespace ngraph ...@@ -200,7 +200,7 @@ namespace ngraph
class LessOp : public BuiltinOp class LessOp : public BuiltinOp
{ {
public: public:
LessOp(const Node::ptr& arg0, const Node::ptr& arg1) LessOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -212,7 +212,7 @@ namespace ngraph ...@@ -212,7 +212,7 @@ namespace ngraph
class LogOp : public BuiltinOp class LogOp : public BuiltinOp
{ {
public: public:
LogOp(const Node::ptr& arg0) LogOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -224,7 +224,7 @@ namespace ngraph ...@@ -224,7 +224,7 @@ namespace ngraph
class MaximumOp : public BuiltinOp class MaximumOp : public BuiltinOp
{ {
public: public:
MaximumOp(const Node::ptr& arg0, const Node::ptr& arg1) MaximumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -236,7 +236,7 @@ namespace ngraph ...@@ -236,7 +236,7 @@ namespace ngraph
class MinimumOp : public BuiltinOp class MinimumOp : public BuiltinOp
{ {
public: public:
MinimumOp(const Node::ptr& arg0, const Node::ptr& arg1) MinimumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -248,7 +248,7 @@ namespace ngraph ...@@ -248,7 +248,7 @@ namespace ngraph
class MultiplyOp : public BuiltinOp class MultiplyOp : public BuiltinOp
{ {
public: public:
MultiplyOp(const Node::ptr& arg0, const Node::ptr& arg1) MultiplyOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -260,7 +260,7 @@ namespace ngraph ...@@ -260,7 +260,7 @@ namespace ngraph
class NegativeOp : public BuiltinOp class NegativeOp : public BuiltinOp
{ {
public: public:
NegativeOp(const Node::ptr& arg0) NegativeOp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -272,7 +272,7 @@ namespace ngraph ...@@ -272,7 +272,7 @@ namespace ngraph
class PowerOp : public BuiltinOp class PowerOp : public BuiltinOp
{ {
public: public:
PowerOp(const Node::ptr& arg0, const Node::ptr& arg1) PowerOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -284,7 +284,7 @@ namespace ngraph ...@@ -284,7 +284,7 @@ namespace ngraph
class RemainderOp : public BuiltinOp class RemainderOp : public BuiltinOp
{ {
public: public:
RemainderOp(const Node::ptr& arg0, const Node::ptr& arg1) RemainderOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -296,7 +296,7 @@ namespace ngraph ...@@ -296,7 +296,7 @@ namespace ngraph
class ReshapeOp : public BuiltinOp class ReshapeOp : public BuiltinOp
{ {
public: public:
ReshapeOp(const Node::ptr& arg0, const Shape& shape) ReshapeOp(const std::shared_ptr<Node>& arg0, const Shape& shape)
: BuiltinOp({arg0}) : BuiltinOp({arg0})
, m_shape(shape) , m_shape(shape)
{ {
...@@ -311,7 +311,7 @@ namespace ngraph ...@@ -311,7 +311,7 @@ namespace ngraph
class SubtractOp : public BuiltinOp class SubtractOp : public BuiltinOp
{ {
public: public:
SubtractOp(const Node::ptr& arg0, const Node::ptr& arg1) SubtractOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
/// ///
BroadcastOp(const Node::ptr& arg, const Shape& shape, const AxisSet& broadcast_axes) BroadcastOp(const std::shared_ptr<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
namespace op namespace op
{ {
Node::ptr broadcast(const Node::ptr& tensor, std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape, const Shape& shape,
AxisSet&& broadcast_axes); AxisSet&& broadcast_axes);
} }
......
...@@ -18,7 +18,7 @@ namespace ngraph ...@@ -18,7 +18,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
Node::ptr concatenate(const Nodes& args); std::shared_ptr<Node> concatenate(const Nodes& args);
} }
class ConcatOp : public BuiltinOp class ConcatOp : public BuiltinOp
......
...@@ -20,7 +20,7 @@ namespace ngraph ...@@ -20,7 +20,7 @@ namespace ngraph
class ConvertOp : public BuiltinOp class ConvertOp : public BuiltinOp
{ {
public: public:
ConvertOp(const Node::ptr& arg, const ngraph::element::Type& element_type) ConvertOp(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg}) : BuiltinOp({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
...@@ -35,6 +35,6 @@ namespace ngraph ...@@ -35,6 +35,6 @@ namespace ngraph
namespace op namespace op
{ {
std::shared_ptr<ngraph::ConvertOp> convert(const Node::ptr& arg, const ngraph::element::Type& element_type); std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
} }
} }
...@@ -20,7 +20,7 @@ namespace ngraph ...@@ -20,7 +20,7 @@ namespace ngraph
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const Node::ptr& arg0, const Node::ptr& arg1) DotOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -31,6 +31,6 @@ namespace ngraph ...@@ -31,6 +31,6 @@ namespace ngraph
namespace op namespace op
{ {
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1); std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
} }
} }
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
void assign_function(Function* function, size_t index); void assign_function(Function* function, size_t index);
public: public:
Parameter(const ValueType::ptr& value_type); Parameter(const std::shared_ptr<ValueType>& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape); Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
...@@ -51,7 +51,7 @@ namespace ngraph ...@@ -51,7 +51,7 @@ namespace ngraph
namespace op namespace op
{ {
/// Factory for frameworks /// Factory for frameworks
std::shared_ptr<ngraph::Parameter> parameter(const ValueType::ptr& value_type = nullptr); std::shared_ptr<ngraph::Parameter> parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests /// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type, std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape); const Shape& shape);
......
...@@ -18,7 +18,7 @@ namespace ngraph ...@@ -18,7 +18,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
Node::ptr tuple(const Nodes& args); std::shared_ptr<Node> tuple(const Nodes& args);
} }
class TupleOp : public BuiltinOp class TupleOp : public BuiltinOp
......
...@@ -31,23 +31,15 @@ namespace ngraph ...@@ -31,23 +31,15 @@ namespace ngraph
class ValueType class ValueType
{ {
public: public:
/**
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {} virtual ~ValueType() {}
virtual bool operator==(const ValueType::ptr& that) const = 0; virtual bool operator==(const std::shared_ptr<ValueType>& that) const = 0;
bool operator!=(const ValueType::ptr& that) const { return !(*this == that); } bool operator!=(const std::shared_ptr<ValueType>& that) const { return !(*this == that); }
}; };
/// Describes a tensor view; an element type and a shape. /// Describes a tensor view; an element type and a shape.
class TensorViewType : public ValueType class TensorViewType : public ValueType
{ {
public: public:
// Preferred handle
using ptr = std::shared_ptr<TensorViewType>;
/// /param element_type The type of the tensor elements. /// /param element_type The type of the tensor elements.
/// /param shape The shape of the tensor. /// /param shape The shape of the tensor.
TensorViewType(const element::Type& element_type, const Shape& shape) TensorViewType(const element::Type& element_type, const Shape& shape)
...@@ -59,7 +51,7 @@ namespace ngraph ...@@ -59,7 +51,7 @@ namespace ngraph
const element::Type& get_element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override; virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
protected: protected:
const element::Type& m_element_type; const element::Type& m_element_type;
...@@ -70,24 +62,22 @@ namespace ngraph ...@@ -70,24 +62,22 @@ namespace ngraph
class TupleType : public ValueType class TupleType : public ValueType
{ {
public: public:
using ptr = std::shared_ptr<ValueType>;
/// Construct empty tuple and add value types later. /// Construct empty tuple and add value types later.
TupleType() {} TupleType() {}
/// @param element_types A vector of types for the tuple elements /// @param element_types A vector of types for the tuple elements
TupleType(const std::vector<ValueType::ptr>& element_types) TupleType(const std::vector<std::shared_ptr<ValueType>>& element_types)
: m_element_types(element_types) : m_element_types(element_types)
{ {
} }
const std::vector<ValueType::ptr> get_element_types() const { return m_element_types; } const std::vector<std::shared_ptr<ValueType>> get_element_types() const { return m_element_types; }
std::vector<ValueType::ptr> set_element_types() { return m_element_types; } std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override; virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
protected: protected:
std::vector<ValueType::ptr> m_element_types; std::vector<std::shared_ptr<ValueType>> m_element_types;
}; };
/** /**
...@@ -96,7 +86,7 @@ namespace ngraph ...@@ -96,7 +86,7 @@ namespace ngraph
class TypedValueMixin class TypedValueMixin
{ {
public: public:
TypedValueMixin(const ValueType::ptr& value_type = nullptr) TypedValueMixin(const std::shared_ptr<ValueType>& value_type = nullptr)
: m_value_type(value_type) : m_value_type(value_type)
{ {
} }
...@@ -105,7 +95,7 @@ namespace ngraph ...@@ -105,7 +95,7 @@ namespace ngraph
** Set the type ** Set the type
** /param type The new type ** /param type The new type
**/ **/
void set_value_type(const ValueType::ptr& value_type) { m_value_type = value_type; } void set_value_type(const std::shared_ptr<ValueType>& value_type) { m_value_type = value_type; }
/** /**
** Set the type to be a tensor view type ** Set the type to be a tensor view type
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
...@@ -119,12 +109,12 @@ namespace ngraph ...@@ -119,12 +109,12 @@ namespace ngraph
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
ValueType::ptr get_value_type() { return m_value_type; } std::shared_ptr<ValueType> get_value_type() { return m_value_type; }
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
const ValueType::ptr get_value_type() const { return m_value_type; } const std::shared_ptr<ValueType> get_value_type() const { return m_value_type; }
protected: protected:
ValueType::ptr m_value_type; std::shared_ptr<ValueType> m_value_type;
}; };
} }
...@@ -21,7 +21,7 @@ using namespace ngraph; ...@@ -21,7 +21,7 @@ using namespace ngraph;
/// @param shape The shape of the result /// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape, const Shape& shape,
AxisSet&& broadcast_axes) AxisSet&& broadcast_axes)
{ {
......
...@@ -24,7 +24,7 @@ void ConcatOp::propagate_types() ...@@ -24,7 +24,7 @@ void ConcatOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
Node::ptr op::concatenate(const std::vector<Node::ptr>& args) std::shared_ptr<Node> op::concatenate(const std::vector<std::shared_ptr<Node>>& args)
{ {
return make_shared<ConcatOp>(args); return make_shared<ConcatOp>(args);
} }
...@@ -24,7 +24,7 @@ void ConvertOp::propagate_types() ...@@ -24,7 +24,7 @@ void ConvertOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
shared_ptr<ConvertOp> op::convert(const Node::ptr& arg, const element::Type& element_type) shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg, const element::Type& element_type)
{ {
return make_shared<ConvertOp>(arg, element_type); return make_shared<ConvertOp>(arg, element_type);
} }
...@@ -20,7 +20,7 @@ using namespace std; ...@@ -20,7 +20,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DotOp>(arg0, arg1); return make_shared<DotOp>(arg0, arg1);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const Node::ptr& result, Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters) const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
...@@ -30,13 +30,13 @@ Function::Function(const Node::ptr& result ...@@ -30,13 +30,13 @@ Function::Function(const Node::ptr& result
} }
} }
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters) const initializer_list<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
} }
shared_ptr<Function> ngraph::op::function(const Node::ptr& result, shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters) const vector<shared_ptr<Parameter>>& parameters)
{ {
return make_shared<Function>(result, parameters); return make_shared<Function>(result, parameters);
......
...@@ -27,17 +27,17 @@ std::string ngraph::Op::get_node_id() const ...@@ -27,17 +27,17 @@ std::string ngraph::Op::get_node_id() const
return ss.str(); return ss.str();
} }
Node::ptr ngraph::op::abs(const Node::ptr& arg) std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
{ {
return make_shared<AbsOp>(arg); return make_shared<AbsOp>(arg);
} }
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<AddOp>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<CeilingOp>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
...@@ -45,61 +45,61 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -45,61 +45,61 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert', // 'convert',
// 'convolution', // 'convolution',
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<DivideOp>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
Node::ptr ngraph::op::exp(const Node::ptr& arg0) std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
{ {
return make_shared<ExpOp>(arg0); return make_shared<ExpOp>(arg0);
} }
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<FloorOp>(arg0, arg1); return make_shared<FloorOp>(arg0, arg1);
} }
Node::ptr ngraph::op::log(const Node::ptr& arg0) std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
{ {
return make_shared<LogOp>(arg0); return make_shared<LogOp>(arg0);
} }
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MaximumOp>(arg0, arg1); return make_shared<MaximumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MinimumOp>(arg0, arg1); return make_shared<MinimumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<MultiplyOp>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
Node::ptr ngraph::op::negative(const Node::ptr& arg0) std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
{ {
return make_shared<NegativeOp>(arg0); return make_shared<NegativeOp>(arg0);
} }
// 'pad', // 'pad',
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<PowerOp>(arg0, arg1); return make_shared<PowerOp>(arg0, arg1);
} }
//'reduce', //'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<RemainderOp>(arg0, arg1); return make_shared<RemainderOp>(arg0, arg1);
} }
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
{ {
return make_shared<ReshapeOp>(arg0, shape); return make_shared<ReshapeOp>(arg0, shape);
} }
...@@ -109,7 +109,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) ...@@ -109,7 +109,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select', // 'select',
//'slice', //'slice',
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
{ {
return make_shared<SubtractOp>(arg0, arg1); return make_shared<SubtractOp>(arg0, arg1);
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
Parameter::Parameter(const ValueType::ptr& value_type) Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node({}, value_type) : Node({}, value_type)
, m_function(nullptr) , m_function(nullptr)
, m_index(0) , m_index(0)
...@@ -45,7 +45,7 @@ void Parameter::propagate_types() ...@@ -45,7 +45,7 @@ void Parameter::propagate_types()
{ {
} }
shared_ptr<Parameter> ngraph::op::parameter(const ValueType::ptr& value_type) shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{ {
return make_shared<Parameter>(value_type); return make_shared<Parameter>(value_type);
} }
......
...@@ -24,7 +24,7 @@ void TupleOp::propagate_types() ...@@ -24,7 +24,7 @@ void TupleOp::propagate_types()
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
Node::ptr op::tuple(const std::vector<Node::ptr>& args) std::shared_ptr<Node> op::tuple(const std::vector<std::shared_ptr<Node>>& args)
{ {
return make_shared<TupleOp>(args); return make_shared<TupleOp>(args);
} }
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
bool TensorViewType::operator==(const ValueType::ptr& that) const bool TensorViewType::operator==(const std::shared_ptr<ValueType>& that) const
{ {
auto that_tvt = dynamic_pointer_cast<TensorViewType>(that); auto that_tvt = dynamic_pointer_cast<TensorViewType>(that);
if (nullptr == that_tvt) if (nullptr == that_tvt)
...@@ -37,7 +37,7 @@ bool TensorViewType::operator==(const ValueType::ptr& that) const ...@@ -37,7 +37,7 @@ bool TensorViewType::operator==(const ValueType::ptr& that) const
return true; return true;
} }
bool TupleType::operator==(const ValueType::ptr& that) const bool TupleType::operator==(const std::shared_ptr<ValueType>& that) const
{ {
auto that_tvt = dynamic_pointer_cast<TupleType>(that); auto that_tvt = dynamic_pointer_cast<TupleType>(that);
if (nullptr == that_tvt) if (nullptr == that_tvt)
......
...@@ -36,7 +36,7 @@ TEST(top_sort, basic) ...@@ -36,7 +36,7 @@ TEST(top_sort, basic)
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
auto t1 = op::add(arg0, arg1); auto t1 = op::add(arg0, arg1);
ASSERT_NE(nullptr, t1); ASSERT_NE(nullptr, t1);
Node::ptr r0 = op::add(t0, t1); auto r0 = op::add(t0, t1);
ASSERT_NE(nullptr, r0); ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, {arg0, arg1}); auto f0 = op::function(r0, {arg0, arg1});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment