Commit 36e36e7f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #78 from NervanaSystems/cyphers/morenames

De-use and cleanup op names.
parents 973b3a0e c7ef13f5
......@@ -22,14 +22,13 @@
namespace ngraph
{
class Node;
class Parameter;
class ValueType;
template <typename T, typename... A>
std::shared_ptr<T> node(A&&... args)
{
return std::make_shared<T>(args...);
namespace op {
class Parameter;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
}
class ValueType;
/// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
......@@ -42,7 +41,4 @@ namespace ngraph
/// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
}
......@@ -26,28 +26,18 @@ namespace ngraph
{
public:
Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
const std::vector<std::shared_ptr<op::Parameter>>& parameters);
std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<Parameter>> get_parameters() const
const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const
{
return m_parameters;
}
std::string get_name() const { return m_name; }
protected:
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters;
std::string m_name;
std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name;
};
namespace op
{
std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
}
}
......@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const
bool ngraph::Node::is_parameter() const
{
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr;
return dynamic_cast<const ngraph::op::Parameter*>(this) != nullptr;
}
namespace ngraph
......
This diff is collapsed.
......@@ -16,36 +16,32 @@
namespace ngraph
{
class BroadcastOp : public BuiltinOp
namespace op
{
public:
///
/// @param arg The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
BroadcastOp(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: BuiltinOp({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
class Broadcast : public Builtin
{
}
virtual std::string get_op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override;
public:
///
/// @param arg The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
///
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Builtin({arg})
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
}
protected:
Shape m_shape;
AxisSet m_broadcast_axes;
};
virtual std::string get_op_class_name() const override { return "Broadcast"; }
virtual void propagate_types() override;
namespace op
{
std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes);
protected:
Shape m_shape;
AxisSet m_broadcast_axes;
};
}
}
......@@ -18,18 +18,16 @@ namespace ngraph
{
namespace op
{
std::shared_ptr<Node> concatenate(const Nodes& args);
}
class ConcatOp : public BuiltinOp
{
public:
ConcatOp(const Nodes& args)
: BuiltinOp(args)
class Concat : public Builtin
{
}
public:
Concat(const Nodes& args)
: Builtin(args)
{
}
virtual std::string get_op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override;
};
virtual std::string get_op_class_name() const override { return "Concatenate"; }
virtual void propagate_types() override;
};
}
}
......@@ -20,54 +20,57 @@
namespace ngraph
{
// Defines methods to all constant scalars
class ScalarConstantBase : public Node
namespace op
{
protected:
ScalarConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
// Defines methods to all constant scalars
class ScalarConstantBase : public Node
{
}
protected:
ScalarConstantBase(const std::shared_ptr<TensorViewType>& type)
: Node({}, type)
{
}
virtual void propagate_types() override;
};
virtual void propagate_types() override;
};
// Implement a constant scalar for each element type.
// The static make method takes a
template <typename T>
class ScalarConstant : public ScalarConstantBase
{
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using type = typename T::type;
ScalarConstant(typename T::type value)
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value)
// Implement a constant scalar for each element type.
// The static make method takes a
template <typename T>
class ScalarConstant : public ScalarConstantBase
{
}
public:
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using type = typename T::type;
virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string get_node_id() const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
}
ScalarConstant(typename T::type value)
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value)
{
}
virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string get_node_id() const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
}
typename T::type get_value() const { return m_value; }
typename T::type get_value() const { return m_value; }
protected:
typename T::type m_value;
};
protected:
typename T::type m_value;
};
using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
using Float32ScalarConstant = ScalarConstant<element::Float32>;
using Int8ScalarConstant = ScalarConstant<element::Int8>;
using Int32ScalarConstant = ScalarConstant<element::Int32>;
using Int64ScalarConstant = ScalarConstant<element::Int64>;
using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
}
}
......@@ -16,25 +16,22 @@
namespace ngraph
{
class ConvertOp : public BuiltinOp
namespace op
{
public:
ConvertOp(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg})
, m_element_type(element_type)
class Convert : public Builtin
{
}
virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override;
public:
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: Builtin({arg})
, m_element_type(element_type)
{
}
protected:
const ngraph::element::Type& m_element_type;
};
virtual std::string get_op_class_name() const override { return "Convert"; }
virtual void propagate_types() override;
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
protected:
const ngraph::element::Type& m_element_type;
};
}
}
......@@ -16,22 +16,19 @@
namespace ngraph
{
class DotOp : public BuiltinOp
namespace op
{
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1})
class Dot : public Builtin
{
}
virtual std::string get_op_class_name() const override { return "dot"; }
virtual void propagate_types() override;
};
public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: Builtin({arg0, arg1})
{
}
namespace op
{
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
virtual std::string get_op_class_name() const override { return "Dot"; }
virtual void propagate_types() override;
};
}
}
......@@ -20,41 +20,33 @@
namespace ngraph
{
class Function;
///
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
/// Function creation requires a sequence of parameters.
/// Basic graph operations do not need parameters attached to a function.
///
class Parameter : public Node
{
friend class Function;
protected:
// Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function.
void assign_function(Function* function, size_t index);
public:
Parameter(const std::shared_ptr<ValueType>& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::string get_node_id() const override;
protected:
Function* m_function;
size_t m_index;
};
namespace op
{
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape);
///
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
/// Function creation requires a sequence of parameters.
/// Basic graph operations do not need parameters attached to a function.
///
class Parameter : public Node
{
friend class ngraph::Function;
protected:
// Called by the Function constructor to associate this parameter with the function.
// It is an error to try to associate a parameter with more than one function.
void assign_function(Function* function, size_t index);
public:
Parameter(const std::shared_ptr<ValueType>& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::string get_node_id() const override;
protected:
Function* m_function;
size_t m_index;
};
}
}
......@@ -18,18 +18,16 @@ namespace ngraph
{
namespace op
{
std::shared_ptr<Node> tuple(const Nodes& args);
}
class TupleOp : public BuiltinOp
{
public:
TupleOp(const Nodes& args)
: BuiltinOp(args)
class Tuple : public Builtin
{
}
public:
Tuple(const Nodes& args)
: Builtin(args)
{
}
virtual std::string get_op_class_name() const override { return "tuple"; }
virtual void propagate_types() override;
};
virtual std::string get_op_class_name() const override { return "Tuple"; }
virtual void propagate_types() override;
};
}
}
......@@ -15,20 +15,9 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
/// @param tensor The tensor view to be broadcast.
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
void Broadcast::propagate_types()
{
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
......
......@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void ConcatOp::propagate_types()
void Concat::propagate_types()
{
throw ngraph_error("NIY");
}
std::shared_ptr<Node> op::concatenate(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<ConcatOp>(args);
}
......@@ -14,6 +14,6 @@
#include "ngraph/ngraph.hpp"
using namespace ngraph;
using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {}
......@@ -17,15 +17,9 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void ConvertOp::propagate_types()
void Convert::propagate_types()
{
throw ngraph_error("NIY");
}
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
......@@ -17,16 +17,9 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction.
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
void Dot::propagate_types()
{
auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
......
......@@ -18,7 +18,7 @@ using namespace std;
using namespace ngraph;
Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters)
const std::vector<std::shared_ptr<op::Parameter>>& parameters)
: m_result(result)
, m_parameters(parameters)
, m_name("Function")
......@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result
parameter->assign_function(this, i++);
}
}
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
......@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const
ss << get_op_class_name() << "_" << m_instance_id;
return ss.str();
}
std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
{
return make_shared<AbsOp>(arg);
}
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<AddOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
// 'convert',
// 'convolution',
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DivideOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
{
return make_shared<ExpOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<FloorOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
{
return make_shared<LogOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MaximumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MinimumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MultiplyOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
{
return make_shared<NegativeOp>(arg0);
}
// 'pad',
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<RemainderOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
{
return make_shared<ReshapeOp>(arg0, shape);
}
//'reverse',
//'rng',
// 'select',
//'slice',
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<SubtractOp>(arg0, arg1);
}
// 'transpose',
// 'while'
......@@ -17,7 +17,7 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node(value_type)
......@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index)
void Parameter::propagate_types() {}
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type)
{
return make_shared<Parameter>(value_type);
}
shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type,
const Shape& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
std::string ngraph::Parameter::get_node_id() const
std::string ngraph::op::Parameter::get_node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
......
......@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
using namespace ngraph::op;
void TupleOp::propagate_types()
void Tuple::propagate_types()
{
throw ngraph_error("NIY");
}
std::shared_ptr<Node> op::tuple(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<TupleOp>(args);
}
......@@ -23,17 +23,17 @@ using namespace ngraph;
TEST(build_graph, build_simple)
{
// Function with 4 parameters
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0);
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
auto cluster_0 = make_shared<Function>(dot, op::Parameters{arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->get_result(), dot);
}
......@@ -59,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons
TEST(build_graph, node_comparison)
{
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32});
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1);
auto add = op::add(dot, arg2);
auto dot = make_shared<op::Dot>(arg0, arg1);
auto add = make_shared<op::Add>(dot, arg2);
auto parg = node<Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg);
auto parg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = make_shared<op::Dot>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
......@@ -78,20 +78,20 @@ TEST(build_graph, literal)
{
// float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<Float32ScalarConstant>(3.0);
auto float0 = make_shared<op::Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0);
auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int
auto float1 = node<Float32ScalarConstant>(3);
auto float1 = make_shared<op::Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
......
......@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op)
{
auto arg0 = op::parameter(element::Float32::element_type(), {1});
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op());
......@@ -31,9 +31,9 @@ TEST(op, is_op)
TEST(op, is_parameter)
{
auto arg0 = op::parameter(element::Float32::element_type(), {1});
auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0);
auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op());
......
......@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic)
{
vector<shared_ptr<Parameter>> args;
vector<shared_ptr<op::Parameter>> args;
for (int i = 0; i < 10; i++)
{
auto arg = op::parameter(element::Float32::element_type(), {1});
auto arg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg);
args.push_back(arg);
}
auto t0 = op::add(args[0], args[1]);
auto t0 = make_shared<op::Add>(args[0], args[1]);
ASSERT_NE(nullptr, t0);
auto t1 = op::dot(t0, args[2]);
auto t1 = make_shared<op::Dot>(t0, args[2]);
ASSERT_NE(nullptr, t1);
auto t2 = op::multiply(t0, args[3]);
auto t2 = make_shared<op::Multiply>(t0, args[3]);
ASSERT_NE(nullptr, t2);
auto t3 = op::add(t1, args[4]);
auto t3 = make_shared<op::Add>(t1, args[4]);
ASSERT_NE(nullptr, t2);
auto t4 = op::add(t2, args[5]);
auto t4 = make_shared<op::Add>(t2, args[5]);
ASSERT_NE(nullptr, t3);
auto r0 = op::add(t3, t4);
auto r0 = make_shared<op::Add>(t3, t4);
ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, args);
auto f0 = make_shared<Function>(r0, args);
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_arguments().size());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment