Commit b0e1e076 authored by Scott Cyphers's avatar Scott Cyphers

Call -> Op, failing test for pattern match.

parent db6e3052
......@@ -18,6 +18,8 @@
#include <string>
#include <vector>
#include <iostream>
#include "ngraph/type.hpp"
namespace ngraph
......@@ -62,6 +64,15 @@ namespace ngraph
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
** TODO: typeids are Node*, doc says they should be the actual classes.
**/
bool has_same_op(const Node::ptr& node) { return typeid(this) == typeid(node.get()); }
protected:
std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users;
......
......@@ -65,32 +65,40 @@ namespace ngraph
}
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
** call nodes. For calls to user functions, the op will be the user function.
** Op nodes are nodes whose value is the result of some operation
** applied to its arguments. For calls to user functions, the op will
** reference the user function.
**/
class Call : public Node
class Op : public Node
{
public:
Call(const std::vector<Node::ptr>& arguments)
Op(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
{
}
};
/**
** Return true if this has the same implementing class as call. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
**/
bool has_same_op(Call& call) { return typeid(this) == typeid(&call); }
virtual std::string description() const override { return "Call"; }
/**
** A FunctionOp invokes a function on node arguments. In addition to the argument
** we need to preserve the function.
**/
class FunctionOp : public Op
{
virtual std::string description() const override { return "FunctionOp"; }
protected:
Node::ptr m_function;
};
class BuiltinCall : public Call
/**
** The is an operation we handle directly, i.e. all type checking, etc.
** are defined in C++ rather than in terms of ngraph operations.
**/
class BuiltinOp : public Op
{
public:
virtual std::string description() const override { return "BuiltinCall"; }
virtual std::string description() const override { return "BuiltinOp"; }
/// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
......@@ -98,17 +106,17 @@ namespace ngraph
virtual void propagate_types() override {}
protected:
BuiltinCall(const std::vector<Node::ptr>& args)
: Call(args)
BuiltinOp(const std::vector<Node::ptr>& args)
: Op(args)
{
}
};
class AbsCall : public BuiltinCall
class AbsOp : public BuiltinOp
{
public:
AbsCall(const Node::ptr& arg0)
: BuiltinCall({arg0})
AbsOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
......@@ -116,18 +124,18 @@ namespace ngraph
//virtual void propagate_types() override;
};
class AddCall : public BuiltinCall
class AddOp : public BuiltinOp
{
public:
AddCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
AddOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "add"; }
//virtual void propagate_types() override;
};
class BroadcastCall : public BuiltinCall
class BroadcastOp : public BuiltinOp
{
public:
/**
......@@ -136,8 +144,8 @@ namespace ngraph
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall({arg})
BroadcastOp(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
......@@ -151,11 +159,11 @@ namespace ngraph
std::vector<size_t> m_broadcast_axes;
};
class CeilingCall : public BuiltinCall
class CeilingOp : public BuiltinOp
{
public:
CeilingCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
CeilingOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -163,11 +171,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class DivideCall : public BuiltinCall
class DivideOp : public BuiltinOp
{
public:
DivideCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
DivideOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -175,12 +183,12 @@ namespace ngraph
//virtual void propagate_types() override;
};
class DotCall : public BuiltinCall
class DotOp : public BuiltinOp
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -188,11 +196,11 @@ namespace ngraph
virtual void propagate_types() override;
};
class EqualCall : public BuiltinCall
class EqualOp : public BuiltinOp
{
public:
EqualCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
EqualOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -200,11 +208,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class ExponentialCall : public BuiltinCall
class ExponentialOp : public BuiltinOp
{
public:
ExponentialCall(const Node::ptr& arg0)
: BuiltinCall({arg0})
ExponentialOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
......@@ -212,11 +220,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class FloorCall : public BuiltinCall
class FloorOp : public BuiltinOp
{
public:
FloorCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
FloorOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -224,11 +232,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class GreaterCall : public BuiltinCall
class GreaterOp : public BuiltinOp
{
public:
GreaterCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
GreaterOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -236,11 +244,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class LessCall : public BuiltinCall
class LessOp : public BuiltinOp
{
public:
LessCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
LessOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -248,11 +256,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class LogCall : public BuiltinCall
class LogOp : public BuiltinOp
{
public:
LogCall(const Node::ptr& arg0)
: BuiltinCall({arg0})
LogOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
......@@ -260,11 +268,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class MaximumCall : public BuiltinCall
class MaximumOp : public BuiltinOp
{
public:
MaximumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
MaximumOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -272,11 +280,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class MinimumCall : public BuiltinCall
class MinimumOp : public BuiltinOp
{
public:
MinimumCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
MinimumOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -284,11 +292,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class MultiplyCall : public BuiltinCall
class MultiplyOp : public BuiltinOp
{
public:
MultiplyCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
MultiplyOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -296,11 +304,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class NegateCall : public BuiltinCall
class NegateOp : public BuiltinOp
{
public:
NegateCall(const Node::ptr& arg0)
: BuiltinCall({arg0})
NegateOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
......@@ -308,11 +316,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class PowerCall : public BuiltinCall
class PowerOp : public BuiltinOp
{
public:
PowerCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
PowerOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -320,11 +328,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class RemainderCall : public BuiltinCall
class RemainderOp : public BuiltinOp
{
public:
RemainderCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
RemainderOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......@@ -332,11 +340,11 @@ namespace ngraph
//virtual void propagate_types() override;
};
class ReshapeCall : public BuiltinCall
class ReshapeOp : public BuiltinOp
{
public:
ReshapeCall(const Node::ptr& arg0, const Shape& shape)
: BuiltinCall({arg0})
ReshapeOp(const Node::ptr& arg0, const Shape& shape)
: BuiltinOp({arg0})
, m_shape(shape)
{
}
......@@ -347,11 +355,11 @@ namespace ngraph
Shape m_shape;
};
class SubtractCall : public BuiltinCall
class SubtractOp : public BuiltinOp
{
public:
SubtractCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall({arg0, arg1})
SubtractOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
......
......@@ -21,12 +21,12 @@ using namespace std;
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsCall>(arg);
return make_shared<AbsOp>(arg);
}
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<AddCall>(arg0, arg1);
return make_shared<AddOp>(arg0, arg1);
}
/**
......@@ -39,10 +39,10 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes);
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastCall::propagate_types()
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
......@@ -70,7 +70,7 @@ void BroadcastCall::propagate_types()
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<CeilingCall>(arg0, arg1);
return make_shared<CeilingOp>(arg0, arg1);
}
// 'concatenate',
......@@ -80,16 +80,16 @@ Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DivideCall>(arg0, arg1);
return make_shared<DivideOp>(arg0, arg1);
}
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotCall>(arg0, arg1);
return make_shared<DotOp>(arg0, arg1);
}
void DotCall::propagate_types()
void DotOp::propagate_types()
{
auto arg0_tensor_type = m_arguments.at(0)->type()->as<TensorViewType*>();
auto arg1_tensor_type = m_arguments.at(1)->type()->as<TensorViewType*>();
......@@ -129,37 +129,37 @@ void DotCall::propagate_types()
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialCall>(arg0);
return make_shared<ExponentialOp>(arg0);
}
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<FloorCall>(arg0, arg1);
return make_shared<FloorOp>(arg0, arg1);
}
Node::ptr ngraph::op::log(const Node::ptr& arg0)
{
return make_shared<LogCall>(arg0);
return make_shared<LogOp>(arg0);
}
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MaximumCall>(arg0, arg1);
return make_shared<MaximumOp>(arg0, arg1);
}
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MinimumCall>(arg0, arg1);
return make_shared<MinimumOp>(arg0, arg1);
}
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MultiplyCall>(arg0, arg1);
return make_shared<MultiplyOp>(arg0, arg1);
}
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{
return make_shared<NegateCall>(arg0);
return make_shared<NegateOp>(arg0);
}
// 'pad',
......@@ -167,19 +167,19 @@ Node::ptr ngraph::op::negate(const Node::ptr& arg0)
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<PowerCall>(arg0, arg1);
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<RemainderCall>(arg0, arg1);
return make_shared<RemainderOp>(arg0, arg1);
}
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{
return make_shared<ReshapeCall>(arg0, shape);
return make_shared<ReshapeOp>(arg0, shape);
}
//'reverse',
......@@ -189,7 +189,7 @@ Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<SubtractCall>(arg0, arg1);
return make_shared<SubtractOp>(arg0, arg1);
}
// 'transpose',
......
......@@ -60,3 +60,34 @@ TEST(build_graph, as_type)
TupleType* tp_tp = tp_vt->as<TupleType*>();
ASSERT_EQ(tp_vt.get(), tp_tp);
}
// Check Call comparisons
TEST(DISABLED_build_graph, call_comparison)
{
auto fun = make_shared<Function>(3);
fun->parameter(0)->type(element::float32_t, {32, 3});
fun->parameter(1)->type(element::float32_t, {3});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1);
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->has_same_op(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->has_same_op(add));
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment