Commit b839784e authored by Scott Cyphers's avatar Scott Cyphers

Use value_type instead of type to be consistent with STL

Use direct implementation of is_parameter, is_op
parent 92c4d314
......@@ -66,7 +66,7 @@ namespace ngraph
public:
// This is the C++ type used to hold a value of this element type during compilation
using ctype = T;
using type = T;
// This is a reference to an instance of this element type.
static const U& element_type(){
static U t;
......
......@@ -29,16 +29,6 @@ ngraph::Node::Node(const std::vector<Node::ptr>& arguments, ValueType::ptr type)
}
}
bool ngraph::Node::is_op() const
{
return dynamic_cast<const ngraph::Op*>(this) != nullptr;
}
bool ngraph::Node::is_parameter() const
{
return dynamic_cast<const ngraph::Parameter*>(this) != nullptr;
}
std::ostream& ngraph::operator<<(std::ostream& out, const ngraph::Node& node)
{
auto op_tmp = dynamic_cast<const ngraph::Op*>(&node);
......
......@@ -67,8 +67,8 @@ namespace ngraph
return typeid(*this) == typeid(*node.get());
}
bool is_op() const;
bool is_parameter() const;
virtual bool is_op() const { return false; };
virtual bool is_parameter() const { return false; };
size_t instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
......
......@@ -75,7 +75,8 @@ namespace ngraph
}
virtual std::string op_class_name() const = 0;
virtual std::string node_id() const;
virtual std::string node_id() const override;
virtual bool is_op() const override { return true; }
};
/**
......
......@@ -41,15 +41,15 @@ namespace ngraph
// The ngraph element type
using element_type = T;
// The C++ type that holds the element type
using ctype = typename T::ctype;
using type = typename T::type;
ScalarConstant(typename T::ctype value)
ScalarConstant(typename T::type value)
: ScalarConstantBase(std::make_shared<TensorViewType>(T::element_type(), Shape{}))
, m_value(value)
{
}
virtual std::string description() const override { return "ConstantScalar"; }
virtual std::string description() const override { return "ScalarConstant"; }
virtual std::string node_id() const override
{
std::stringstream ss;
......@@ -57,7 +57,7 @@ namespace ngraph
return ss.str();
}
typename T::ctype value() const { return m_value; }
typename T::type value() const { return m_value; }
// Make a constant from any value that can be converted to the C++ type we use
// to represent the values.
......@@ -68,7 +68,7 @@ namespace ngraph
}
protected:
typename T::ctype m_value;
typename T::type m_value;
};
using FloatScalarConstant = ScalarConstant<element::Float>;
......
......@@ -41,6 +41,7 @@ namespace ngraph
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
virtual std::string node_id() const override;
virtual bool is_parameter() const override { return true; };
protected:
Function* m_function;
......
......@@ -112,8 +112,8 @@ namespace ngraph
class TypedValueMixin
{
public:
TypedValueMixin(const ValueType::ptr& type = nullptr)
: m_type(type)
TypedValueMixin(const ValueType::ptr& value_type = nullptr)
: m_value_type(value_type)
{
}
......@@ -121,26 +121,26 @@ namespace ngraph
** Set the type
** /param type The new type
**/
void type(const ValueType::ptr& type) { m_type = type; }
void value_type(const ValueType::ptr& value_type) { m_value_type = value_type; }
/**
** Set the type to be a tensor view type
** /param element_type The type of the tensor elements
** /param shape The shape of the view
**/
void type(const element::Type& element_type, const Shape& shape)
void value_type(const element::Type& element_type, const Shape& shape)
{
m_type = std::make_shared<TensorViewType>(element_type, shape);
m_value_type = std::make_shared<TensorViewType>(element_type, shape);
}
/**
** The type associated with this value.
**/
ValueType::ptr type() { return m_type; }
ValueType::ptr value_type() { return m_value_type; }
/**
** The type associated with this value.
**/
const ValueType::ptr type() const { return m_type; }
const ValueType::ptr value_type() const { return m_value_type; }
protected:
ValueType::ptr m_type;
ValueType::ptr m_value_type;
};
}
......@@ -32,7 +32,7 @@ Node::ptr ngraph::op::broadcast(const Node::ptr& tensor,
void BroadcastOp::propagate_types()
{
auto arg_type = m_arguments.at(0)->type();
auto arg_type = m_arguments.at(0)->value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to broadcast is missing type.");
......@@ -53,5 +53,5 @@ void BroadcastOp::propagate_types()
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->element_type(), m_shape);
}
......@@ -27,8 +27,8 @@ Node::ptr ngraph::op::dot(const Node::ptr& arg0, const Node::ptr& arg1)
void DotOp::propagate_types()
{
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->type());
auto arg0_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(0)->value_type());
auto arg1_tensor_type = dynamic_pointer_cast<TensorViewType>(m_arguments.at(1)->value_type());
if (nullptr == arg0_tensor_type || nullptr == arg1_tensor_type)
{
throw ngraph_error("Arguments to dot must be tensor views");
......@@ -60,5 +60,5 @@ void DotOp::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
m_value_type = make_shared<TensorViewType>(arg0_tensor_type->element_type(), result_shape);
}
......@@ -93,7 +93,7 @@ TEST(build_graph, literal)
auto float0 = FloatScalarConstant::make(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float::element_type(), Shape{});
ASSERT_EQ(float0->value(), 3.0);
ASSERT_EQ(*float0->type(), float_scalar_type);
ASSERT_EQ(*float0->value_type(), float_scalar_type);
auto d = op::dot(float0, float0);
ASSERT_EQ(d->arguments().at(0), float0);
ASSERT_EQ(d->arguments().at(1), float0);
......@@ -101,13 +101,13 @@ TEST(build_graph, literal)
// float scalar from an int
auto float1 = FloatScalarConstant::make(3);
ASSERT_EQ(float1->value(), 3);
ASSERT_EQ(*float1->type(), float_scalar_type);
ASSERT_EQ(*float1->value_type(), float_scalar_type);
auto int32_0 = Int32ScalarConstant::make(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->value(), 3);
ASSERT_EQ(*int32_0->type(), int32_scalar_type);
ASSERT_NE(*int32_0->type(), float_scalar_type);
ASSERT_EQ(*int32_0->value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->value_type(), float_scalar_type);
}
// Check argument inverses
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment