Commit b0e1e076 authored by Scott Cyphers's avatar Scott Cyphers

Call -> Op, failing test for pattern match.

parent db6e3052
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <iostream>
#include "ngraph/type.hpp" #include "ngraph/type.hpp"
namespace ngraph namespace ngraph
...@@ -62,6 +64,15 @@ namespace ngraph ...@@ -62,6 +64,15 @@ namespace ngraph
std::string name() const { return m_name; } std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; } void name(const std::string& name) { m_name = name; }
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
** TODO: typeids are Node*, doc says they should be the actual classes.
**/
bool has_same_op(const Node::ptr& node) { return typeid(this) == typeid(node.get()); }
protected: protected:
std::vector<Node::ptr> m_arguments; std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
......
This diff is collapsed.
...@@ -21,12 +21,12 @@ using namespace std; ...@@ -21,12 +21,12 @@ using namespace std;
Node::ptr ngraph::op::abs(const Node::ptr& arg) Node::ptr ngraph::op::abs(const Node::ptr& arg)
{ {
return make_shared<AbsCall>(arg); return make_shared<AbsOp>(arg);
} }
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<AddCall>(arg0, arg1); return make_shared<AddOp>(arg0, arg1);
} }
/** /**
...@@ -39,10 +39,10 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, ...@@ -39,10 +39,10 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape, const Shape& shape,
const vector<size_t>& broadcast_axes) const vector<size_t>& broadcast_axes)
{ {
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes); return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
} }
void BroadcastCall::propagate_types() void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->type(); auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type) if (nullptr == arg_type)
...@@ -70,7 +70,7 @@ void BroadcastCall::propagate_types() ...@@ -70,7 +70,7 @@ void BroadcastCall::propagate_types()
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<CeilingCall>(arg0, arg1); return make_shared<CeilingOp>(arg0, arg1);
} }
// 'concatenate', // 'concatenate',
...@@ -80,16 +80,16 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -80,16 +80,16 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DivideCall>(arg0, arg1); return make_shared<DivideOp>(arg0, arg1);
} }
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<DotCall>(arg0, arg1); return make_shared<DotOp>(arg0, arg1);
} }
void DotCall::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>(); auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>();
auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>(); auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>();
...@@ -129,37 +129,37 @@ void DotCall::propagate_types() ...@@ -129,37 +129,37 @@ void DotCall::propagate_types()
Node::ptr ngraph::op::exponential(const Node::ptr& arg0) Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{ {
return make_shared<ExponentialCall>(arg0); return make_shared<ExponentialOp>(arg0);
} }
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<FloorCall>(arg0, arg1); return make_shared<FloorOp>(arg0, arg1);
} }
Node::ptr ngraph::op::log(const Node::ptr& arg0) Node::ptr ngraph::op::log(const Node::ptr& arg0)
{ {
return make_shared<LogCall>(arg0); return make_shared<LogOp>(arg0);
} }
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MaximumCall>(arg0, arg1); return make_shared<MaximumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MinimumCall>(arg0, arg1); return make_shared<MinimumOp>(arg0, arg1);
} }
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<MultiplyCall>(arg0, arg1); return make_shared<MultiplyOp>(arg0, arg1);
} }
Node::ptr ngraph::op::negate(const Node::ptr& arg0) Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{ {
return make_shared<NegateCall>(arg0); return make_shared<NegateOp>(arg0);
} }
// 'pad', // 'pad',
...@@ -167,19 +167,19 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0) ...@@ -167,19 +167,19 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<PowerCall>(arg0, arg1); return make_shared<PowerOp>(arg0, arg1);
} }
//'reduce', //'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<RemainderCall>(arg0, arg1); return make_shared<RemainderOp>(arg0, arg1);
} }
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{ {
return make_shared<ReshapeCall>(arg0, shape); return make_shared<ReshapeOp>(arg0, shape);
} }
//'reverse', //'reverse',
...@@ -189,7 +189,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape) ...@@ -189,7 +189,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1) Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{ {
return make_shared<SubtractCall>(arg0, arg1); return make_shared<SubtractOp>(arg0, arg1);
} }
// 'transpose', // 'transpose',
......
...@@ -60,3 +60,34 @@ TEST(build_graph, as_type) ...@@ -60,3 +60,34 @@ TEST(build_graph, as_type)
TupleType* tp_tp = tp_vt->as<TupleType*>(); TupleType* tp_tp = tp_vt->as<TupleType*>();
ASSERT_EQ(tp_vt.get(), tp_tp); ASSERT_EQ(tp_vt.get(), tp_tp);
} }
// Check Call comparisons
TEST(DISABLED_build_graph, call_comparison)
{
auto fun = make_shared<Function>(3);
fun->parameter(0)->type(element::float32_t, {32, 3});
fun->parameter(1)->type(element::float32_t, {3});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1);
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->has_same_op(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->has_same_op(add));
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment