Commit b839784e authored by Scott Cyphers's avatar Scott Cyphers

Use value_type instead of type to be consistent with STL

Use direct implementation of is_parameter, is_op
parent 92c4d314
...@@ -66,7 +66,7 @@ namespace ngraph ...@@ -66,7 +66,7 @@ namespace ngraph
public: public:
// This is the C++ type used to hold a value of this element type during compilation // This is the C++ type used to hold a value of this element type during compilation
using ctype = T; using type = T;
// This is a reference to an instance of this element type. // This is a reference to an instance of this element type.
static const U& element_type(){ static const U& element_type(){
static U t; static U t;
......
...@@ -29,16 +29,6 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type) ...@@ -29,16 +29,6 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
} }
} }
bool ngraph::Node::is_op() const
{
return dynamic_cast<const ngraph::Op*>(this) != nullptr;
}
bool ngraph::Node::is_parameter() const
{
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr;
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node) std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
{ {
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node); auto op_tmp = dynamic_cast<const ngraph::Op*>(&node);
......
...@@ -67,8 +67,8 @@ namespace ngraph ...@@ -67,8 +67,8 @@ namespace ngraph
return typeid(*this) == typeid(*node.get()); return typeid(*this) == typeid(*node.get());
} }
bool is_op() const; virtual bool is_op() const { return false; };
bool is_parameter() const; virtual bool is_parameter() const { return false; };
size_t instance_id() const { return m_instance_id; } size_t instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&); friend std::ostream& operator<<(std::ostream&, const Node&);
......
...@@ -75,7 +75,8 @@ namespace ngraph ...@@ -75,7 +75,8 @@ namespace ngraph
} }
virtual std::string op_class_name() const = 0; virtual std::string op_class_name() const = 0;
virtual std::string node_id() const; virtual std::string node_id() const override;
virtual bool is_op() const override { return true; }
}; };
/** /**
......
...@@ -41,15 +41,15 @@ namespace ngraph ...@@ -41,15 +41,15 @@ namespace ngraph
// The ngraph element type // The ngraph element type
using element_type = T; using element_type = T;
// The C++ type that holds the element type // The C++ type that holds the element type
using ctype = typename T::ctype; using type = typename T::type;
ScalarConstant(typename T::ctype value) ScalarConstant(typename T::type value)
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{})) : ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value) , m_value(value)
{ {
} }
virtual std::string description() const override { return "ConstantScalar"; } virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string node_id() const override virtual std::string node_id() const override
{ {
std::stringstream ss; std::stringstream ss;
...@@ -57,7 +57,7 @@ namespace ngraph ...@@ -57,7 +57,7 @@ namespace ngraph
return ss.str(); return ss.str();
} }
typename T::ctype value() const { return m_value; } typename T::type value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use // Make a constant from any value that can be converted to the C++ type we use
// to represent the values. // to represent the values.
...@@ -68,7 +68,7 @@ namespace ngraph ...@@ -68,7 +68,7 @@ namespace ngraph
} }
protected: protected:
typename T::ctype m_value; typename T::type m_value;
}; };
using FloatScalarConstant = ScalarConstant<element::Float>; using FloatScalarConstant = ScalarConstant<element::Float>;
......
...@@ -41,6 +41,7 @@ namespace ngraph ...@@ -41,6 +41,7 @@ namespace ngraph
std::string description() const override { return "Parameter"; } std::string description() const override { return "Parameter"; }
virtual void propagate_types() override; virtual void propagate_types() override;
virtual std::string node_id() const override; virtual std::string node_id() const override;
virtual bool is_parameter() const override { return true; };
protected: protected:
Function* m_function; Function* m_function;
......
...@@ -112,8 +112,8 @@ namespace ngraph ...@@ -112,8 +112,8 @@ namespace ngraph
class TypedValueMixin class TypedValueMixin
{ {
public: public:
TypedValueMixin(const ValueType::ptr& type = nullptr) TypedValueMixin(const ValueType::ptr& value_type = nullptr)
: m_type(type) : m_value_type(value_type)
{ {
} }
...@@ -121,26 +121,26 @@ namespace ngraph ...@@ -121,26 +121,26 @@ namespace ngraph
** Set the type ** Set the type
** /param type The new type ** /param type The new type
**/ **/
void type(const ValueType::ptr& type) { m_type = type; } void value_type(const ValueType::ptr& value_type) { m_value_type = value_type; }
/** /**
** Set the type to be a tensor view type ** Set the type to be a tensor view type
** /param element_type The type of the tensor elements ** /param element_type The type of the tensor elements
** /param shape The shape of the view ** /param shape The shape of the view
**/ **/
void type(const element::Type& element_type, const Shape& shape) void value_type(const element::Type& element_type, const Shape& shape)
{ {
m_type = std::make_shared<TensorViewType>(element_type, shape); m_value_type = std::make_shared<TensorViewType>(element_type, shape);
} }
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
ValueType::ptr type() { return m_type; } ValueType::ptr value_type() { return m_value_type; }
/** /**
** The type associated with this value. ** The type associated with this value.
**/ **/
const ValueType::ptr type() const { return m_type; } const ValueType::ptr value_type() const { return m_value_type; }
protected: protected:
ValueType::ptr m_type; ValueType::ptr m_value_type;
}; };
} }
...@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor, ...@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
void BroadcastOp::propagate_types() void BroadcastOp::propagate_types()
{ {
auto arg_type = m_arguments.at(0)->type(); auto arg_type = m_arguments.at(0)->value_type();
if (nullptr == arg_type) if (nullptr == arg_type)
{ {
throw ngraph_error("Argument to broadcast is missing type."); throw ngraph_error("Argument to broadcast is missing type.");
...@@ -53,5 +53,5 @@ void BroadcastOp::propagate_types() ...@@ -53,5 +53,5 @@ void BroadcastOp::propagate_types()
} }
// TODO If m_type is already set (by framework), this should verify that the type // TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects. // we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape); m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
} }
...@@ -27,8 +27,8 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1) ...@@ -27,8 +27,8 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void DotOp::propagate_types() void DotOp::propagate_types()
{ {
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type()); auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->value_type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type()); auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type) if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{ {
throw ngraph_error("Arguments to dot must be tensor views"); throw ngraph_error("Arguments to dot must be tensor views");
...@@ -60,5 +60,5 @@ void DotOp::propagate_types() ...@@ -60,5 +60,5 @@ void DotOp::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end()); copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end()); copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end()); copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape); m_value_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
} }
...@@ -93,7 +93,7 @@ TEST(build_graph, literal) ...@@ -93,7 +93,7 @@ TEST(build_graph, literal)
auto float0 = FloatScalarConstant::make(3.0); auto float0 = FloatScalarConstant::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{}); auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0); ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type); ASSERT_EQ(*float0->value_type(), float_scalar_type);
auto d = op::dot(float0, float0); auto d = op::dot(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0); ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0); ASSERT_EQ(d->arguments().at(1), float0);
...@@ -101,13 +101,13 @@ TEST(build_graph, literal) ...@@ -101,13 +101,13 @@ TEST(build_graph, literal)
// float scalar from an int // float scalar from an int
auto float1 = FloatScalarConstant::make(3); auto float1 = FloatScalarConstant::make(3);
ASSERT_EQ(float1->value(), 3); ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type); ASSERT_EQ(*float1->value_type(), float_scalar_type);
auto int32_0 = Int32ScalarConstant::make(3.0); auto int32_0 = Int32ScalarConstant::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{}); auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3); ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type); ASSERT_EQ(*int32_0->value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type); ASSERT_NE(*int32_0->value_type(), float_scalar_type);
} }
// Check argument inverses // Check argument inverses
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment