Commit b0e1e076 authored by Scott Cyphers's avatar Scott Cyphers

Call -> Op, failing test for pattern match.

parent db6e3052
......@@ -18,6 +18,8 @@
#include <string>
#include <vector>
#include <iostream>
#include "ngraph/type.hpp"
namespace ngraph
......@@ -62,6 +64,15 @@ namespace ngraph
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
** TODO: typeids are Node*, doc says they should be the actual classes.
**/
bool has_same_op(const Node::ptr& node) { return typeid(this) == typeid(node.get()); }
protected:
std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users;
......
This diff is collapsed.
......@@ -21,12 +21,12 @@ using namespace std;
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsCall>(arg);
return make_shared<AbsOp>(arg);
}
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<AddCall>(arg0, arg1);
return make_shared<AddOp>(arg0, arg1);
}
/**
......@@ -39,10 +39,10 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes);
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastCall::propagate_types()
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
......@@ -70,7 +70,7 @@ void BroadcastCall::propagate_types()
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<CeilingCall>(arg0, arg1);
return make_shared<CeilingOp>(arg0, arg1);
}
// 'concatenate',
......@@ -80,16 +80,16 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DivideCall>(arg0, arg1);
return make_shared<DivideOp>(arg0, arg1);
}
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotCall>(arg0, arg1);
return make_shared<DotOp>(arg0, arg1);
}
void DotCall::propagate_types()
void DotOp::propagate_types()
{
auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>();
auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>();
......@@ -129,37 +129,37 @@ void DotCall::propagate_types()
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialCall>(arg0);
return make_shared<ExponentialOp>(arg0);
}
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<FloorCall>(arg0, arg1);
return make_shared<FloorOp>(arg0, arg1);
}
Node::ptr ngraph::op::log(const Node::ptr& arg0)
{
return make_shared<LogCall>(arg0);
return make_shared<LogOp>(arg0);
}
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MaximumCall>(arg0, arg1);
return make_shared<MaximumOp>(arg0, arg1);
}
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MinimumCall>(arg0, arg1);
return make_shared<MinimumOp>(arg0, arg1);
}
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MultiplyCall>(arg0, arg1);
return make_shared<MultiplyOp>(arg0, arg1);
}
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{
return make_shared<NegateCall>(arg0);
return make_shared<NegateOp>(arg0);
}
// 'pad',
......@@ -167,19 +167,19 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<PowerCall>(arg0, arg1);
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<RemainderCall>(arg0, arg1);
return make_shared<RemainderOp>(arg0, arg1);
}
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{
return make_shared<ReshapeCall>(arg0, shape);
return make_shared<ReshapeOp>(arg0, shape);
}
//'reverse',
......@@ -189,7 +189,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<SubtractCall>(arg0, arg1);
return make_shared<SubtractOp>(arg0, arg1);
}
// 'transpose',
......
......@@ -60,3 +60,34 @@ TEST(build_graph, as_type)
TupleType* tp_tp = tp_vt->as<TupleType*>();
ASSERT_EQ(tp_vt.get(), tp_tp);
}
// Check Call comparisons
TEST(DISABLED_build_graph, call_comparison)
{
auto fun = make_shared<Function>(3);
fun->parameter(0)->type(element::float32_t, {32, 3});
fun->parameter(1)->type(element::float32_t, {3});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1);
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->has_same_op(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->has_same_op(add));
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment