Commit acc5d2d2 authored by Adam Procter's avatar Adam Procter

Finish type propagation in dot, add unit tests

parent 8bc8579c
...@@ -35,7 +35,7 @@ void Dot::propagate_types() ...@@ -35,7 +35,7 @@ void Dot::propagate_types()
} }
// Use NumPy semantics for now // Use NumPy semantics for now
// Last axis of first arg reduces against second to last of second arg if more than one axis, else axis. // Last axis of first arg reduces against second to last of second arg if more than one axis, else the only axis.
vector<size_t> arg0_shape = arg0_tensor_type->get_shape(); vector<size_t> arg0_shape = arg0_tensor_type->get_shape();
vector<size_t> arg1_shape = arg1_tensor_type->get_shape(); vector<size_t> arg1_shape = arg1_tensor_type->get_shape();
size_t arg0_reduction = arg0_shape.size() - 1; size_t arg0_reduction = arg0_shape.size() - 1;
...@@ -52,9 +52,18 @@ void Dot::propagate_types() ...@@ -52,9 +52,18 @@ void Dot::propagate_types()
{ {
throw ngraph_error("Dot reduction axes not compatible"); throw ngraph_error("Dot reduction axes not compatible");
} }
vector<size_t> result_shape; vector<size_t> result_shape;
copy(arg0_shape.begin(), arg0_shape.begin() + arg1_reduction, result_shape.end()); result_shape.reserve(arg0_shape.size() + arg1_shape.size() - 2);
copy(arg1_shape.begin(), arg1_shape.begin() + arg1_reduction, result_shape.end());
copy(arg1_shape.begin() + arg1_reduction, arg1_shape.end(), result_shape.end()); for(auto i = 0; i < arg0_shape.size(); i++)
set_value_type_checked(make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape)); if(i != arg0_reduction)
result_shape.push_back(arg0_shape[i]);
for(auto i = 0; i < arg1_shape.size(); i++)
if(i != arg1_reduction)
result_shape.push_back(arg1_shape[i]);
auto result_type = make_shared<TensorViewType>(arg0_tensor_type->get_element_type(), result_shape);
set_value_type_checked(result_type);
} }
...@@ -25,6 +25,9 @@ void test_binary_bad_arguments_views(const shared_ptr<Node>& node); ...@@ -25,6 +25,9 @@ void test_binary_bad_arguments_views(const shared_ptr<Node>& node);
void test_binary_good_arguments(const shared_ptr<Node>& node); void test_binary_good_arguments(const shared_ptr<Node>& node);
void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y)); void test_binary(shared_ptr<Node>(f)(const shared_ptr<Node>& x, const shared_ptr<Node>& y));
//
// Tests for broadcast.
//
TEST(type_prop, broadcast_deduce) TEST(type_prop, broadcast_deduce)
{ {
// Deduce type // Deduce type
...@@ -91,6 +94,102 @@ TEST(type_prop, broadcast_bad_arguments) ...@@ -91,6 +94,102 @@ TEST(type_prop, broadcast_bad_arguments)
} }
} }
//
// Tests for dot product.
//
TEST(type_prop, dot_deduce_1d)
{
// Deduce type for 1D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{}));
}
TEST(type_prop, dot_deduce_2d)
{
// Deduce type for 2D arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{4,3}));
}
TEST(type_prop, dot_deduce_different_d)
{
// Deduce type for different-dimension arguments
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,8,4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1,2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2,8,4,1,3}));
}
TEST(type_prop, dot_deduce_different_d_correct)
{
// Deduced type matches explicitly set type
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{2,8,4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{1,2,3});
auto bc = make_shared<op::Dot>(param1, param2);
bc->set_value_type(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{2,8,4,1,3}));
bc->propagate_types();
auto bc_vt = bc->get_value_type();
ASSERT_EQ(*bc_vt, TensorViewType(element::Float32::element_type(), Shape{2,8,4,1,3}));
}
TEST(type_prop, dot_deduce_element_type_mismatch)
{
// Type deduction fails due to element type mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Int32::element_type(), Shape{2,5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Element type mismatch not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Arguments to dot must have the same element type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dot_deduce_reduction_axes_size_mismatch)
{
// Type deduction fails due to reduction axes size mismatch
auto param1 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{4,2});
auto param2 = make_shared<op::Parameter>(element::Float32::element_type(), Shape{3,5});
auto bc = make_shared<op::Dot>(param1, param2);
try
{
bc->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Dot reduction axes size mismatch not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Dot reduction axes not compatible"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
//
// Tests for binary elementwise ops.
//
void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node) void test_binary_bad_arguments_tuple(const shared_ptr<Node>& node)
{ {
try try
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment