Commit 36e36e7f authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #78 from NervanaSystems/cyphers/morenames

De-use and cleanup op names.
parents 973b3a0e c7ef13f5
...@@ -22,14 +22,13 @@ ...@@ -22,14 +22,13 @@
namespace ngraph namespace ngraph
{ {
class Node; class Node;
namespace op {
class Parameter; class Parameter;
class ValueType;
template <typename T, typename... A> /// A list of parameters
std::shared_ptr<T> node(A&&... args) using Parameters = std::vector<std::shared_ptr<Parameter>>;
{
return std::make_shared<T>(args...);
} }
class ValueType;
/// Zero or more value types /// Zero or more value types
using ValueTypes = std::vector<std::shared_ptr<ValueType>>; using ValueTypes = std::vector<std::shared_ptr<ValueType>>;
...@@ -42,7 +41,4 @@ namespace ngraph ...@@ -42,7 +41,4 @@ namespace ngraph
/// A set of axes, for example, reduction axes /// A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>; using AxisSet = std::set<size_t>;
/// A list of parameters
using Parameters = std::vector<std::shared_ptr<Parameter>>;
} }
...@@ -26,10 +26,10 @@ namespace ngraph ...@@ -26,10 +26,10 @@ namespace ngraph
{ {
public: public:
Function(const std::shared_ptr<Node>& result, Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters); const std::vector<std::shared_ptr<op::Parameter>>& parameters);
std::shared_ptr<Node> get_result() { return m_result; } std::shared_ptr<Node> get_result() { return m_result; }
const std::vector<std::shared_ptr<Parameter>> get_parameters() const const std::vector<std::shared_ptr<op::Parameter>> get_parameters() const
{ {
return m_parameters; return m_parameters;
} }
...@@ -37,17 +37,7 @@ namespace ngraph ...@@ -37,17 +37,7 @@ namespace ngraph
protected: protected:
std::shared_ptr<Node> m_result; std::shared_ptr<Node> m_result;
std::vector<std::shared_ptr<ngraph::Parameter>> m_parameters; std::vector<std::shared_ptr<ngraph::op::Parameter>> m_parameters;
std::string m_name; std::string m_name;
}; };
namespace op
{
std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result,
const std::initializer_list<std::shared_ptr<Parameter>>& parameters);
std::shared_ptr<Function>
function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<Parameter>>& parameters);
}
} }
...@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const ...@@ -37,7 +37,7 @@ bool ngraph::Node::is_op() const
bool ngraph::Node::is_parameter() const bool ngraph::Node::is_parameter() const
{ {
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr; return dynamic_cast<const ngraph::op::Parameter*>(this) != nullptr;
} }
namespace ngraph namespace ngraph
......
This diff is collapsed.
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
namespace ngraph namespace ngraph
{ {
class BroadcastOp : public BuiltinOp namespace op
{
class Broadcast : public Builtin
{ {
public: public:
/// ///
...@@ -25,27 +27,21 @@ namespace ngraph ...@@ -25,27 +27,21 @@ namespace ngraph
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast. /// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg. /// the remaining axes in shape must be the same as the shape of arg.
/// ///
BroadcastOp(const std::shared_ptr<Node>& arg, Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: BuiltinOp({arg}) : Builtin({arg})
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
} }
virtual std::string get_op_class_name() const override { return "broadcast"; } virtual std::string get_op_class_name() const override { return "Broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
AxisSet m_broadcast_axes; AxisSet m_broadcast_axes;
}; };
namespace op
{
std::shared_ptr<Node> broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes);
} }
} }
...@@ -18,18 +18,16 @@ namespace ngraph ...@@ -18,18 +18,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
std::shared_ptr<Node> concatenate(const Nodes& args); class Concat : public Builtin
}
class ConcatOp : public BuiltinOp
{ {
public: public:
ConcatOp(const Nodes& args) Concat(const Nodes& args)
: BuiltinOp(args) : Builtin(args)
{ {
} }
virtual std::string get_op_class_name() const override { return "concatenate"; } virtual std::string get_op_class_name() const override { return "Concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
}
} }
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
namespace ngraph namespace ngraph
{ {
namespace op
{
// Defines methods to all constant scalars // Defines methods to all constant scalars
class ScalarConstantBase : public Node class ScalarConstantBase : public Node
{ {
...@@ -70,4 +72,5 @@ namespace ngraph ...@@ -70,4 +72,5 @@ namespace ngraph
using UInt8ScalarConstant = ScalarConstant<element::UInt8>; using UInt8ScalarConstant = ScalarConstant<element::UInt8>;
using UInt32ScalarConstant = ScalarConstant<element::UInt32>; using UInt32ScalarConstant = ScalarConstant<element::UInt32>;
using UInt64ScalarConstant = ScalarConstant<element::UInt64>; using UInt64ScalarConstant = ScalarConstant<element::UInt64>;
}
} }
...@@ -16,25 +16,22 @@ ...@@ -16,25 +16,22 @@
namespace ngraph namespace ngraph
{ {
class ConvertOp : public BuiltinOp namespace op
{
class Convert : public Builtin
{ {
public: public:
ConvertOp(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type) Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type)
: BuiltinOp({arg}) : Builtin({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
} }
virtual std::string get_op_class_name() const override { return "convert"; } virtual std::string get_op_class_name() const override { return "Convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
}; };
namespace op
{
std::shared_ptr<ngraph::ConvertOp> convert(const std::shared_ptr<Node>& arg,
const ngraph::element::Type& element_type);
} }
} }
...@@ -16,22 +16,19 @@ ...@@ -16,22 +16,19 @@
namespace ngraph namespace ngraph
{ {
class DotOp : public BuiltinOp namespace op
{
class Dot : public Builtin
{ {
public: public:
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. /// TODO: Semantics of arg0 and arg1 axes wrt reduction.
DotOp(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1) Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1)
: BuiltinOp({arg0, arg1}) : Builtin({arg0, arg1})
{ {
} }
virtual std::string get_op_class_name() const override { return "dot"; } virtual std::string get_op_class_name() const override { return "Dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
namespace op
{
std::shared_ptr<Node> dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
} }
} }
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ngraph namespace ngraph
{ {
class Function; class Function;
namespace op
{
/// ///
/// Parameters are nodes that represent the arguments that will be passed to user-defined functions. /// Parameters are nodes that represent the arguments that will be passed to user-defined functions.
/// Function creation requires a sequence of parameters. /// Function creation requires a sequence of parameters.
...@@ -28,7 +29,7 @@ namespace ngraph ...@@ -28,7 +29,7 @@ namespace ngraph
/// ///
class Parameter : public Node class Parameter : public Node
{ {
friend class Function; friend class ngraph::Function;
protected: protected:
// Called by the Function constructor to associate this parameter with the function. // Called by the Function constructor to associate this parameter with the function.
...@@ -47,14 +48,5 @@ namespace ngraph ...@@ -47,14 +48,5 @@ namespace ngraph
Function* m_function; Function* m_function;
size_t m_index; size_t m_index;
}; };
namespace op
{
/// Factory for frameworks
std::shared_ptr<ngraph::Parameter>
parameter(const std::shared_ptr<ValueType>& value_type = nullptr);
/// Convenience factory for tests
std::shared_ptr<ngraph::Parameter> parameter(const element::Type element_type,
const Shape& shape);
} }
} }
...@@ -18,18 +18,16 @@ namespace ngraph ...@@ -18,18 +18,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
std::shared_ptr<Node> tuple(const Nodes& args); class Tuple : public Builtin
}
class TupleOp : public BuiltinOp
{ {
public: public:
TupleOp(const Nodes& args) Tuple(const Nodes& args)
: BuiltinOp(args) : Builtin(args)
{ {
} }
virtual std::string get_op_class_name() const override { return "tuple"; } virtual std::string get_op_class_name() const override { return "Tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
}
} }
...@@ -15,20 +15,9 @@ ...@@ -15,20 +15,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
/// @param tensor The tensor view to be broadcast. void Broadcast::propagate_types()
/// @param shape The shape of the result
/// @param broadcast_axes The axis positions (0-based) in the result that are being broadcast.
/// the remaining axes in shape must be the same as the shape of arg.
std::shared_ptr<Node> ngraph::op::broadcast(const std::shared_ptr<Node>& tensor,
const Shape& shape,
AxisSet&& broadcast_axes)
{
return make_shared<BroadcastOp>(tensor, shape, broadcast_axes);
}
void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->get_value_type(); auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type) if (nullptr == arg_type)
......
...@@ -17,14 +17,9 @@ ...@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
void ConcatOp::propagate_types() void Concat::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
std::shared_ptr<Node> op::concatenate(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<ConcatOp>(args);
}
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace ngraph; using namespace ngraph::op;
void ScalarConstantBase::propagate_types() {} void ScalarConstantBase::propagate_types() {}
...@@ -17,15 +17,9 @@ ...@@ -17,15 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
void ConvertOp::propagate_types() void Convert::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
shared_ptr<ConvertOp> op::convert(const std::shared_ptr<Node>& arg,
const element::Type& element_type)
{
return make_shared<ConvertOp>(arg, element_type);
}
...@@ -17,16 +17,9 @@ ...@@ -17,16 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
/// TODO: Semantics of arg0 and arg1 axes wrt reduction. void Dot::propagate_types()
std::shared_ptr<Node> ngraph::op::dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DotOp>(arg0, arg1);
}
void DotOp::propagate_types()
{ {
auto arg0_tensor_type = auto arg0_tensor_type =
dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type()); dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
......
...@@ -18,7 +18,7 @@ using namespace std; ...@@ -18,7 +18,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
Function::Function(const std::shared_ptr<Node>& result, Function::Function(const std::shared_ptr<Node>& result,
const std::vector<std::shared_ptr<ngraph::Parameter>>& parameters) const std::vector<std::shared_ptr<op::Parameter>>& parameters)
: m_result(result) : m_result(result)
, m_parameters(parameters) , m_parameters(parameters)
, m_name("Function") , m_name("Function")
...@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result ...@@ -29,15 +29,3 @@ Function::Function(const std::shared_ptr<Node>& result
parameter->assign_function(this, i++); parameter->assign_function(this, i++);
} }
} }
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const initializer_list<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
shared_ptr<Function> ngraph::op::function(const std::shared_ptr<Node>& result,
const vector<shared_ptr<Parameter>>& parameters)
{
return make_shared<Function>(result, parameters);
}
...@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const ...@@ -26,103 +26,3 @@ std::string ngraph::Op::get_node_id() const
ss << get_op_class_name() << "_" << m_instance_id; ss << get_op_class_name() << "_" << m_instance_id;
return ss.str(); return ss.str();
} }
std::shared_ptr<Node> ngraph::op::abs(const std::shared_ptr<Node>& arg)
{
return make_shared<AbsOp>(arg);
}
std::shared_ptr<Node> ngraph::op::add(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<AddOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::ceiling(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<CeilingOp>(arg0, arg1);
}
// 'convert',
// 'convolution',
std::shared_ptr<Node> ngraph::op::divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<DivideOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::exp(const std::shared_ptr<Node>& arg0)
{
return make_shared<ExpOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::floor(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<FloorOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::log(const std::shared_ptr<Node>& arg0)
{
return make_shared<LogOp>(arg0);
}
std::shared_ptr<Node> ngraph::op::maximum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MaximumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::minimum(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MinimumOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::multiply(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<MultiplyOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::negative(const std::shared_ptr<Node>& arg0)
{
return make_shared<NegativeOp>(arg0);
}
// 'pad',
std::shared_ptr<Node> ngraph::op::power(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<PowerOp>(arg0, arg1);
}
//'reduce',
std::shared_ptr<Node> ngraph::op::remainder(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<RemainderOp>(arg0, arg1);
}
std::shared_ptr<Node> ngraph::op::reshape(const std::shared_ptr<Node>& arg0, const Shape& shape)
{
return make_shared<ReshapeOp>(arg0, shape);
}
//'reverse',
//'rng',
// 'select',
//'slice',
std::shared_ptr<Node> ngraph::op::subtract(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
{
return make_shared<SubtractOp>(arg0, arg1);
}
// 'transpose',
// 'while'
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
Parameter::Parameter(const std::shared_ptr<ValueType>& value_type) Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
: Node(value_type) : Node(value_type)
...@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index) ...@@ -43,18 +43,7 @@ void Parameter::assign_function(Function* function, size_t index)
void Parameter::propagate_types() {} void Parameter::propagate_types() {}
shared_ptr<Parameter> ngraph::op::parameter(const std::shared_ptr<ValueType>& value_type) std::string ngraph::op::Parameter::get_node_id() const
{
return make_shared<Parameter>(value_type);
}
shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_type,
const Shape& shape)
{
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
std::string ngraph::Parameter::get_node_id() const
{ {
stringstream ss; stringstream ss;
ss << "parameter_" << m_instance_id; ss << "parameter_" << m_instance_id;
......
...@@ -17,14 +17,9 @@ ...@@ -17,14 +17,9 @@
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph::op;
void TupleOp::propagate_types() void Tuple::propagate_types()
{ {
throw ngraph_error("NIY"); throw ngraph_error("NIY");
} }
std::shared_ptr<Node> op::tuple(const std::vector<std::shared_ptr<Node>>& args)
{
return make_shared<TupleOp>(args);
}
...@@ -23,17 +23,17 @@ using namespace ngraph; ...@@ -23,17 +23,17 @@ using namespace ngraph;
TEST(build_graph, build_simple) TEST(build_graph, build_simple)
{ {
// Function with 4 parameters // Function with 4 parameters
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{7, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{7, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3}); auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32, 7}); auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto arg3 = node<Parameter>(element::Float32::element_type(), Shape{32, 7}); auto arg3 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 7});
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto broadcast_1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto b1 = make_shared<op::Broadcast>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0); auto dot = make_shared<op::Dot>(arg2, arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3}); auto cluster_0 = make_shared<Function>(dot, op::Parameters{arg0, arg1, arg2, arg3});
ASSERT_EQ(cluster_0->get_result(), dot); ASSERT_EQ(cluster_0->get_result(), dot);
} }
...@@ -59,15 +59,15 @@ TEST(build_graph, as_type) ...@@ -59,15 +59,15 @@ TEST(build_graph, as_type)
// Check node comparisons // Check node comparisons
TEST(build_graph, node_comparison) TEST(build_graph, node_comparison)
{ {
auto arg0 = node<Parameter>(element::Float32::element_type(), Shape{32, 3}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32, 3});
auto arg1 = node<Parameter>(element::Float32::element_type(), Shape{3}); auto arg1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3});
auto arg2 = node<Parameter>(element::Float32::element_type(), Shape{32}); auto arg2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{32});
auto dot = op::dot(arg0, arg1); auto dot = make_shared<op::Dot>(arg0, arg1);
auto add = op::add(dot, arg2); auto add = make_shared<op::Add>(dot, arg2);
auto parg = node<Parameter>(element::Float32::element_type(), Shape{}); auto parg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = node<DotOp>(parg, parg); auto pattern_dot = make_shared<op::Dot>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot)); ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented. // TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong. // Need to figure out what's wrong.
...@@ -78,20 +78,20 @@ TEST(build_graph, literal) ...@@ -78,20 +78,20 @@ TEST(build_graph, literal)
{ {
// float scalar from a float // float scalar from a float
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<Float32ScalarConstant>(3.0); auto float0 = make_shared<op::Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0); ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type); ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0); auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0); ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = node<Float32ScalarConstant>(3); auto float1 = make_shared<op::Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type); ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0); auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3); ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
......
...@@ -23,7 +23,7 @@ using namespace ngraph; ...@@ -23,7 +23,7 @@ using namespace ngraph;
TEST(op, is_op) TEST(op, is_op)
{ {
auto arg0 = op::parameter(element::Float32::element_type(), {1}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter()); EXPECT_TRUE(arg0->is_parameter());
EXPECT_FALSE(arg0->is_op()); EXPECT_FALSE(arg0->is_op());
...@@ -31,9 +31,9 @@ TEST(op, is_op) ...@@ -31,9 +31,9 @@ TEST(op, is_op)
TEST(op, is_parameter) TEST(op, is_parameter)
{ {
auto arg0 = op::parameter(element::Float32::element_type(), {1}); auto arg0 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg0); ASSERT_NE(nullptr, arg0);
auto t0 = op::add(arg0, arg0); auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter()); EXPECT_FALSE(t0->is_parameter());
EXPECT_TRUE(t0->is_op()); EXPECT_TRUE(t0->is_op());
......
...@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes) ...@@ -58,30 +58,30 @@ static bool validate_list(const vector<Node*>& nodes)
TEST(topological_sort, basic) TEST(topological_sort, basic)
{ {
vector<shared_ptr<Parameter>> args; vector<shared_ptr<op::Parameter>> args;
for (int i = 0; i < 10; i++) for (int i = 0; i < 10; i++)
{ {
auto arg = op::parameter(element::Float32::element_type(), {1}); auto arg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1});
ASSERT_NE(nullptr, arg); ASSERT_NE(nullptr, arg);
args.push_back(arg); args.push_back(arg);
} }
auto t0 = op::add(args[0], args[1]); auto t0 = make_shared<op::Add>(args[0], args[1]);
ASSERT_NE(nullptr, t0); ASSERT_NE(nullptr, t0);
auto t1 = op::dot(t0, args[2]); auto t1 = make_shared<op::Dot>(t0, args[2]);
ASSERT_NE(nullptr, t1); ASSERT_NE(nullptr, t1);
auto t2 = op::multiply(t0, args[3]); auto t2 = make_shared<op::Multiply>(t0, args[3]);
ASSERT_NE(nullptr, t2); ASSERT_NE(nullptr, t2);
auto t3 = op::add(t1, args[4]); auto t3 = make_shared<op::Add>(t1, args[4]);
ASSERT_NE(nullptr, t2); ASSERT_NE(nullptr, t2);
auto t4 = op::add(t2, args[5]); auto t4 = make_shared<op::Add>(t2, args[5]);
ASSERT_NE(nullptr, t3); ASSERT_NE(nullptr, t3);
auto r0 = op::add(t3, t4); auto r0 = make_shared<op::Add>(t3, t4);
ASSERT_NE(nullptr, r0); ASSERT_NE(nullptr, r0);
auto f0 = op::function(r0, args); auto f0 = make_shared<Function>(r0, args);
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->get_arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment