Commit 7ed0fe7d authored by Scott Cyphers's avatar Scott Cyphers

Switch to get/set

parent c66a7469
......@@ -45,15 +45,15 @@ std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
auto parameter_tmp = dynamic_cast<const ngraph::Op*>(&node);
if (op_tmp)
{
out << "Op(" << op_tmp->node_id() << ")";
out << "Op(" << op_tmp->get_node_id() << ")";
}
else if (parameter_tmp)
{
out << "Parameter(" << parameter_tmp->node_id() << ")";
out << "Parameter(" << parameter_tmp->get_node_id() << ")";
}
else
{
out << "Node(" << node.node_id() << ")";
out << "Node(" << node.get_node_id() << ")";
}
return out;
}
......@@ -48,14 +48,14 @@ namespace ngraph
/// Propagate types and check arguments for consistency
virtual void propagate_types() = 0;
const Nodes& arguments() const { return m_arguments; }
const Nodes& get_arguments() const { return m_arguments; }
const std::multiset<Node*>& users() const { return m_users; }
std::string name() const { return m_name; }
void name(const std::string& name) { m_name = name; }
std::string get_name() const { return m_name; }
void set_name(const std::string& name) { m_name = name; }
virtual std::string node_id() const = 0;
virtual std::string get_node_id() const = 0;
/**
** Return true if this has the same implementing class as node. This
......@@ -70,7 +70,7 @@ namespace ngraph
bool is_op() const;
bool is_parameter() const;
size_t instance_id() const { return m_instance_id; }
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
protected:
......
......@@ -74,8 +74,8 @@ namespace ngraph
{
}
virtual std::string op_class_name() const = 0;
virtual std::string node_id() const override;
virtual std::string get_op_class_name() const = 0;
virtual std::string get_node_id() const override;
};
/**
......@@ -116,7 +116,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "abs"; }
virtual std::string get_op_class_name() const override { return "abs"; }
//virtual void propagate_types() override;
};
......@@ -127,7 +127,7 @@ namespace ngraph
: BuiltinOp({arg0, arg1})
{
}
virtual std::string op_class_name() const override { return "add"; }
virtual std::string get_op_class_name() const override { return "add"; }
//virtual void propagate_types() override;
};
......@@ -139,7 +139,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "ceiling"; }
virtual std::string get_op_class_name() const override { return "ceiling"; }
//virtual void propagate_types() override;
};
......@@ -151,7 +151,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "divide"; }
virtual std::string get_op_class_name() const override { return "divide"; }
//virtual void propagate_types() override;
};
......@@ -163,7 +163,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "equal"; }
virtual std::string get_op_class_name() const override { return "equal"; }
//virtual void propagate_types() override;
};
......@@ -175,7 +175,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "exp"; }
virtual std::string get_op_class_name() const override { return "exp"; }
//virtual void propagate_types() override;
};
......@@ -187,7 +187,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "floor"; }
virtual std::string get_op_class_name() const override { return "floor"; }
//virtual void propagate_types() override;
};
......@@ -199,7 +199,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "greater"; }
virtual std::string get_op_class_name() const override { return "greater"; }
//virtual void propagate_types() override;
};
......@@ -211,7 +211,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "less"; }
virtual std::string get_op_class_name() const override { return "less"; }
//virtual void propagate_types() override;
};
......@@ -223,7 +223,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "log"; }
virtual std::string get_op_class_name() const override { return "log"; }
//virtual void propagate_types() override;
};
......@@ -235,7 +235,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "max"; }
virtual std::string get_op_class_name() const override { return "max"; }
//virtual void propagate_types() override;
};
......@@ -247,7 +247,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "min"; }
virtual std::string get_op_class_name() const override { return "min"; }
//virtual void propagate_types() override;
};
......@@ -259,7 +259,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "multiply"; }
virtual std::string get_op_class_name() const override { return "multiply"; }
//virtual void propagate_types() override;
};
......@@ -271,7 +271,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "negative"; }
virtual std::string get_op_class_name() const override { return "negative"; }
//virtual void propagate_types() override;
};
......@@ -283,7 +283,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "power"; }
virtual std::string get_op_class_name() const override { return "power"; }
//virtual void propagate_types() override;
};
......@@ -295,7 +295,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "remainder"; }
virtual std::string get_op_class_name() const override { return "remainder"; }
//virtual void propagate_types() override;
};
......@@ -308,7 +308,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "reshape"; }
virtual std::string get_op_class_name() const override { return "reshape"; }
//virtual void propagate_types() override;
protected:
Shape m_shape;
......@@ -322,7 +322,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "subtract"; }
virtual std::string get_op_class_name() const override { return "subtract"; }
//virtual void propagate_types() override;
};
}
......@@ -32,7 +32,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "broadcast"; }
virtual std::string get_op_class_name() const override { return "broadcast"; }
virtual void propagate_types() override;
protected:
......
......@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "concatenate"; }
virtual std::string get_op_class_name() const override { return "concatenate"; }
virtual void propagate_types() override;
};
}
......@@ -50,14 +50,14 @@ namespace ngraph
}
virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string node_id() const override
virtual std::string get_node_id() const override
{
std::stringstream ss;
ss << description() << "_" /* << node_id() */;
return ss.str();
}
typename T::type value() const { return m_value; }
typename T::type get_value() const { return m_value; }
protected:
typename T::type m_value;
......
......@@ -26,7 +26,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "convert"; }
virtual std::string get_op_class_name() const override { return "convert"; }
virtual void propagate_types() override;
protected:
const ngraph::element::Type& m_element_type;
......
......@@ -25,7 +25,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "dot"; }
virtual std::string get_op_class_name() const override { return "dot"; }
virtual void propagate_types() override;
};
......
......@@ -41,7 +41,7 @@ namespace ngraph
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::string node_id() const override;
virtual std::string get_node_id() const override;
protected:
Function* m_function;
......
......@@ -29,7 +29,7 @@ namespace ngraph
{
}
virtual std::string op_class_name() const override { return "tuple"; }
virtual std::string get_op_class_name() const override { return "tuple"; }
virtual void propagate_types() override;
};
}
......@@ -64,8 +64,8 @@ namespace ngraph
{
}
const element::Type& element_type() const { return m_element_type; }
const Shape& shape() const { return m_shape; }
const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const ValueType::ptr& that) const override;
......@@ -97,8 +97,8 @@ namespace ngraph
{
}
const std::vector<ValueType::ptr> element_types() const { return m_element_types; }
std::vector<ValueType::ptr> element_types() { return m_element_types; }
const std::vector<ValueType::ptr> get_element_types() const { return m_element_types; }
std::vector<ValueType::ptr> set_element_types() { return m_element_types; }
virtual bool operator==(const ValueType::ptr& that) const override;
......@@ -121,13 +121,13 @@ namespace ngraph
** Set the type
** /param type The new type
**/
void value_type(const ValueType::ptr& value_type) { m_value_type = value_type; }
void set_value_type(const ValueType::ptr& value_type) { m_value_type = value_type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void value_type(const element::Type& element_type, const Shape& shape)
void set_value_type(const element::Type& element_type, const Shape& shape)
{
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
......@@ -135,11 +135,11 @@ namespace ngraph
/**
** The type associated with this value.
**/
ValueType::ptr value_type() { return m_value_type; }
ValueType::ptr get_value_type() { return m_value_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr value_type() const { return m_value_type; }
const ValueType::ptr get_value_type() const { return m_value_type; }
protected:
ValueType::ptr m_value_type;
};
......
......@@ -33,9 +33,9 @@ void Visualize::add(node_ptr p)
// map<size_t, list<node_ptr>> dependent_nodes;
traverse_nodes(p, [&](node_ptr node)
{
for (auto arg : node->arguments())
for (auto arg : node->get_arguments())
{
m_ss << " " << arg->node_id() << " -> " << node->node_id() << ";\n";
m_ss << " " << arg->get_node_id() << " -> " << node->get_node_id() << ";\n";
}
});
}
......
......@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->value_type();
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
......@@ -47,11 +47,11 @@ void BroadcastOp::propagate_types()
{
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != arg_tensor_view_type->shape())
if (Shape{target_shape} != arg_tensor_view_type->get_shape())
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape);
}
......@@ -27,21 +27,21 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->value_type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->value_type());
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->get_value_type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->get_value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
}
if (arg0_tensor_type->element_type() != arg1_tensor_type->element_type())
if (arg0_tensor_type->get_element_type() != arg1_tensor_type->get_element_type())
{
throw ngraph_error("Arguments to dot must have the same element type");
}
// Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis.
vector<size_t> arg0_shape = arg0_tensor_type->shape();
vector<size_t> arg1_shape = arg1_tensor_type->shape();
vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1;
size_t arg1_reduction;
if (arg1_shape.size() > 1)
......@@ -60,5 +60,5 @@ void DotOp::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_value_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
m_value_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
}
......@@ -20,10 +20,10 @@
using namespace ngraph;
using namespace std;
std::string ngraph::Op::node_id() const
std::string ngraph::Op::get_node_id() const
{
stringstream ss;
ss << op_class_name() << "_" << m_instance_id;
ss << get_op_class_name() << "_" << m_instance_id;
return ss.str();
}
......
......@@ -56,7 +56,7 @@ shared_ptr<Parameter> ngraph::op::parameter(const ngraph::element::Type element_
return make_shared<Parameter>(make_shared<TensorViewType>(element_type, shape));
}
std::string ngraph::Parameter::node_id() const
std::string ngraph::Parameter::get_node_id() const
{
stringstream ss;
ss << "parameter_" << m_instance_id;
......
......@@ -26,11 +26,11 @@ bool TensorViewType::operator==(const ValueType::ptr& that) const
{
return false;
}
if (that_tvt->element_type() != m_element_type)
if (that_tvt->get_element_type() != m_element_type)
{
return false;
}
if (that_tvt->shape() != m_shape)
if (that_tvt->get_shape() != m_shape)
{
return false;
}
......@@ -44,5 +44,5 @@ bool TupleType::operator==(const ValueType::ptr& that) const
{
return false;
}
return that_tvt->element_types() == element_types();
return that_tvt->get_element_types() == get_element_types();
}
......@@ -136,11 +136,11 @@ static void traverse_nodes(std::shared_ptr<ngraph::Node> p,
std::set<size_t>& instances_seen)
{
f(p);
for (auto arg : p->arguments())
for (auto arg : p->get_arguments())
{
if (instances_seen.find(arg->instance_id()) == instances_seen.end())
if (instances_seen.find(arg->get_instance_id()) == instances_seen.end())
{
instances_seen.insert(arg->instance_id());
instances_seen.insert(arg->get_instance_id());
traverse_nodes(arg, f, instances_seen);
}
}
......
......@@ -30,8 +30,8 @@ TEST(build_graph, build_simple)
auto broadcast_1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto b1 = node<BroadcastOp>(arg3, Shape{10, 32, 7}, AxisSet{0});
auto dot = node<DotOp>(arg2, arg0);
ASSERT_EQ(dot->arguments()[0], arg2);
ASSERT_EQ(dot->arguments()[1], arg0);
ASSERT_EQ(dot->get_arguments()[0], arg2);
ASSERT_EQ(dot->get_arguments()[1], arg0);
auto cluster_0 = op::function(dot, {arg0, arg1, arg2, arg3});
......@@ -80,22 +80,22 @@ TEST(build_graph, literal)
//auto float0 = FloatScalarConstant::make(3.0);
auto float0 = node<FloatScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->value_type(), float_scalar_type);
ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
auto d = node<DotOp>(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
// float scalar from an int
auto float1 = node<FloatScalarConstant>(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->value_type(), float_scalar_type);
ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
auto int32_0 = node<Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->value_type(), float_scalar_type);
ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), float_scalar_type);
}
// Check argument inverses
......
......@@ -42,7 +42,7 @@ TEST(top_sort, basic)
auto f0 = op::function(r0, {arg0, arg1});
ASSERT_NE(nullptr, f0);
ASSERT_EQ(2, r0->arguments().size());
ASSERT_EQ(2, r0->get_arguments().size());
auto op_r0 = static_pointer_cast<Op>(r0);
cout << "op_r0 name " << *r0 << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment