Commit 2c30e819 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #63 from NervanaSystems/cyphers/mnist

Two shape propagates/checks, bulk of ops.
parents 8ade5867 67163443
......@@ -39,6 +39,7 @@ namespace ngraph
}
bool operator==(const Type& other) const;
bool operator!=(const Type& other) const { return !(*this == other); }
private:
static std::map<std::string, Type> m_element_list;
......@@ -48,7 +49,7 @@ namespace ngraph
const std::string m_cname;
};
const Type float32_t= Type(32, true, true, "float");
const Type float32_t = Type(32, true, true, "float");
const Type int8_t = Type(8, false, true, "int8_t");
const Type int32_t = Type(32, false, true, "int32_t");
const Type int64_t = Type(64, false, true, "int64_t");
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <stdexcept>
namespace ngraph
{
/// Base error for ngraph runtime errors.
struct ngraph_error : std::runtime_error
{
explicit ngraph_error(const std::string& what_arg)
: std::runtime_error(what_arg)
{
}
explicit ngraph_error(const char* what_arg)
: std::runtime_error(what_arg)
{
}
};
}
......@@ -35,6 +35,8 @@ namespace ngraph
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
protected:
Function& m_function;
size_t m_index;
......@@ -59,7 +61,7 @@ namespace ngraph
/**
** A user-defined function.
**/
class Function : public Op
class Function
{
public:
Function(size_t n_parameters);
......@@ -68,7 +70,7 @@ namespace ngraph
Parameter::ptr parameter(size_t i) { return m_parameters[i]; }
std::string name() const override { return m_name; }
std::string name() const { return m_name; }
protected:
std::vector<Parameter::ptr> m_parameters;
......
......@@ -19,6 +19,7 @@
#pragma once
#include "ngraph/element_type.hpp"
#include "ngraph/except.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op.hpp"
......
......@@ -18,6 +18,8 @@
#include <string>
#include <vector>
#include <iostream>
#include "ngraph/type.hpp"
namespace ngraph
......@@ -53,17 +55,22 @@ namespace ngraph
virtual std::string description() const = 0;
/// Propagate types and check arguments for consistency
// virtual void propagate_types() = 0;
virtual void propagate_types() = 0;
const std::vector<Node::ptr> arguments() const { return m_arguments; }
std::vector<Node::ptr> arguments() { return m_arguments; }
const std::vector<Node::ptr>& arguments() const { return m_arguments; }
const std::multiset<Node*> users() const { return m_users; }
std::multiset<Node*> users() { return m_users; }
const std::multiset<Node*>& users() const { return m_users; }
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
/**
** Return true if this has the same implementing class as node. This
** will be used by the pattern matcher when comparing a pattern
** graph against the graph.
**/
bool is_same_op_type(const Node::ptr& node) const { return typeid(*this) == typeid(*node.get()); }
protected:
std::vector<Node::ptr> m_arguments;
std::multiset<Node*> m_users;
......
......@@ -21,82 +21,121 @@
namespace ngraph
{
/**
** Every instance of Op corresponds to a unique defined operation.
**/
class Op
namespace op
{
protected:
virtual ~Op() {}
Node::ptr abs(const Node::ptr& arg);
Node::ptr add(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
public:
virtual std::string name() const = 0;
};
//Node::ptr candidate();
Node::ptr ceiling(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr concatenate();
//Node::ptr constant();
//Node::ptr convert();
//Node::ptr convolution();
Node::ptr divide(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr dot(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr equal(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr exponential(const Node::ptr& arg0);
Node::ptr floor(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr get();
Node::ptr greater(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr less(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr log(const Node::ptr& arg0);
//Node::ptr logical();
Node::ptr maximum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr minimum(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr multiply(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr negate(const Node::ptr& arg0);
//Node::ptr pad();
Node::ptr power(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr reduce();
Node::ptr remainder(const Node::ptr& arg0, const Node::ptr& arg1);
Node::ptr reshape(const Node::ptr& arg0, const Shape& shape);
//Node::ptr reverse();
//Node::ptr rng();
//Node::ptr select();
//Node::ptr slice();
Node::ptr subtract(const Node::ptr& arg0, const Node::ptr& arg1);
//Node::ptr transpose();
//Node::ptr tuple();
//Node::ptr while();
}
/**
** Call nodes are nodes whose value is the result of some operation, the op,
** applied to its arguments. We use the op as a callable to construct the
** call nodes. For calls to user functions, the op will be the user function.
** Op nodes are nodes whose value is the result of some operation
** applied to its arguments. For calls to user functions, the op will
** reference the user function.
**/
class Call : public Node
class Op : public Node
{
public:
std::shared_ptr<Op> op() const { return m_op; }
Call(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& arguments)
Op(const std::vector<Node::ptr>& arguments)
: Node(arguments, nullptr)
, m_op(op)
{
}
};
virtual std::string description() const override { return m_op->name(); }
/**
** A FunctionOp invokes a function on node arguments. In addition to the argument
** we need to preserve the function.
**/
class FunctionOp : public Op
{
virtual std::string description() const override { return "FunctionOp"; }
protected:
std::shared_ptr<Op> m_op;
Node::ptr m_function;
};
/**
** There is exactly one instance of builtin op for each pre-defined operation. These
** are intended to be used when matching calls in different graphs; every FooCall
** will have the same op.
** The is an operation we handle directly, i.e. all type checking, etc.
** are defined in C++ rather than in terms of ngraph operations.
**/
class BuiltinOp : public Op
{
friend class Call;
public:
BuiltinOp(const std::string& name)
: m_name(name)
{
}
virtual std::string description() const override { return "BuiltinOp"; }
/// Name of the builtin op, for debugging and logging.
virtual std::string op_name() const = 0;
public:
std::string name() const override { return m_name; }
// TODO: Implement for each op
virtual void propagate_types() override {}
protected:
std::string m_name;
BuiltinOp(const std::vector<Node::ptr>& args)
: Op(args)
{
}
};
class BuiltinCall : public Call
class AbsOp : public BuiltinOp
{
public:
virtual std::string description() const override { return "BuiltinCall"; }
protected:
BuiltinCall(const std::shared_ptr<Op>& op, const std::vector<Node::ptr>& args)
: Call(op, args)
AbsOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_name() const override { return "abs"; }
//virtual void propagate_types() override;
};
namespace op
class AddOp : public BuiltinOp
{
public:
AddOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
std::shared_ptr<Node> broadcast(const Node::ptr& tensor,
const Shape& shape,
const std::vector<size_t>& broadcast_axes);
}
virtual std::string op_name() const override { return "add"; }
//virtual void propagate_types() override;
};
class BroadcastCall : public BuiltinCall
class BroadcastOp : public BuiltinOp
{
public:
/**
......@@ -105,34 +144,226 @@ namespace ngraph
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
BroadcastCall(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinCall(s_op, {arg})
BroadcastOp(const Node::ptr& arg, const Shape& shape, std::vector<size_t> broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
virtual std::string op_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
Shape m_shape;
std::vector<size_t> m_broadcast_axes;
};
protected:
static std::shared_ptr<BuiltinOp> s_op;
class CeilingOp : public BuiltinOp
{
public:
CeilingOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "ceiling"; }
//virtual void propagate_types() override;
};
namespace op
class DivideOp : public BuiltinOp
{
public:
DivideOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
std::shared_ptr<Node> dot(const Node::ptr& arg0, const Node::ptr& arg1);
}
class DotCall : public BuiltinCall
virtual std::string op_name() const override { return "divide"; }
//virtual void propagate_types() override;
};
class DotOp : public BuiltinOp
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotCall(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinCall(s_op, {arg0, arg1})
DotOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "dot"; }
virtual void propagate_types() override;
};
class EqualOp : public BuiltinOp
{
public:
EqualOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "equal"; }
//virtual void propagate_types() override;
};
class ExponentialOp : public BuiltinOp
{
public:
ExponentialOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_name() const override { return "exp"; }
//virtual void propagate_types() override;
};
class FloorOp : public BuiltinOp
{
public:
FloorOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "floor"; }
//virtual void propagate_types() override;
};
class GreaterOp : public BuiltinOp
{
public:
GreaterOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "greater"; }
//virtual void propagate_types() override;
};
class LessOp : public BuiltinOp
{
public:
LessOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "less"; }
//virtual void propagate_types() override;
};
class LogOp : public BuiltinOp
{
public:
LogOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_name() const override { return "log"; }
//virtual void propagate_types() override;
};
class MaximumOp : public BuiltinOp
{
public:
MaximumOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "max"; }
//virtual void propagate_types() override;
};
class MinimumOp : public BuiltinOp
{
public:
MinimumOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "min"; }
//virtual void propagate_types() override;
};
class MultiplyOp : public BuiltinOp
{
public:
MultiplyOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "multiply"; }
//virtual void propagate_types() override;
};
class NegateOp : public BuiltinOp
{
public:
NegateOp(const Node::ptr& arg0)
: BuiltinOp({arg0})
{
}
virtual std::string op_name() const override { return "negate"; }
//virtual void propagate_types() override;
};
class PowerOp : public BuiltinOp
{
public:
PowerOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "power"; }
//virtual void propagate_types() override;
};
class RemainderOp : public BuiltinOp
{
public:
RemainderOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "remainder"; }
//virtual void propagate_types() override;
};
class ReshapeOp : public BuiltinOp
{
public:
ReshapeOp(const Node::ptr& arg0, const Shape& shape)
: BuiltinOp({arg0})
, m_shape(shape)
{
}
virtual std::string op_name() const override { return "reshape"; }
//virtual void propagate_types() override;
protected:
static std::shared_ptr<BuiltinOp> s_op;
Shape m_shape;
};
class SubtractOp : public BuiltinOp
{
public:
SubtractOp(const Node::ptr& arg0, const Node::ptr& arg1)
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_name() const override { return "subtract"; }
//virtual void propagate_types() override;
};
}
......@@ -32,11 +32,19 @@ namespace ngraph
{
}
Shape(const std::vector<size_t>& sizes)
: m_sizes(sizes)
{
}
/**
** Conversion to a vector of sizes.
**/
operator const std::vector<size_t>&() const { return m_sizes; }
bool operator==(const Shape& shape) const { return m_sizes == shape.m_sizes; }
bool operator!=(const Shape& shape) const { return m_sizes != shape.m_sizes; }
protected:
std::vector<size_t> m_sizes;
};
......
......@@ -22,6 +22,9 @@
namespace ngraph
{
class TensorViewType;
class TupleType;
/**
** ValueType is
** TensorViewType
......@@ -34,6 +37,8 @@ namespace ngraph
** Preferred handle
**/
using ptr = std::shared_ptr<ValueType>;
virtual ~ValueType() {}
};
/**
......@@ -57,6 +62,9 @@ namespace ngraph
{
}
const element::Type& element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; }
protected:
const element::Type& m_element_type;
Shape m_shape;
......
......@@ -24,6 +24,14 @@ Parameter::Parameter(Function& function, size_t index)
{
}
void Parameter::propagate_types()
{
if (m_type == nullptr)
{
throw ngraph_error{"Unitialized parameter"};
}
}
Function::Function(size_t n_parameters)
: m_parameters(n_parameters)
, m_name("Function")
......
......@@ -12,12 +12,22 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "ngraph/ngraph.hpp"
using namespace ngraph;
using namespace std;
std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadcast");
Node::ptr ngraph::op::abs(const Node::ptr& arg)
{
return make_shared<AbsOp>(arg);
}
Node::ptr ngraph::op::add(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<AddOp>(arg0, arg1);
}
/**
** /param arg The tensor view to be broadcast.
......@@ -25,17 +35,163 @@ std::shared_ptr<BuiltinOp> BroadcastCall::s_op = make_shared<BuiltinOp>("broadca
** /param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
** the remaining axes in shape must be the same as the shape of arg.
**/
shared_ptr<Node> ngraph::op::broadcast(const Node::ptr& tensor,
Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
const Shape& shape,
const vector<size_t>& broadcast_axes)
{
return make_shared<BroadcastCall>(tensor, shape, broadcast_axes);
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to broadcast is not a tensor view");
}
vector<size_t> target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
std::shared_ptr<BuiltinOp> DotCall::s_op = make_shared<BuiltinOp>("dot");
Node::ptr ngraph::op::ceiling(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
// 'concatenate',
// 'constant',
// 'convert',
// 'convolution',
Node::ptr ngraph::op::divide(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DivideOp>(arg0, arg1);
}
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
shared_ptr<Node> ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
{
arg1_reduction = arg1_shape.size() - 2;
}
else
{
arg1_reduction = arg1_shape.size() - 1;
}
if (arg0_shape.at(arg0_reduction) != arg1_shape.at(arg1_reduction))
{
throw ngraph_error("Dot reduction axes not compatible");
}
vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
Node::ptr ngraph::op::exponential(const Node::ptr& arg0)
{
return make_shared<ExponentialOp>(arg0);
}
Node::ptr ngraph::op::floor(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<FloorOp>(arg0, arg1);
}
Node::ptr ngraph::op::log(const Node::ptr& arg0)
{
return make_shared<DotCall>(arg0, arg1);
return make_shared<LogOp>(arg0);
}
Node::ptr ngraph::op::maximum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MaximumOp>(arg0, arg1);
}
Node::ptr ngraph::op::minimum(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MinimumOp>(arg0, arg1);
}
Node::ptr ngraph::op::multiply(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<MultiplyOp>(arg0, arg1);
}
Node::ptr ngraph::op::negate(const Node::ptr& arg0)
{
return make_shared<NegateOp>(arg0);
}
// 'pad',
// 'parameter',
Node::ptr ngraph::op::power(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
Node::ptr ngraph::op::remainder(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<RemainderOp>(arg0, arg1);
}
Node::ptr ngraph::op::reshape(const Node::ptr& arg0, const Shape& shape)
{
return make_shared<ReshapeOp>(arg0, shape);
}
//'reverse',
//'rng',
// 'select',
//'slice',
Node::ptr ngraph::op::subtract(const Node::ptr& arg0, const Node::ptr& arg1)
{
return make_shared<SubtractOp>(arg0, arg1);
}
// 'transpose',
//'tuple',
// 'while'
......@@ -19,7 +19,7 @@
using namespace std;
using namespace ngraph;
TEST(DISABLED_graph, build_simple)
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto cluster_0 = make_shared<Function>(4);
......@@ -29,11 +29,9 @@ TEST(DISABLED_graph, build_simple)
cluster_0->parameter(2)->type(element::float32_t, {32, 7});
cluster_0->parameter(3)->type(element::float32_t, {32, 7});
auto arg3 = cluster_0->parameter(3);
// call broadcast op on arg3, broadcasting on axis 0.
auto broadcast_1 = op::broadcast(arg3, {10, 32, 7}, {0});
auto arg2 = cluster_0->parameter(2);
auto arg0 = cluster_0->parameter(0);
// call dot op
auto dot = op::dot(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
......@@ -42,3 +40,52 @@ TEST(DISABLED_graph, build_simple)
ASSERT_EQ(cluster_0->result()->value(), dot);
}
// Check upcasting from ValueType.
TEST(build_graph, as_type)
{
// Check upcasting a ValueType::ptr that is a TensorViewType to a TensorViewType and Tuple.
ValueType::ptr tv_vt = make_shared<TensorViewType>(element::float32_t, Shape{2, 3, 5});
auto tv_tv = dynamic_pointer_cast<TensorViewType>(tv_vt);
ASSERT_EQ(tv_vt, tv_tv);
auto tv_tp = dynamic_pointer_cast<TupleType>(tv_vt);
ASSERT_EQ(nullptr, tv_tp);
// Check upcasting a ValueType::ptr that is a TupleType to a TensorViewType and Tuple.
ValueType::ptr tp_vt = make_shared<TupleType>(vector<ValueType::ptr>{tv_vt, tv_vt});
auto tp_tv = dynamic_pointer_cast<TensorViewType>(tp_vt);
ASSERT_EQ(nullptr, tp_tv);
auto tp_tp = dynamic_pointer_cast<TupleType>(tp_vt);
ASSERT_EQ(tp_vt, tp_tp);
}
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto fun = make_shared<Function>(3);
fun->parameter(0)->type(element::float32_t, {32, 3});
fun->parameter(1)->type(element::float32_t, {3});
fun->parameter(2)->type(element::float32_t, {32});
auto arg0 = fun->parameter(0);
auto arg1 = fun->parameter(1);
auto arg2 = fun->parameter(2);
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto pattern = make_shared<Function>(1);
pattern->parameter(0)->type(element::float32_t, {});
auto parg = pattern->parameter(0);
auto pattern_dot = op::dot(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->is_same_op_type(add));
}
// Check argument inverses
TEST(build_graph, arg_inverse)
{
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment