Commit 1b026daa authored by Scott Cyphers's avatar Scott Cyphers

Remove Op -- use typeid.

parent 5f8bf07e
......@@ -61,7 +61,7 @@ namespace ngraph
/**
** A user-defined function.
**/
class Function : public Op
class Function
{
public:
Function(size_t n_parameters);
......@@ -70,7 +70,7 @@ namespace ngraph
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const override { return m_name; }
std::string name() const { return m_name; }
protected:
std::vector<Parameter::ptr> m_parameters;
......
......@@ -64,18 +64,6 @@ namespace ngraph
//Node::ptr while();
}
/**
** Every instance of Op corresponds to a unique defined operation.
**/
class Op
{
protected:
virtual ~Op() {}
public:
virtual std::string name() const = 0;
};
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
......@@ -84,53 +72,34 @@ namespace ngraph
class Call : public Node
{
public:
const Op& op() const { return m_op; }
Call(const Op& op, const std::vector<Node::ptr>& arguments)
Call(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
, m_op(op)
{
}
virtual std::string description() const override { return m_op.name(); }
protected:
const Op& m_op;
};
/**
** There is exactly one instance of builtin op for each pre-defined operation. These
** are intended to be used when matching calls in different graphs; every FooCall
** will have the same op.
**/
class BuiltinOp : public Op
{
friend class Call;
public:
BuiltinOp(const std::string& name)
: m_name(name)
{
}
public:
std::string name() const override { return m_name; }
protected:
std::string m_name;
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
**/
bool has_same_op(Call& call) { return typeid(this) == typeid(&call); }
virtual std::string description() const override { return "Call"; }
};
class BuiltinCall : public Call
{
public:
virtual std::string description() const override { return "BuiltinCall"; }
/// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
// TODO: Implement for each op
virtual void propagate_types() override {}
protected:
BuiltinCall(const Op& op, const std::vector<Node::ptr>& args)
: Call(op, args)
BuiltinCall(const std::vector<Node::ptr>& args)
: Call(args)
{
}
};
......@@ -139,24 +108,23 @@ namespace ngraph
{
public:
AbsCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
: BuiltinCall({arg0})
{
}
protected:
static BuiltinOp s_op;
virtual std::string op_name() const override { return "abs"; }
//virtual void propagate_types() override;
};
class AddCall : public BuiltinCall
{
public:
AddCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
virtual std::string op_name() const override { return "add"; }
//virtual void propagate_types() override;
};
class BroadcastCall : public BuiltinCall
......@@ -169,43 +137,42 @@ namespace ngraph
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall(s_op, {arg})
: BuiltinCall({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
static BuiltinOp s_op;
};
class CeilingCall : public BuiltinCall
{
public:
CeilingCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "ceiling"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class DivideCall : public BuiltinCall
{
public:
DivideCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "divide"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class DotCall : public BuiltinCall
......@@ -213,183 +180,182 @@ namespace ngraph
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class EqualCall : public BuiltinCall
{
public:
EqualCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "equal"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class ExponentialCall : public BuiltinCall
{
public:
ExponentialCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
: BuiltinCall({arg0})
{
}
virtual std::string op_name() const override { return "exp"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class FloorCall : public BuiltinCall
{
public:
FloorCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "floor"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class GreaterCall : public BuiltinCall
{
public:
GreaterCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "greater"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class LessCall : public BuiltinCall
{
public:
LessCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "less"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class LogCall : public BuiltinCall
{
public:
LogCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
: BuiltinCall({arg0})
{
}
virtual std::string op_name() const override { return "log"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class MaximumCall : public BuiltinCall
{
public:
MaximumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "max"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class MinimumCall : public BuiltinCall
{
public:
MinimumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "min"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class MultiplyCall : public BuiltinCall
{
public:
MultiplyCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "multiply"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class NegateCall : public BuiltinCall
{
public:
NegateCall(const Node::ptr& arg0)
: BuiltinCall(s_op, {arg0})
: BuiltinCall({arg0})
{
}
virtual std::string op_name() const override { return "negate"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class PowerCall : public BuiltinCall
{
public:
PowerCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "power"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class RemainderCall : public BuiltinCall
{
public:
RemainderCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "remainder"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
class ReshapeCall : public BuiltinCall
{
public:
ReshapeCall(const Node::ptr& arg0, const Shape& shape)
: BuiltinCall(s_op, {arg0})
: BuiltinCall({arg0})
, m_shape(shape)
{
}
virtual std::string op_name() const override { return "reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
static BuiltinOp s_op;
};
class SubtractCall : public BuiltinCall
{
public:
SubtractCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
: BuiltinCall({arg0, arg1})
{
}
virtual std::string op_name() const override { return "subtract"; }
//virtual void propagate_types() override;
protected:
static BuiltinOp s_op;
};
}
......@@ -19,22 +19,16 @@
using namespace ngraph;
using namespace std;
BuiltinOp AbsCall::s_op = BuiltinOp("abs");
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsCall>(arg);
}
BuiltinOp AddCall::s_op = BuiltinOp("add");
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<AddCall>(arg0, arg1);
}
BuiltinOp BroadcastCall::s_op = BuiltinOp("broadcast");
/**
** /param arg The tensor view to be broadcast.
** /param shape The shape of the result
......@@ -74,8 +68,6 @@ void BroadcastCall::propagate_types()
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
BuiltinOp CeilingCall::s_op = BuiltinOp("ceiling");
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<CeilingCall>(arg0, arg1);
......@@ -86,15 +78,11 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
// 'convert',
// 'convolution',
BuiltinOp DivideCall::s_op = BuiltinOp("divide");
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DivideCall>(arg0, arg1);
}
BuiltinOp DotCall::s_op = BuiltinOp("dot");
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
......@@ -139,50 +127,36 @@ void DotCall::propagate_types()
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
BuiltinOp ExponentialCall::s_op = BuiltinOp("exponential");
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialCall>(arg0);
}
BuiltinOp FloorCall::s_op = BuiltinOp("floor");
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<FloorCall>(arg0, arg1);
}
BuiltinOp LogCall::s_op = BuiltinOp("log");
Node::ptr ngraph::op::log(const Node::ptr& arg0)
{
return make_shared<LogCall>(arg0);
}
BuiltinOp MaximumCall::s_op = BuiltinOp("maximum");
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MaximumCall>(arg0, arg1);
}
BuiltinOp MinimumCall::s_op = BuiltinOp("minimum");
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MinimumCall>(arg0, arg1);
}
BuiltinOp MultiplyCall::s_op = BuiltinOp("multiply");
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MultiplyCall>(arg0, arg1);
}
BuiltinOp NegateCall::s_op = BuiltinOp("negate");
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{
return make_shared<NegateCall>(arg0);
......@@ -191,8 +165,6 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
// 'pad',
// 'parameter',
BuiltinOp PowerCall::s_op = BuiltinOp("power");
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<PowerCall>(arg0, arg1);
......@@ -200,15 +172,11 @@ Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
//'reduce',
BuiltinOp RemainderCall::s_op = BuiltinOp("remainder");
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<RemainderCall>(arg0, arg1);
}
BuiltinOp ReshapeCall::s_op = BuiltinOp("reshape");
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{
return make_shared<ReshapeCall>(arg0, shape);
......@@ -219,8 +187,6 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
// 'select',
//'slice',
BuiltinOp SubtractCall::s_op = BuiltinOp("subtract");
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<SubtractCall>(arg0, arg1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment