Commit b0e1e076 authored by Scott Cyphers's avatar Scott Cyphers

Call -> Op, failing test for pattern match.

parent db6e3052
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <iostream>
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
...@@ -62,6 +64,15 @@ namespace ngraph ...@@ -62,6 +64,15 @@ namespace ngraph
std::string name() const { return m_name; } std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; } void name(const std::string& name) { m_name = name; }
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
** TODO: typeids are Node*, doc says they should be the actual classes.
**/
bool has_same_op(const Node::ptr& node) { return typeid(this) == typeid(node.get()); }
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
......
...@@ -65,32 +65,40 @@ namespace ngraph ...@@ -65,32 +65,40 @@ namespace ngraph
} }
/** /**
** Call nodes are nodes whose value is the result of some operation, the op, ** Op nodes are nodes whose value is the result of some operation
** applied to its arguments. We use the op as a callable to construct the ** applied to its arguments. For calls to user functions, the op will
** call nodes. For calls to user functions, the op will be the user function. ** reference the user function.
**/ **/
class Call : public Node class Op : public Node
{ {
public: public:
Call(const std::vector<Node::ptr>& arguments) Op(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr) : Node(arguments, nullptr)
{ {
} }
};
/** /**
** Return true if this has the same implementing class as call. This ** A FunctionOp invokes a function on node arguments. In addition to the argument
** will be used by the pattern matcher when comparing a pattern ** we need to preserve the function.
** graph against the graph.
**/ **/
bool has_same_op(Call& call) { return typeid(this) == typeid(&call); } class FunctionOp : public Op
virtual std::string description() const override { return "Call"; } {
virtual std::string description() const override { return "FunctionOp"; }
protected:
Node::ptr m_function;
}; };
class BuiltinCall : public Call /**
** The is an operation we handle directly, i.e. all type checking, etc.
** are defined in C++ rather than in terms of ngraph operations.
**/
class BuiltinOp : public Op
{ {
public: public:
virtual std::string description() const override { return "BuiltinCall"; } virtual std::string description() const override { return "BuiltinOp"; }
/// Name of the builtin op, for debugging and logging. /// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0; virtual std::string op_name() const = 0;
...@@ -98,17 +106,17 @@ namespace ngraph ...@@ -98,17 +106,17 @@ namespace ngraph
virtual void propagate_types() override {} virtual void propagate_types() override {}
protected: protected:
BuiltinCall(const std::vector<Node::ptr>& args) BuiltinOp(const std::vector<Node::ptr>& args)
: Call(args) : Op(args)
{ {
} }
}; };
class AbsCall : public BuiltinCall class AbsOp : public BuiltinOp
{ {
public: public:
AbsCall(const Node::ptr& arg0) AbsOp(const Node::ptr& arg0)
: BuiltinCall({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -116,18 +124,18 @@ namespace ngraph ...@@ -116,18 +124,18 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class AddCall : public BuiltinCall class AddOp : public BuiltinOp
{ {
public: public:
AddCall(const Node::ptr& arg0, const Node::ptr& arg1) AddOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_name() const override { return "add"; } virtual std::string op_name() const override { return "add"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class BroadcastCall : public BuiltinCall class BroadcastOp : public BuiltinOp
{ {
public: public:
/** /**
...@@ -136,8 +144,8 @@ namespace ngraph ...@@ -136,8 +144,8 @@ namespace ngraph
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast. ** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg. ** the remaining axes in shape must be the same as the shape of arg.
**/ **/
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes) BroadcastOp(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall({arg}) : BuiltinOp({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
...@@ -151,11 +159,11 @@ namespace ngraph ...@@ -151,11 +159,11 @@ namespace ngraph
std::vector<size_t> m_broadcast_axes; std::vector<size_t> m_broadcast_axes;
}; };
class CeilingCall : public BuiltinCall class CeilingOp : public BuiltinOp
{ {
public: public:
CeilingCall(const Node::ptr& arg0, const Node::ptr& arg1) CeilingOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -163,11 +171,11 @@ namespace ngraph ...@@ -163,11 +171,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class DivideCall : public BuiltinCall class DivideOp : public BuiltinOp
{ {
public: public:
DivideCall(const Node::ptr& arg0, const Node::ptr& arg1) DivideOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -175,12 +183,12 @@ namespace ngraph ...@@ -175,12 +183,12 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class DotCall : public BuiltinCall class DotOp : public BuiltinOp
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1) DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -188,11 +196,11 @@ namespace ngraph ...@@ -188,11 +196,11 @@ namespace ngraph
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
class EqualCall : public BuiltinCall class EqualOp : public BuiltinOp
{ {
public: public:
EqualCall(const Node::ptr& arg0, const Node::ptr& arg1) EqualOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -200,11 +208,11 @@ namespace ngraph ...@@ -200,11 +208,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ExponentialCall : public BuiltinCall class ExponentialOp : public BuiltinOp
{ {
public: public:
ExponentialCall(const Node::ptr& arg0) ExponentialOp(const Node::ptr& arg0)
: BuiltinCall({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -212,11 +220,11 @@ namespace ngraph ...@@ -212,11 +220,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class FloorCall : public BuiltinCall class FloorOp : public BuiltinOp
{ {
public: public:
FloorCall(const Node::ptr& arg0, const Node::ptr& arg1) FloorOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -224,11 +232,11 @@ namespace ngraph ...@@ -224,11 +232,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class GreaterCall : public BuiltinCall class GreaterOp : public BuiltinOp
{ {
public: public:
GreaterCall(const Node::ptr& arg0, const Node::ptr& arg1) GreaterOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -236,11 +244,11 @@ namespace ngraph ...@@ -236,11 +244,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class LessCall : public BuiltinCall class LessOp : public BuiltinOp
{ {
public: public:
LessCall(const Node::ptr& arg0, const Node::ptr& arg1) LessOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -248,11 +256,11 @@ namespace ngraph ...@@ -248,11 +256,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class LogCall : public BuiltinCall class LogOp : public BuiltinOp
{ {
public: public:
LogCall(const Node::ptr& arg0) LogOp(const Node::ptr& arg0)
: BuiltinCall({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -260,11 +268,11 @@ namespace ngraph ...@@ -260,11 +268,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MaximumCall : public BuiltinCall class MaximumOp : public BuiltinOp
{ {
public: public:
MaximumCall(const Node::ptr& arg0, const Node::ptr& arg1) MaximumOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -272,11 +280,11 @@ namespace ngraph ...@@ -272,11 +280,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MinimumCall : public BuiltinCall class MinimumOp : public BuiltinOp
{ {
public: public:
MinimumCall(const Node::ptr& arg0, const Node::ptr& arg1) MinimumOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -284,11 +292,11 @@ namespace ngraph ...@@ -284,11 +292,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class MultiplyCall : public BuiltinCall class MultiplyOp : public BuiltinOp
{ {
public: public:
MultiplyCall(const Node::ptr& arg0, const Node::ptr& arg1) MultiplyOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -296,11 +304,11 @@ namespace ngraph ...@@ -296,11 +304,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class NegateCall : public BuiltinCall class NegateOp : public BuiltinOp
{ {
public: public:
NegateCall(const Node::ptr& arg0) NegateOp(const Node::ptr& arg0)
: BuiltinCall({arg0}) : BuiltinOp({arg0})
{ {
} }
...@@ -308,11 +316,11 @@ namespace ngraph ...@@ -308,11 +316,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class PowerCall : public BuiltinCall class PowerOp : public BuiltinOp
{ {
public: public:
PowerCall(const Node::ptr& arg0, const Node::ptr& arg1) PowerOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -320,11 +328,11 @@ namespace ngraph ...@@ -320,11 +328,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class RemainderCall : public BuiltinCall class RemainderOp : public BuiltinOp
{ {
public: public:
RemainderCall(const Node::ptr& arg0, const Node::ptr& arg1) RemainderOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
...@@ -332,11 +340,11 @@ namespace ngraph ...@@ -332,11 +340,11 @@ namespace ngraph
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
class ReshapeCall : public BuiltinCall class ReshapeOp : public BuiltinOp
{ {
public: public:
ReshapeCall(const Node::ptr& arg0, const Shape& shape) ReshapeOp(const Node::ptr& arg0, const Shape& shape)
: BuiltinCall({arg0}) : BuiltinOp({arg0})
, m_shape(shape) , m_shape(shape)
{ {
} }
...@@ -347,11 +355,11 @@ namespace ngraph ...@@ -347,11 +355,11 @@ namespace ngraph
Shape m_shape; Shape m_shape;
}; };
class SubtractCall : public BuiltinCall class SubtractOp : public BuiltinOp
{ {
public: public:
SubtractCall(const Node::ptr& arg0, const Node::ptr& arg1) SubtractOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
......
...@@ -21,12 +21,12 @@ using namespace std; ...@@ -21,12 +21,12 @@ using namespace std;
Node::ptr ngraph::op::abs(const Node::ptr& arg) Node::ptr ngraph::op::abs(const Node::ptr& arg)
{ {
return make_shared<AbsCall>(arg); return make_shared<AbsOp>(arg);
} }
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<AddCall>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
/** /**
...@@ -39,10 +39,10 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, ...@@ -39,10 +39,10 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const vector<size_t>& broadcast_axes) const vector<size_t>& broadcast_axes)
{ {
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes); return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
} }
void BroadcastCall::propagate_types() void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->type(); auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type) if (nullptr == arg_type)
...@@ -70,7 +70,7 @@ void BroadcastCall::propagate_types() ...@@ -70,7 +70,7 @@ void BroadcastCall::propagate_types()
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<CeilingCall>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
// 'concatenate', // 'concatenate',
...@@ -80,16 +80,16 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -80,16 +80,16 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DivideCall>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DotCall>(arg0, arg1); return make_shared<DotOp>(arg0, arg1);
} }
void DotCall::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>(); auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>();
auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>(); auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>();
...@@ -129,37 +129,37 @@ void DotCall::propagate_types() ...@@ -129,37 +129,37 @@ void DotCall::propagate_types()
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{ {
return make_shared<ExponentialCall>(arg0); return make_shared<ExponentialOp>(arg0);
} }
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<FloorCall>(arg0, arg1); return make_shared<FloorOp>(arg0, arg1);
} }
Node::ptr ngraph::op::log(const Node::ptr& arg0) Node::ptr ngraph::op::log(const Node::ptr& arg0)
{ {
return make_shared<LogCall>(arg0); return make_shared<LogOp>(arg0);
} }
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MaximumCall>(arg0, arg1); return make_shared<MaximumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MinimumCall>(arg0, arg1); return make_shared<MinimumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MultiplyCall>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
Node::ptr ngraph::op::negate(const Node::ptr& arg0) Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{ {
return make_shared<NegateCall>(arg0); return make_shared<NegateOp>(arg0);
} }
// 'pad', // 'pad',
...@@ -167,19 +167,19 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0) ...@@ -167,19 +167,19 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<PowerCall>(arg0, arg1); return make_shared<PowerOp>(arg0, arg1);
} }
//'reduce', //'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<RemainderCall>(arg0, arg1); return make_shared<RemainderOp>(arg0, arg1);
} }
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{ {
return make_shared<ReshapeCall>(arg0, shape); return make_shared<ReshapeOp>(arg0, shape);
} }
//'reverse', //'reverse',
...@@ -189,7 +189,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) ...@@ -189,7 +189,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<SubtractCall>(arg0, arg1); return make_shared<SubtractOp>(arg0, arg1);
} }
// 'transpose', // 'transpose',
......
...@@ -60,3 +60,34 @@ TEST(build_graph, as_type) ...@@ -60,3 +60,34 @@ TEST(build_graph, as_type)
TupleType* tp_tp = tp_vt->as<TupleType*>(); TupleType* tp_tp = tp_vt->as<TupleType*>();
ASSERT_EQ(tp_vt.get(), tp_tp); ASSERT_EQ(tp_vt.get(), tp_tp);
} }
// Check Call comparisons
TEST(DISABLED_build_graph, call_comparison)
{
auto fun = make_shared<Function>(3);
fun->parameter(0)->type(element::float32_t, {32, 3});
fun->parameter(1)->type(element::float32_t, {3});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1);
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->has_same_op(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->has_same_op(add));
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment