Commit 7ed0fe7d authored by Scott Cyphers's avatar Scott Cyphers

Switch to get/set

parent c66a7469
...@@ -45,15 +45,15 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node) ...@@ -45,15 +45,15 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node); auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
if (op_tmp) if (op_tmp)
{ {
out << "Op(" << op_tmp->node_id() << ")"; out << "Op(" << op_tmp->get_node_id() << ")";
} }
else if (parameter_tmp) else if (parameter_tmp)
{ {
out << "Parameter(" << parameter_tmp->node_id() << ")"; out << "Parameter(" << parameter_tmp->get_node_id() << ")";
} }
else else
{ {
out << "Node(" << node.node_id() << ")"; out << "Node(" << node.get_node_id() << ")";
} }
return out; return out;
} }
...@@ -48,14 +48,14 @@ namespace ngraph ...@@ -48,14 +48,14 @@ namespace ngraph
/// Propagate types and check arguments for consistency /// Propagate types and check arguments for consistency
virtual void propagate_types() = 0; virtual void propagate_types() = 0;
const Nodes& arguments() const { return m_arguments; } const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; } const std::multiset<Node*>& users() const { return m_users; }
std::string name() const { return m_name; } std::string get_name() const { return m_name; }
void name(const std::string& name) { m_name = name; } void set_name(const std::string& name) { m_name = name; }
virtual std::string node_id() const = 0; virtual std::string get_node_id() const = 0;
/** /**
** Return true if this has the same implementing class as node. This ** Return true if this has the same implementing class as node. This
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
bool is_op() const; bool is_op() const;
bool is_parameter() const; bool is_parameter() const;
size_t instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
protected: protected:
......
...@@ -74,8 +74,8 @@ namespace ngraph ...@@ -74,8 +74,8 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const = 0; virtual std::string get_op_class_name() const = 0;
virtual std::string node_id() const override; virtual std::string get_node_id() const override;
}; };
/** /**
...@@ -116,7 +116,7 @@ namespace ngraph ...@@ -116,7 +116,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "abs"; } virtual std::string get_op_class_name() const override { return "abs"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -127,7 +127,7 @@ namespace ngraph ...@@ -127,7 +127,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1}) : BuiltinOp({arg0, arg1})
{ {
} }
virtual std::string op_class_name() const override { return "add"; } virtual std::string get_op_class_name() const override { return "add"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -139,7 +139,7 @@ namespace ngraph ...@@ -139,7 +139,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "ceiling"; } virtual std::string get_op_class_name() const override { return "ceiling"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -151,7 +151,7 @@ namespace ngraph ...@@ -151,7 +151,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "divide"; } virtual std::string get_op_class_name() const override { return "divide"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -163,7 +163,7 @@ namespace ngraph ...@@ -163,7 +163,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "equal"; } virtual std::string get_op_class_name() const override { return "equal"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -175,7 +175,7 @@ namespace ngraph ...@@ -175,7 +175,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "exp"; } virtual std::string get_op_class_name() const override { return "exp"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -187,7 +187,7 @@ namespace ngraph ...@@ -187,7 +187,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "floor"; } virtual std::string get_op_class_name() const override { return "floor"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -199,7 +199,7 @@ namespace ngraph ...@@ -199,7 +199,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "greater"; } virtual std::string get_op_class_name() const override { return "greater"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -211,7 +211,7 @@ namespace ngraph ...@@ -211,7 +211,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "less"; } virtual std::string get_op_class_name() const override { return "less"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -223,7 +223,7 @@ namespace ngraph ...@@ -223,7 +223,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "log"; } virtual std::string get_op_class_name() const override { return "log"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -235,7 +235,7 @@ namespace ngraph ...@@ -235,7 +235,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "max"; } virtual std::string get_op_class_name() const override { return "max"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -247,7 +247,7 @@ namespace ngraph ...@@ -247,7 +247,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "min"; } virtual std::string get_op_class_name() const override { return "min"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -259,7 +259,7 @@ namespace ngraph ...@@ -259,7 +259,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "multiply"; } virtual std::string get_op_class_name() const override { return "multiply"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -271,7 +271,7 @@ namespace ngraph ...@@ -271,7 +271,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "negative"; } virtual std::string get_op_class_name() const override { return "negative"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -283,7 +283,7 @@ namespace ngraph ...@@ -283,7 +283,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "power"; } virtual std::string get_op_class_name() const override { return "power"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -295,7 +295,7 @@ namespace ngraph ...@@ -295,7 +295,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "remainder"; } virtual std::string get_op_class_name() const override { return "remainder"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
...@@ -308,7 +308,7 @@ namespace ngraph ...@@ -308,7 +308,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "reshape"; } virtual std::string get_op_class_name() const override { return "reshape"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
protected: protected:
Shape m_shape; Shape m_shape;
...@@ -322,7 +322,7 @@ namespace ngraph ...@@ -322,7 +322,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "subtract"; } virtual std::string get_op_class_name() const override { return "subtract"; }
//virtual void propagate_types() override; //virtual void propagate_types() override;
}; };
} }
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "broadcast"; } virtual std::string get_op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "concatenate"; } virtual std::string get_op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
...@@ -50,14 +50,14 @@ namespace ngraph ...@@ -50,14 +50,14 @@ namespace ngraph
} }
virtual std::string description() const override { return "ScalarConstant"; } virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string node_id() const override virtual std::string get_node_id() const override
{ {
std::stringstream ss; std::stringstream ss;
ss << description() << "_" /* << node_id() */; ss << description() << "_" /* << node_id() */;
return ss.str(); return ss.str();
} }
typename T::type value() const { return m_value; } typename T::type get_value() const { return m_value; }
protected: protected:
typename T::type m_value; typename T::type m_value;
......
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "convert"; } virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override; virtual void propagate_types() override;
protected: protected:
const ngraph::element::Type& m_element_type; const ngraph::element::Type& m_element_type;
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "dot"; } virtual std::string get_op_class_name() const override { return "dot"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
......
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string node_id() const override; virtual std::string get_node_id() const override;
protected: protected:
Function* m_function; Function* m_function;
......
...@@ -29,7 +29,7 @@ namespace ngraph ...@@ -29,7 +29,7 @@ namespace ngraph
{ {
} }
virtual std::string op_class_name() const override { return "tuple"; } virtual std::string get_op_class_name() const override { return "tuple"; }
virtual void propagate_types() override; virtual void propagate_types() override;
}; };
} }
...@@ -64,8 +64,8 @@ namespace ngraph ...@@ -64,8 +64,8 @@ namespace ngraph
{ {
} }
const element::Type& element_type() const { return m_element_type; } const element::Type& get_element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; } const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override; virtual bool operator==(const ValueType::ptr& that) const override;
...@@ -97,8 +97,8 @@ namespace ngraph ...@@ -97,8 +97,8 @@ namespace ngraph
{ {
} }
const std::vector<ValueType::ptr> element_types() const { return m_element_types; } const std::vector<ValueType::ptr> get_element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; } std::vector<ValueType::ptr> set_element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override; virtual bool operator==(const ValueType::ptr& that) const override;
...@@ -121,13 +121,13 @@ namespace ngraph ...@@ -121,13 +121,13 @@ namespace ngraph
** Set the type ** Set the type
** /param type The new type ** /param type The new type
**/ **/
void value_type(const ValueType::ptr& value_type) { m_value_type = value_type; } void set_value_type(const ValueType::ptr& value_type) { m_value_type = value_type; }
/** /**
** Set the type to be a tensor view type ** Set the type to be a tensor view type
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
** /param shape The shape of the view ** /param shape The shape of the view
**/ **/
void value_type(const element::Type& element_type, const Shape& shape) void set_value_type(const element::Type& element_type, const Shape& shape)
{ {
m_value_type = std::make_shared<TensorViewType>(element_type, shape); m_value_type = std::make_shared<TensorViewType>(element_type, shape);
} }
...@@ -135,11 +135,11 @@ namespace ngraph ...@@ -135,11 +135,11 @@ namespace ngraph
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
ValueType::ptr value_type() { return m_value_type; } ValueType::ptr get_value_type() { return m_value_type; }
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
const ValueType::ptr value_type() const { return m_value_type; } const ValueType::ptr get_value_type() const { return m_value_type; }
protected: protected:
ValueType::ptr m_value_type; ValueType::ptr m_value_type;
}; };
......
...@@ -33,9 +33,9 @@ void Visualize::add(node_ptr p) ...@@ -33,9 +33,9 @@ void Visualize::add(node_ptr p)
// map<size_t, list<node_ptr>> dependent_nodes; // map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node) traverse_nodes(p, [&](node_ptr node)
{ {
for (auto arg : node->arguments()) for (auto arg : node->get_arguments())
{ {
m_ss << " " << arg->node_id() << " -> " << node->node_id() << ";\n"; m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
} }
}); });
} }
......
...@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, ...@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
void BroadcastOp::propagate_types() void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->value_type(); auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type) if (nullptr == arg_type)
{ {
throw ngraph_error("Argument to broadcast is missing type."); throw ngraph_error("Argument to broadcast is missing type.");
...@@ -47,11 +47,11 @@ void BroadcastOp::propagate_types() ...@@ -47,11 +47,11 @@ void BroadcastOp::propagate_types()
{ {
target_shape.erase(target_shape.begin() + *i); target_shape.erase(target_shape.begin() + *i);
} }
if (Shape{target_shape} != arg_tensor_view_type->shape()) if (Shape{target_shape} != arg_tensor_view_type->get_shape())
{ {
throw ngraph_error("Broadcast arg, shape, and axes are incompatible"); throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
} }
// TODO If m_type is already set (by framework), this should verify that the type // TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects. // we expect is consistent with the type the framework expects.
m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape); m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape);
} }
...@@ -27,21 +27,21 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -27,21 +27,21 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void DotOp::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->value_type()); auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->value_type()); auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type) if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{ {
throw ngraph_error("Arguments to dot must be tensor views"); throw ngraph_error("Arguments to dot must be tensor views");
} }
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type()) if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{ {
throw ngraph_error("Arguments to dot must have the same element type"); throw ngraph_error("Arguments to dot must have the same element type");
} }
// Use NumPy semantics for now // Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis. // Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape(); vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction; size_t arg1_reduction;
if (arg1_shape.size() > 1) if (arg1_shape.size() > 1)
...@@ -60,5 +60,5 @@ void DotOp::propagate_types() ...@@ -60,5 +60,5 @@ void DotOp::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end()); copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end()); copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end()); copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_value_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape); m_value_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
} }
...@@ -20,10 +20,10 @@ ...@@ -20,10 +20,10 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
std::string ngraph::Op::node_id() const std::string ngraph::Op::get_node_id() const
{ {
stringstream ss; stringstream ss;
ss << op_class_name() << "_" << m_instance_id; ss << get_op_class_name() << "_" << m_instance_id;
return ss.str(); return ss.str();
} }
......
...@@ -56,7 +56,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_ ...@@ -56,7 +56,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape)); return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
} }
std::string ngraph::Parameter::node_id() const std::string ngraph::Parameter::get_node_id() const
{ {
stringstream ss; stringstream ss;
ss << "parameter_" << m_instance_id; ss << "parameter_" << m_instance_id;
......
...@@ -26,11 +26,11 @@ bool TensorViewType::operator==(const ValueType::ptr& that) const ...@@ -26,11 +26,11 @@ bool TensorViewType::operator==(const ValueType::ptr& that) const
{ {
return false; return false;
} }
if (that_tvt->element_type() != m_element_type) if (that_tvt->get_element_type() != m_element_type)
{ {
return false; return false;
} }
if (that_tvt->shape() != m_shape) if (that_tvt->get_shape() != m_shape)
{ {
return false; return false;
} }
...@@ -44,5 +44,5 @@ bool TupleType::operator==(const ValueType::ptr& that) const ...@@ -44,5 +44,5 @@ bool TupleType::operator==(const ValueType::ptr& that) const
{ {
return false; return false;
} }
return that_tvt->element_types() == element_types(); return that_tvt->get_element_types() == get_element_types();
} }
...@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p, ...@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::set<size_t>& instances_seen) std::set<size_t>& instances_seen)
{ {
f(p); f(p);
for (auto arg : p->arguments()) for (auto arg : p->get_arguments())
{ {
if (instances_seen.find(arg->instance_id()) == instances_seen.end()) if (instances_seen.find(arg->get_instance_id()) == instances_seen.end())
{ {
instances_seen.insert(arg->instance_id()); instances_seen.insert(arg->get_instance_id());
traverse_nodes(arg, f, instances_seen); traverse_nodes(arg, f, instances_seen);
} }
} }
......
...@@ -30,8 +30,8 @@ TEST(build_graph, build_simple) ...@@ -30,8 +30,8 @@ TEST(build_graph, build_simple)
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0}); auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0); auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2); ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0); ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3}); auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
...@@ -80,22 +80,22 @@ TEST(build_graph, literal) ...@@ -80,22 +80,22 @@ TEST(build_graph, literal)
//auto float0 = FloatScalarConstant::make(3.0); //auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0); auto float0 = node<FloatScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0); ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->value_type(), float_scalar_type); ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0); auto d = node<DotOp>(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0); ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0); ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int // float scalar from an int
auto float1 = node<FloatScalarConstant>(3); auto float1 = node<FloatScalarConstant>(3);
ASSERT_EQ(float1->value(), 3); ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->value_type(), float_scalar_type); ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0); auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3); ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->value_type(), int32_scalar_type); ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->value_type(), float_scalar_type); ASSERT_NE(*int32_0->get_value_type(), float_scalar_type);
} }
// Check argument inverses // Check argument inverses
......
...@@ -42,7 +42,7 @@ TEST(top_sort, basic) ...@@ -42,7 +42,7 @@ TEST(top_sort, basic)
auto f0 = op::function(r0, {arg0, arg1}); auto f0 = op::function(r0, {arg0, arg1});
ASSERT_NE(nullptr, f0); ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size()); ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0); auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl; cout << "op_r0 name " << *r0 << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment