Commit 36e36e7f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #78 from NervanaSystems/cyphers/morenames

De-use and cleanup op names.
parents 973b3a0e c7ef13f5
...@@ -22,14 +22,13 @@ ...@@ -22,14 +22,13 @@
namespace ngraph namespace ngraph
{ {
class Node; class Node;
namespace op {
class Parameter; class Parameter;
class ValueType;
template <typename T, typename... A> /// A list of parameters
std::shared_ptr<T> node(A&&... args) using Parameters = std::vector<std::shared_ptr<Parameter>>;
{
return std::make_shared<T>(args...);
} }
class ValueType;
/// Zero or more value types /// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>; using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
...@@ -42,7 +41,4 @@ namespace ngraph ...@@ -42,7 +41,4 @@ namespace ngraph
/// A set of axes, for example, reduction axes /// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>; using AxisSet = std::set<size_t>;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
} }
...@@ -26,10 +26,10 @@ namespace ngraph ...@@ -26,10 +26,10 @@ namespace ngraph
{ {
public: public:
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<op::Parameter>>& parameters);
std::shared_ptr<Node> get_result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<Parameter>> get_parameters() const const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const
{ {
return m_parameters; return m_parameters;
} }
...@@ -37,17 +37,7 @@ namespace ngraph ...@@ -37,17 +37,7 @@ namespace ngraph
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
namespace op
{
std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
}
} }
...@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const ...@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const
bool ngraph::Node::is_parameter() const bool ngraph::Node::is_parameter() const
{ {
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr; return dynamic_cast<const ngraph::op::Parameter*>(this) != nullptr;
} }
namespace ngraph namespace ngraph
......
...@@ -22,57 +22,6 @@ ...@@ -22,57 +22,6 @@
namespace ngraph namespace ngraph
{ {
namespace op
{
std::shared_ptr<Node> abs(const std::shared_ptr<Node>& arg);
std::shared_ptr<Node> add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> convert();
//std::shared_ptr<Node> convolution();
std::shared_ptr<Node> divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> exp(const std::shared_ptr<Node>& arg0);
std::shared_ptr<Node> floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> get_tuple_element();
std::shared_ptr<Node> greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> greater_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> less(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> less_equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> log(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> logical(); and, or, not
std::shared_ptr<Node> maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> negative(const std::shared_ptr<Node>& arg0);
//std::shared_ptr<Node> pad();
std::shared_ptr<Node> power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> reduce();
// std::shared_ptr<Node> reduce_window();
std::shared_ptr<Node> remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
std::shared_ptr<Node> reshape(const std::shared_ptr<Node>& arg0, const Shape& shape);
//std::shared_ptr<Node> reverse();
//std::shared_ptr<Node> rng();
//std::shared_ptr<Node> select();
//std::shared_ptr<Node> select_scatter();
//std::shared_ptr<Node> slice();
std::shared_ptr<Node> subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
//std::shared_ptr<Node> transpose();
//std::shared_ptr<Node> while();
}
/// Op nodes are nodes whose value is the result of some operation /// Op nodes are nodes whose value is the result of some operation
/// applied to its arguments. For calls to user functions, the op will /// applied to its arguments. For calls to user functions, the op will
/// reference the user function. /// reference the user function.
...@@ -93,11 +42,14 @@ namespace ngraph ...@@ -93,11 +42,14 @@ namespace ngraph
virtual std::string get_node_id() const override; virtual std::string get_node_id() const override;
}; };
/// A FunctionOp invokes a function on node arguments. In addition to the argument // TODO: These class definitions are to be moved into separate files in the op directory
namespace op
{
/// A Function invokes a function on node arguments. In addition to the argument
/// we need to preserve the function. /// we need to preserve the function.
class FunctionOp : public Op class FunctionCall : public Op
{ {
virtual std::string description() const override { return "FunctionOp"; } virtual std::string description() const override { return "FunctionCall"; }
protected: protected:
std::shared_ptr<Node> m_function; std::shared_ptr<Node> m_function;
...@@ -105,237 +57,238 @@ namespace ngraph ...@@ -105,237 +57,238 @@ namespace ngraph
/// The is an operation we handle directly, i.e. all type checking, etc. /// The is an operation we handle directly, i.e. all type checking, etc.
/// are defined in C++ rather than in terms of ngraph operations. /// are defined in C++ rather than in terms of ngraph operations.
class BuiltinOp : public Op class Builtin : public Op
{ {
public: public:
virtual std::string description() const override { return "BuiltinOp"; } virtual std::string description() const override { return "Builtin"; }
/// Name of the builtin op, for debugging and logging. /// Name of the builtin op, for debugging and logging.
// TODO: Implement for each op. This enables graphs to be built for now. // TODO: Implement for each op. This enables graphs to be built for now.
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinOp(const std::vector<std::shared_ptr<Node>>& args) Builtin(const std::vector<std::shared_ptr<Node>>& args)
: Op(args) : Op(args)
{ {
} }
}; };
class AbsOp : public BuiltinOp class Abs : public Builtin
{ {
public: public:
AbsOp(const std::shared_ptr<Node>& arg0) Abs(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : Builtin({arg0})
{ {
} }
virtual std::string get_op_class_name() const override { return "abs"; } virtual std::string get_op_class_name() const override { return "Abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class AddOp : public BuiltinOp class Add : public Builtin
{ {
public: public:
AddOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Add(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "add"; } virtual std::string get_op_class_name() const override { return "Add"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class CeilingOp : public BuiltinOp class Ceiling : public Builtin
{ {
public: public:
CeilingOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Ceiling(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "ceiling"; } virtual std::string get_op_class_name() const override { return "Ceiling"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class DivideOp : public BuiltinOp class Divide : public Builtin
{ {
public: public:
DivideOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Divide(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "divide"; } virtual std::string get_op_class_name() const override { return "Divide"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class EqualOp : public BuiltinOp class Equal : public Builtin
{ {
public: public:
EqualOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Equal(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "equal"; } virtual std::string get_op_class_name() const override { return "Equal"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ExpOp : public BuiltinOp class Exp : public Builtin
{ {
public: public:
ExpOp(const std::shared_ptr<Node>& arg0) Exp(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : Builtin({arg0})
{ {
} }
virtual std::string get_op_class_name() const override { return "exp"; } virtual std::string get_op_class_name() const override { return "Exp"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class FloorOp : public BuiltinOp class Floor : public Builtin
{ {
public: public:
FloorOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Floor(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "floor"; } virtual std::string get_op_class_name() const override { return "Floor"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class GreaterOp : public BuiltinOp class Greater : public Builtin
{ {
public: public:
GreaterOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Greater(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "greater"; } virtual std::string get_op_class_name() const override { return "Greater"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class LessOp : public BuiltinOp class Less : public Builtin
{ {
public: public:
LessOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Less(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "less"; } virtual std::string get_op_class_name() const override { return "Less"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class LogOp : public BuiltinOp class Log : public Builtin
{ {
public: public:
LogOp(const std::shared_ptr<Node>& arg0) Log(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : Builtin({arg0})
{ {
} }
virtual std::string get_op_class_name() const override { return "log"; } virtual std::string get_op_class_name() const override { return "Log"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MaximumOp : public BuiltinOp class Maximum : public Builtin
{ {
public: public:
MaximumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Maximum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "max"; } virtual std::string get_op_class_name() const override { return "Max"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MinimumOp : public BuiltinOp class Minimum : public Builtin
{ {
public: public:
MinimumOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Minimum(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "min"; } virtual std::string get_op_class_name() const override { return "Min"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MultiplyOp : public BuiltinOp class Multiply : public Builtin
{ {
public: public:
MultiplyOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Multiply(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "multiply"; } virtual std::string get_op_class_name() const override { return "Multiply"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class NegativeOp : public BuiltinOp class Negative : public Builtin
{ {
public: public:
NegativeOp(const std::shared_ptr<Node>& arg0) Negative(const std::shared_ptr<Node>& arg0)
: BuiltinOp({arg0}) : Builtin({arg0})
{ {
} }
virtual std::string get_op_class_name() const override { return "negative"; } virtual std::string get_op_class_name() const override { return "Negative"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class PowerOp : public BuiltinOp class Power : public Builtin
{ {
public: public:
PowerOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Power(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "power"; } virtual std::string get_op_class_name() const override { return "Power"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class RemainderOp : public BuiltinOp class Remainder : public Builtin
{ {
public: public:
RemainderOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Remainder(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "remainder"; } virtual std::string get_op_class_name() const override { return "Remainder"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ReshapeOp : public BuiltinOp class Reshape : public Builtin
{ {
public: public:
ReshapeOp(const std::shared_ptr<Node>& arg0, const Shape& shape) Reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
: BuiltinOp({arg0}) : Builtin({arg0})
, m_shape(shape) , m_shape(shape)
{ {
} }
virtual std::string get_op_class_name() const override { return "reshape"; } virtual std::string get_op_class_name() const override { return "Reshape"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
}; };
class SubtractOp : public BuiltinOp class Subtract : public Builtin
{ {
public: public:
SubtractOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Subtract(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "subtract"; } virtual std::string get_op_class_name() const override { return "Subtract"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
}
} }
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
namespace ngraph namespace ngraph
{ {
class BroadcastOp : public BuiltinOp namespace op
{
class Broadcast : public Builtin
{ {
public: public:
/// ///
...@@ -25,27 +27,21 @@ namespace ngraph ...@@ -25,27 +27,21 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
/// ///
BroadcastOp(const std::shared_ptr<Node>& arg, Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : Builtin({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
} }
virtual std::string get_op_class_name() const override { return "broadcast"; } virtual std::string get_op_class_name() const override { return "Broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
AxisSet m_broadcast_axes; AxisSet m_broadcast_axes;
}; };
namespace op
{
std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes);
} }
} }
...@@ -18,18 +18,16 @@ namespace ngraph ...@@ -18,18 +18,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
std::shared_ptr<Node> concatenate(const Nodes& args); class Concat : public Builtin
}
class ConcatOp : public BuiltinOp
{ {
public: public:
ConcatOp(const Nodes& args) Concat(const Nodes& args)
: BuiltinOp(args) : Builtin(args)
{ {
} }
virtual std::string get_op_class_name() const override { return "concatenate"; } virtual std::string get_op_class_name() const override { return "Concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
}
} }
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
namespace ngraph namespace ngraph
{ {
namespace op
{
// Defines methods to all constant scalars // Defines methods to all constant scalars
class ScalarConstantBase : public Node class ScalarConstantBase : public Node
{ {
...@@ -70,4 +72,5 @@ namespace ngraph ...@@ -70,4 +72,5 @@ namespace ngraph
using UInt8ScalarConstant = ScalarConstant<element::UInt8>; using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>; using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>; using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
}
} }
...@@ -16,25 +16,22 @@ ...@@ -16,25 +16,22 @@
namespace ngraph namespace ngraph
{ {
class ConvertOp : public BuiltinOp namespace op
{
class Convert : public Builtin
{ {
public: public:
ConvertOp(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type) Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg}) : Builtin({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
virtual std::string get_op_class_name() const override { return "convert"; } virtual std::string get_op_class_name() const override { return "Convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
} }
} }
...@@ -16,22 +16,19 @@ ...@@ -16,22 +16,19 @@
namespace ngraph namespace ngraph
{ {
class DotOp : public BuiltinOp namespace op
{
class Dot : public Builtin
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "dot"; } virtual std::string get_op_class_name() const override { return "Dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
namespace op
{
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
} }
} }
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ngraph namespace ngraph
{ {
class Function; class Function;
namespace op
{
/// ///
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions. /// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
/// Function creation requires a sequence of parameters. /// Function creation requires a sequence of parameters.
...@@ -28,7 +29,7 @@ namespace ngraph ...@@ -28,7 +29,7 @@ namespace ngraph
/// ///
class Parameter : public Node class Parameter : public Node
{ {
friend class Function; friend class ngraph::Function;
protected: protected:
// Called by the Function constructor to associate this parameter with the function. // Called by the Function constructor to associate this parameter with the function.
...@@ -47,14 +48,5 @@ namespace ngraph ...@@ -47,14 +48,5 @@ namespace ngraph
Function* m_function; Function* m_function;
size_t m_index; size_t m_index;
}; };
namespace op
{
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape);
} }
} }
...@@ -18,18 +18,16 @@ namespace ngraph ...@@ -18,18 +18,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
std::shared_ptr<Node> tuple(const Nodes& args); class Tuple : public Builtin
}
class TupleOp : public BuiltinOp
{ {
public: public:
TupleOp(const Nodes& args) Tuple(const Nodes& args)
: BuiltinOp(args) : Builtin(args)
{ {
} }
virtual std::string get_op_class_name() const override { return "tuple"; } virtual std::string get_op_class_name() const override { return "Tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
}
} }
...@@ -15,20 +15,9 @@ ...@@ -15,20 +15,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
/// @param tensor The tensor view to be broadcast. void Broadcast::propagate_types()
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->get_value_type(); auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type) if (nullptr == arg_type)
......
...@@ -17,14 +17,9 @@ ...@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
void ConcatOp::propagate_types() void Concat::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
std::shared_ptr<Node> op::concatenate(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<ConcatOp>(args);
}
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {} void ScalarConstantBase::propagate_types() {}
...@@ -17,15 +17,9 @@ ...@@ -17,15 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
void ConvertOp::propagate_types() void Convert::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
...@@ -17,16 +17,9 @@ ...@@ -17,16 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. void Dot::propagate_types()
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{ {
auto arg0_tensor_type = auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type()); dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
......
...@@ -18,7 +18,7 @@ using namespace std; ...@@ -18,7 +18,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const std::shared_ptr<Node>& result, Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters) const std::vector<std::shared_ptr<op::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
, m_name("Function") , m_name("Function")
...@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result ...@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result
parameter->assign_function(this, i++); parameter->assign_function(this, i++);
} }
} }
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
...@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const ...@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const
ss << get_op_class_name() << "_" << m_instance_id; ss << get_op_class_name() << "_" << m_instance_id;
return ss.str(); return ss.str();
} }
std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
{
return make_shared<AbsOp>(arg);
}
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<AddOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
// 'convert',
// 'convolution',
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DivideOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
{
return make_shared<ExpOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<FloorOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
{
return make_shared<LogOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MaximumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MinimumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MultiplyOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
{
return make_shared<NegativeOp>(arg0);
}
// 'pad',
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<RemainderOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
{
return make_shared<ReshapeOp>(arg0, shape);
}
//'reverse',
//'rng',
// 'select',
//'slice',
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<SubtractOp>(arg0, arg1);
}
// 'transpose',
// 'while'
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
Parameter::Parameter(const std::shared_ptr<ValueType>& value_type) Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node(value_type) : Node(value_type)
...@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index)
void Parameter::propagate_types() {} void Parameter::propagate_types() {}
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type) std::string ngraph::op::Parameter::get_node_id() const
{
return make_shared<Parameter>(value_type);
}
shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type,
const Shape& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
std::string ngraph::Parameter::get_node_id() const
{ {
stringstream ss; stringstream ss;
ss << "parameter_" << m_instance_id; ss << "parameter_" << m_instance_id;
......
...@@ -17,14 +17,9 @@ ...@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
void TupleOp::propagate_types() void Tuple::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
std::shared_ptr<Node> op::tuple(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<TupleOp>(args);
}
...@@ -23,17 +23,17 @@ using namespace ngraph; ...@@ -23,17 +23,17 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3}); auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7}); auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7}); auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0); auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3}); auto cluster_0 = make_shared<Function>(dot, op::Parameters{arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->get_result(), dot); ASSERT_EQ(cluster_0->get_result(), dot);
} }
...@@ -59,15 +59,15 @@ TEST(build_graph, as_type) ...@@ -59,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3}); auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32}); auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1); auto dot = make_shared<op::Dot>(arg0, arg1);
auto add = op::add(dot, arg2); auto add = make_shared<op::Add>(dot, arg2);
auto parg = node<Parameter>(element::Float32::element_type(), Shape{}); auto parg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg); auto pattern_dot = make_shared<op::Dot>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong. // Need to figure out what's wrong.
...@@ -78,20 +78,20 @@ TEST(build_graph, literal) ...@@ -78,20 +78,20 @@ TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<Float32ScalarConstant>(3.0); auto float0 = make_shared<op::Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0); ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type); ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0); auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0); ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = node<Float32ScalarConstant>(3); auto float1 = make_shared<op::Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type); ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0); auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3); ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
......
...@@ -23,7 +23,7 @@ using namespace ngraph; ...@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op) TEST(op, is_op)
{ {
auto arg0 = op::parameter(element::Float32::element_type(), {1}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter()); EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op()); EXPECT_FALSE(arg0->is_op());
...@@ -31,9 +31,9 @@ TEST(op, is_op) ...@@ -31,9 +31,9 @@ TEST(op, is_op)
TEST(op, is_parameter) TEST(op, is_parameter)
{ {
auto arg0 = op::parameter(element::Float32::element_type(), {1}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0); auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter()); EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op()); EXPECT_TRUE(t0->is_op());
......
...@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic) TEST(topological_sort, basic)
{ {
vector<shared_ptr<Parameter>> args; vector<shared_ptr<op::Parameter>> args;
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
{ {
auto arg = op::parameter(element::Float32::element_type(), {1}); auto arg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg); ASSERT_NE(nullptr, arg);
args.push_back(arg); args.push_back(arg);
} }
auto t0 = op::add(args[0], args[1]); auto t0 = make_shared<op::Add>(args[0], args[1]);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
auto t1 = op::dot(t0, args[2]); auto t1 = make_shared<op::Dot>(t0, args[2]);
ASSERT_NE(nullptr, t1); ASSERT_NE(nullptr, t1);
auto t2 = op::multiply(t0, args[3]); auto t2 = make_shared<op::Multiply>(t0, args[3]);
ASSERT_NE(nullptr, t2); ASSERT_NE(nullptr, t2);
auto t3 = op::add(t1, args[4]); auto t3 = make_shared<op::Add>(t1, args[4]);
ASSERT_NE(nullptr, t2); ASSERT_NE(nullptr, t2);
auto t4 = op::add(t2, args[5]); auto t4 = make_shared<op::Add>(t2, args[5]);
ASSERT_NE(nullptr, t3); ASSERT_NE(nullptr, t3);
auto r0 = op::add(t3, t4); auto r0 = make_shared<op::Add>(t3, t4);
ASSERT_NE(nullptr, r0); ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, args); auto f0 = make_shared<Function>(r0, args);
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment