Commit 1b026daa authored by Scott Cyphers's avatar Scott Cyphers

Remove Op -- use typeid.

parent 5f8bf07e
...@@ -61,7 +61,7 @@ namespace ngraph ...@@ -61,7 +61,7 @@ namespace ngraph
/** /**
** A user-defined function. ** A user-defined function.
**/ **/
class Function : public Op class Function
{ {
public: public:
Function(size_t n_parameters); Function(size_t n_parameters);
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
Parameter::ptr parameter(size_t i) { return m_parameters[i]; } Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const override { return m_name; } std::string name() const { return m_name; }
protected: protected:
std::vector<Parameter::ptr> m_parameters; std::vector<Parameter::ptr> m_parameters;
......
...@@ -64,18 +64,6 @@ namespace ngraph ...@@ -64,18 +64,6 @@ namespace ngraph
//Node::ptr while(); //Node::ptr while();
} }
/**
** Every instance of Op corresponds to a unique defined operation.
**/
class Op
{
protected:
virtual ~Op() {}
public:
virtual std::string name() const = 0;
};
/** /**
** Call nodes are nodes whose value is the result of some operation, the op, ** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the ** applied to its arguments. We use the op as a callable to construct the
...@@ -84,53 +72,34 @@ namespace ngraph ...@@ -84,53 +72,34 @@ namespace ngraph
class Call : public Node class Call : public Node
{ {
public: public:
const Op& op() const { return m_op; }
Call(const Op& op, const std::vector<Node::ptr>& arguments) Call(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
, m_op(op)
{
}
virtual std::string description() const override { return m_op.name(); }
protected:
const Op& m_op;
};
/**
** There is exactly one instance of builtin op for each pre-defined operation. These
** are intended to be used when matching calls in different graphs; every FooCall
** will have the same op.
**/
class BuiltinOp : public Op
{
friend class Call;
public:
BuiltinOp(const std::string& name)
: m_name(name)
{ {
} }
public: /**
std::string name() const override { return m_name; } ** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
protected: ** graph against the graph.
std::string m_name; **/
bool has_same_op(Call& call) { return typeid(this) == typeid(&call); }
virtual std::string description() const override { return "Call"; }
}; };
class BuiltinCall : public Call class BuiltinCall : public Call
{ {
public: public:
virtual std::string description() const override { return "BuiltinCall"; } virtual std::string description() const override { return "BuiltinCall"; }
/// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
// TODO: Implement for each op // TODO: Implement for each op
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinCall(const Op& op, const std::vector<Node::ptr>& args) BuiltinCall(const std::vector<Node::ptr>& args)
: Call(op, args) : Call(args)
{ {
} }
}; };
...@@ -139,24 +108,23 @@ namespace ngraph ...@@ -139,24 +108,23 @@ namespace ngraph
{ {
public: public:
AbsCall(const Node::ptr& arg0) AbsCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0}) : BuiltinCall({arg0})
{ {
} }
protected: virtual std::string op_name() const override { return "abs"; }
static BuiltinOp s_op; //virtual void propagate_types() override;
}; };
class AddCall : public BuiltinCall class AddCall : public BuiltinCall
{ {
public: public:
AddCall(const Node::ptr& arg0, const Node::ptr& arg1) AddCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
//virtual void propagate_types() override; virtual std::string op_name() const override { return "add"; }
protected: //virtual void propagate_types() override;
static BuiltinOp s_op;
}; };
class BroadcastCall : public BuiltinCall class BroadcastCall : public BuiltinCall
...@@ -169,43 +137,42 @@ namespace ngraph ...@@ -169,43 +137,42 @@ namespace ngraph
** the remaining axes in shape must be the same as the shape of arg. ** the remaining axes in shape must be the same as the shape of arg.
**/ **/
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes) BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall(s_op, {arg}) : BuiltinCall({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
} }
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
std::vector<size_t> m_broadcast_axes; std::vector<size_t> m_broadcast_axes;
static BuiltinOp s_op;
}; };
class CeilingCall : public BuiltinCall class CeilingCall : public BuiltinCall
{ {
public: public:
CeilingCall(const Node::ptr& arg0, const Node::ptr& arg1) CeilingCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "ceiling"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class DivideCall : public BuiltinCall class DivideCall : public BuiltinCall
{ {
public: public:
DivideCall(const Node::ptr& arg0, const Node::ptr& arg1) DivideCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "divide"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class DotCall : public BuiltinCall class DotCall : public BuiltinCall
...@@ -213,183 +180,182 @@ namespace ngraph ...@@ -213,183 +180,182 @@ namespace ngraph
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1) DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class EqualCall : public BuiltinCall class EqualCall : public BuiltinCall
{ {
public: public:
EqualCall(const Node::ptr& arg0, const Node::ptr& arg1) EqualCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "equal"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class ExponentialCall : public BuiltinCall class ExponentialCall : public BuiltinCall
{ {
public: public:
ExponentialCall(const Node::ptr& arg0) ExponentialCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0}) : BuiltinCall({arg0})
{ {
} }
virtual std::string op_name() const override { return "exp"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class FloorCall : public BuiltinCall class FloorCall : public BuiltinCall
{ {
public: public:
FloorCall(const Node::ptr& arg0, const Node::ptr& arg1) FloorCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "floor"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class GreaterCall : public BuiltinCall class GreaterCall : public BuiltinCall
{ {
public: public:
GreaterCall(const Node::ptr& arg0, const Node::ptr& arg1) GreaterCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "greater"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class LessCall : public BuiltinCall class LessCall : public BuiltinCall
{ {
public: public:
LessCall(const Node::ptr& arg0, const Node::ptr& arg1) LessCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "less"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class LogCall : public BuiltinCall class LogCall : public BuiltinCall
{ {
public: public:
LogCall(const Node::ptr& arg0) LogCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0}) : BuiltinCall({arg0})
{ {
} }
virtual std::string op_name() const override { return "log"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class MaximumCall : public BuiltinCall class MaximumCall : public BuiltinCall
{ {
public: public:
MaximumCall(const Node::ptr& arg0, const Node::ptr& arg1) MaximumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "max"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class MinimumCall : public BuiltinCall class MinimumCall : public BuiltinCall
{ {
public: public:
MinimumCall(const Node::ptr& arg0, const Node::ptr& arg1) MinimumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "min"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class MultiplyCall : public BuiltinCall class MultiplyCall : public BuiltinCall
{ {
public: public:
MultiplyCall(const Node::ptr& arg0, const Node::ptr& arg1) MultiplyCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "multiply"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class NegateCall : public BuiltinCall class NegateCall : public BuiltinCall
{ {
public: public:
NegateCall(const Node::ptr& arg0) NegateCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0}) : BuiltinCall({arg0})
{ {
} }
virtual std::string op_name() const override { return "negate"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class PowerCall : public BuiltinCall class PowerCall : public BuiltinCall
{ {
public: public:
PowerCall(const Node::ptr& arg0, const Node::ptr& arg1) PowerCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "power"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class RemainderCall : public BuiltinCall class RemainderCall : public BuiltinCall
{ {
public: public:
RemainderCall(const Node::ptr& arg0, const Node::ptr& arg1) RemainderCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "remainder"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
class ReshapeCall : public BuiltinCall class ReshapeCall : public BuiltinCall
{ {
public: public:
ReshapeCall(const Node::ptr& arg0, const Shape& shape) ReshapeCall(const Node::ptr& arg0, const Shape& shape)
: BuiltinCall(s_op, {arg0}) : BuiltinCall({arg0})
, m_shape(shape) , m_shape(shape)
{ {
} }
virtual std::string op_name() const override { return "reshape"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
static BuiltinOp s_op;
}; };
class SubtractCall : public BuiltinCall class SubtractCall : public BuiltinCall
{ {
public: public:
SubtractCall(const Node::ptr& arg0, const Node::ptr& arg1) SubtractCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1}) : BuiltinCall({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "subtract"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
}; };
} }
...@@ -19,22 +19,16 @@ ...@@ -19,22 +19,16 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
BuiltinOp AbsCall::s_op = BuiltinOp("abs");
Node::ptr ngraph::op::abs(const Node::ptr& arg) Node::ptr ngraph::op::abs(const Node::ptr& arg)
{ {
return make_shared<AbsCall>(arg); return make_shared<AbsCall>(arg);
} }
BuiltinOp AddCall::s_op = BuiltinOp("add");
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<AddCall>(arg0, arg1); return make_shared<AddCall>(arg0, arg1);
} }
BuiltinOp BroadcastCall::s_op = BuiltinOp("broadcast");
/** /**
** /param arg The tensor view to be broadcast. ** /param arg The tensor view to be broadcast.
** /param shape The shape of the result ** /param shape The shape of the result
...@@ -74,8 +68,6 @@ void BroadcastCall::propagate_types() ...@@ -74,8 +68,6 @@ void BroadcastCall::propagate_types()
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape); m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
} }
BuiltinOp CeilingCall::s_op = BuiltinOp("ceiling");
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<CeilingCall>(arg0, arg1); return make_shared<CeilingCall>(arg0, arg1);
...@@ -86,15 +78,11 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -86,15 +78,11 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert', // 'convert',
// 'convolution', // 'convolution',
BuiltinOp DivideCall::s_op = BuiltinOp("divide");
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DivideCall>(arg0, arg1); return make_shared<DivideCall>(arg0, arg1);
} }
BuiltinOp DotCall::s_op = BuiltinOp("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
...@@ -139,50 +127,36 @@ void DotCall::propagate_types() ...@@ -139,50 +127,36 @@ void DotCall::propagate_types()
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape); m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
} }
BuiltinOp ExponentialCall::s_op = BuiltinOp("exponential");
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{ {
return make_shared<ExponentialCall>(arg0); return make_shared<ExponentialCall>(arg0);
} }
BuiltinOp FloorCall::s_op = BuiltinOp("floor");
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<FloorCall>(arg0, arg1); return make_shared<FloorCall>(arg0, arg1);
} }
BuiltinOp LogCall::s_op = BuiltinOp("log");
Node::ptr ngraph::op::log(const Node::ptr& arg0) Node::ptr ngraph::op::log(const Node::ptr& arg0)
{ {
return make_shared<LogCall>(arg0); return make_shared<LogCall>(arg0);
} }
BuiltinOp MaximumCall::s_op = BuiltinOp("maximum");
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MaximumCall>(arg0, arg1); return make_shared<MaximumCall>(arg0, arg1);
} }
BuiltinOp MinimumCall::s_op = BuiltinOp("minimum");
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MinimumCall>(arg0, arg1); return make_shared<MinimumCall>(arg0, arg1);
} }
BuiltinOp MultiplyCall::s_op = BuiltinOp("multiply");
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MultiplyCall>(arg0, arg1); return make_shared<MultiplyCall>(arg0, arg1);
} }
BuiltinOp NegateCall::s_op = BuiltinOp("negate");
Node::ptr ngraph::op::negate(const Node::ptr& arg0) Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{ {
return make_shared<NegateCall>(arg0); return make_shared<NegateCall>(arg0);
...@@ -191,8 +165,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0) ...@@ -191,8 +165,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
// 'pad', // 'pad',
// 'parameter', // 'parameter',
BuiltinOp PowerCall::s_op = BuiltinOp("power");
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<PowerCall>(arg0, arg1); return make_shared<PowerCall>(arg0, arg1);
...@@ -200,15 +172,11 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -200,15 +172,11 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
//'reduce', //'reduce',
BuiltinOp RemainderCall::s_op = BuiltinOp("remainder");
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<RemainderCall>(arg0, arg1); return make_shared<RemainderCall>(arg0, arg1);
} }
BuiltinOp ReshapeCall::s_op = BuiltinOp("reshape");
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{ {
return make_shared<ReshapeCall>(arg0, shape); return make_shared<ReshapeCall>(arg0, shape);
...@@ -219,8 +187,6 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) ...@@ -219,8 +187,6 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select', // 'select',
//'slice', //'slice',
BuiltinOp SubtractCall::s_op = BuiltinOp("subtract");
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<SubtractCall>(arg0, arg1); return make_shared<SubtractCall>(arg0, arg1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment