Unverified Commit b3d2ff59 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

type_prop tests for batchnorm bprop (#601)

parent 4fc1a478
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/ops/batch_norm.hpp"
#include <memory> #include <memory>
using namespace std; using namespace std;
...@@ -57,6 +58,96 @@ TEST(type_prop, broadcast_deduce_incorrect) ...@@ -57,6 +58,96 @@ TEST(type_prop, broadcast_deduce_incorrect)
} }
} }
TEST(type_prop, batchnorm_backprop_4d_check)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{});
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 4});
try
{
auto bc =
make_shared<op::BatchNormBackprop>(0.001, dummy, dummy, param, dummy, dummy, dummy);
FAIL() << "Deduced type should disagree with c-tor arguments";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Input expected to be a 4D tensor"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_backprop_et_check)
{
auto dummy_f32 = make_shared<op::Parameter>(element::f32, Shape{3});
auto dummy_f64 = make_shared<op::Parameter>(element::f64, Shape{3});
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
try
{
auto bc = make_shared<op::BatchNormBackprop>(
0.001, dummy_f32, dummy_f64, param, dummy_f32, dummy_f32, dummy_f32);
FAIL() << "Deduced type should disagree with c-tor arguments";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("The element type of beta isn't equal to input data's type"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_backprop_shape_check)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{3});
auto dummy2 = make_shared<op::Parameter>(element::f32, Shape{4});
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
try
{
auto bc =
make_shared<op::BatchNormBackprop>(0.001, dummy, dummy2, param, dummy2, dummy2, dummy2);
FAIL() << "Deduced type should disagree with c-tor arguments";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(),
std::string("The shape of beta isn't equal to input channel's shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, batchnorm_backprop_delta_check)
{
auto dummy = make_shared<op::Parameter>(element::f32, Shape{3});
auto dummy2 = make_shared<op::Parameter>(element::f32, Shape{4});
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 2});
auto delta = make_shared<op::Parameter>(element::f32, Shape{4, 3, 2, 3});
try
{
auto bc =
make_shared<op::BatchNormBackprop>(0.001, dummy, dummy, param, dummy, dummy, delta);
FAIL() << "Deduced type should disagree with c-tor arguments";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("delta shape is expected to be equal to input shape"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, concat_deduce) TEST(type_prop, concat_deduce)
{ {
// Deduce type // Deduce type
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment