Commit 961b4e0a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #80 from NervanaSystems/cyphers/sizes

Fix/sanitize type comparisons, add test
parents 3c815ee4 27d4f16a
......@@ -30,6 +30,8 @@ namespace ngraph
{
class Type
{
Type(const Type&) = delete;
Type& operator=(const Type&) = delete;
public:
Type(size_t bitwidth, bool is_float, bool is_signed, const std::string& cname);
......@@ -75,6 +77,8 @@ namespace ngraph
template <typename T>
class TraitedType : public Type
{
TraitedType(const TraitedType&) = delete;
TraitedType& operator=(const TraitedType&) = delete;
protected:
TraitedType()
: Type(sizeof(T) * 8,
......
......@@ -30,6 +30,17 @@ ngraph::Node::Node(const std::vector<std::shared_ptr<Node>>& arguments,
}
}
void ngraph::Node::set_value_type_checked(const std::shared_ptr<ValueType>& value_type)
{
if (nullptr == m_value_type){
m_value_type = value_type;
} else {
if (*m_value_type != *value_type){
throw ngraph::ngraph_error("Setting value type to a different ValueType");
}
}
}
bool ngraph::Node::is_op() const
{
return dynamic_cast<const ngraph::Op*>(this) != nullptr;
......
......@@ -83,6 +83,12 @@ namespace ngraph
m_value_type = value_type;
}
// Set the value type if it has not already been set; otherwise, ensure that
// value_type agrees with the value type that was set.
// This is used when the framework specifies a value type for the value, and we
// independently compute what we thing the value type should be from the arguments.
void set_value_type_checked(const std::shared_ptr<ValueType>& value_type);
bool is_op() const;
bool is_parameter() const;
......
......@@ -37,8 +37,8 @@ namespace ngraph
void assign_function(Function* function, size_t index);
public:
Parameter(const std::shared_ptr<ValueType>& value_type);
Parameter(const ngraph::element::Type element_type, const Shape& shape);
Parameter(const std::shared_ptr<ValueType>& value_type=nullptr);
Parameter(const ngraph::element::Type& element_type, const Shape& shape);
std::string description() const override { return "Parameter"; }
virtual void propagate_types() override;
......
......@@ -32,8 +32,8 @@ namespace ngraph
{
public:
virtual ~ValueType() {}
virtual bool operator==(const std::shared_ptr<ValueType>& that) const = 0;
bool operator!=(const std::shared_ptr<ValueType>& that) const { return !(*this == that); }
virtual bool operator==(const ValueType& that) const = 0;
bool operator!=(const ValueType& that) const { return !(*this == that); }
};
/// Describes a tensor view; an element type and a shape.
......@@ -51,7 +51,7 @@ namespace ngraph
const element::Type& get_element_type() const { return m_element_type; }
const Shape& get_shape() const { return m_shape; }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
virtual bool operator==(const ValueType& that) const override;
protected:
const element::Type& m_element_type;
......@@ -77,7 +77,7 @@ namespace ngraph
}
std::vector<std::shared_ptr<ValueType>> set_element_types() { return m_element_types; }
virtual bool operator==(const std::shared_ptr<ValueType>& that) const override;
virtual bool operator==(const ValueType& that) const override;
protected:
std::vector<std::shared_ptr<ValueType>> m_element_types;
......
......@@ -38,7 +38,5 @@ void Broadcast::propagate_types()
{
throw ngraph_error("Broadcast arg, shape, and axes are incompatible");
}
// TODO If m_type is already set (by framework), this should verify that the type
// we expect is consistent with the type the framework expects.
m_value_type = make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape);
set_value_type_checked(make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), m_shape));
}
......@@ -56,5 +56,5 @@ void Dot::propagate_types()
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end());
m_value_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape));
}
......@@ -26,7 +26,7 @@ Parameter::Parameter(const std::shared_ptr<ValueType>& value_type)
{
}
Parameter::Parameter(const ngraph::element::Type element_type, const Shape& shape)
Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape)
: Parameter(make_shared<TensorViewType>(element_type, shape))
{
}
......
......@@ -19,9 +19,9 @@
using namespace std;
using namespace ngraph;
bool TensorViewType::operator==(const std::shared_ptr<ValueType>& that) const
bool TensorViewType::operator==(const ValueType& that) const
{
auto that_tvt = dynamic_pointer_cast<TensorViewType>(that);
auto that_tvt = dynamic_cast<const TensorViewType*>(&that);
if (nullptr == that_tvt)
{
return false;
......@@ -37,9 +37,9 @@ bool TensorViewType::operator==(const std::shared_ptr<ValueType>& that) const
return true;
}
bool TupleType::operator==(const std::shared_ptr<ValueType>& that) const
bool TupleType::operator==(const ValueType& that) const
{
auto that_tvt = dynamic_pointer_cast<TupleType>(that);
auto that_tvt = dynamic_cast<const TupleType*>(&that);
if (nullptr == that_tvt)
{
return false;
......
......@@ -81,7 +81,7 @@ TEST(build_graph, literal)
auto float0 = make_shared<op::Float32ScalarConstant>(3.0);
auto float_scalar_type = make_shared<TensorViewType>(element::Float32::element_type(), Shape{});
ASSERT_EQ(float0->get_value(), 3.0);
ASSERT_EQ(*float0->get_value_type(), float_scalar_type);
ASSERT_EQ(*float0->get_value_type(), *float_scalar_type);
auto d = make_shared<op::Dot>(float0, float0);
ASSERT_EQ(d->get_arguments().at(0), float0);
ASSERT_EQ(d->get_arguments().at(1), float0);
......@@ -89,13 +89,51 @@ TEST(build_graph, literal)
// float scalar from an int
auto float1 = make_shared<op::Float32ScalarConstant>(3);
ASSERT_EQ(float1->get_value(), 3);
ASSERT_EQ(*float1->get_value_type(), float_scalar_type);
ASSERT_EQ(*float1->get_value_type(), *float_scalar_type);
auto int32_0 = make_shared<op::Int32ScalarConstant>(3.0);
auto int32_scalar_type = make_shared<TensorViewType>(element::Int32::element_type(), Shape{});
ASSERT_EQ(int32_0->get_value(), 3);
ASSERT_EQ(*int32_0->get_value_type(), int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), float_scalar_type);
ASSERT_EQ(*int32_0->get_value_type(), *int32_scalar_type);
ASSERT_NE(*int32_0->get_value_type(), *float_scalar_type);
}
TEST(build_graph, set_value_type_checked)
{
auto untyped_param = make_shared<op::Parameter>();
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
} catch(...){
FAIL() << "Setting value type for first time type failed.";
}
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
} catch(...){
FAIL() << "Setting value type to same type failed.";
}
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 5}));
FAIL() << "Setting value type to a different shape did not fail.";
} catch(const ngraph_error& error){
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
} catch(...){
FAIL() << "Setting value type to a different shape did not failed with incorrect error.";
}
try {
untyped_param->set_value_type_checked(make_shared<TensorViewType>(element::Int32::element_type(), Shape{4, 4}));
FAIL() << "Setting value type to a different element type did not fail.";
} catch(const ngraph_error& error){
EXPECT_EQ(error.what(), std::string("Setting value type to a different ValueType"));
} catch(...){
FAIL() << "Setting value type to a different element type did not failed with incorrect error.";
}
auto param = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4, 4});
try {
param->set_value_type_checked(make_shared<TensorViewType>(element::Float32::element_type(), Shape{4, 4}));
} catch(...){
FAIL() << "Setting value type to same type failed.";
}
}
// Check argument inverses
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment