Commit a29f6754 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

add batchnorm training type tests (#746)

parent 2e8c6286
......@@ -44,6 +44,7 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
}
auto et = input->get_element_type();
Shape channel_shape{m_bn_input_shape[1]};
const char* input_names[] = {"gamma", "beta"};
for (size_t i = 0; i < 2; i++)
......@@ -54,21 +55,13 @@ ngraph::op::BatchNorm::BatchNorm(double eps,
" isn't equal to input data's type";
throw ngraph_error(err_msg.c_str());
}
}
if ((gamma->get_shape().size() != 1) || (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma and beta shoud have rank 1");
}
if (gamma->get_shape().size() != beta->get_shape().size())
if (get_input_op(i)->get_shape() != channel_shape)
{
throw ngraph_error("gamma and beta rank does not match");
auto err_msg = std::string("The shape of ") + input_names[i] +
" isn't equal to input channel's shape";
throw ngraph_error(err_msg.c_str());
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
add_output(input->get_element_type(), m_bn_input_shape);
......
......@@ -58,6 +58,89 @@ TEST(type_prop, broadcast_deduce_incorrect)
}
}
TEST(type_prop, batchnorm_rank_less_than_2)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{1});
try
{
auto bc = make_shared<op::BatchNorm>(0.001, dummy, dummy, dummy);
FAIL() << "BatchNorm c-tor should throw for tensors whose rank is less than 2";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("input tensor to batchnorm must have tensor of at least rank 2"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_zero_channel_check)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{1, 0, 2, 3});
try
{
auto bc = make_shared<op::BatchNorm>(0.001, dummy, dummy, dummy);
FAIL() << "BatchNorm c-tor should throw for tensors w/ zero-dimension channels";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string(
"input tensor must have at least one channel axis for batch normalization"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_et_check)
{
auto dummy_f32 = make_shared<op::Parameter>(element::f32, Shape{3});
auto dummy_f64 = make_shared<op::Parameter>(element::f64, Shape{3});
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
try
{
auto bc = make_shared<op::BatchNorm>(0.001, dummy_f32, dummy_f64, param);
FAIL() << "BatchNorm c-tor should throw for different element types";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("The element type of beta isn't equal to input data's type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_shape_check)
{
auto dummy_3 = make_shared<op::Parameter>(element::f32, Shape{3});
auto dummy_4 = make_shared<op::Parameter>(element::f32, Shape{4});
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
try
{
auto bc = make_shared<op::BatchNorm>(0.001, dummy_4, dummy_3, param);
FAIL() << "BatchNorm c-tor should throw if gamma and beta shapes don't match";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("The shape of gamma isn't equal to input channel's shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_backprop_4d_check)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment