Commit a9a3ae79 authored by baojun's avatar baojun Committed by Scott Cyphers

Fix layernorm flatten issue (#4032)

* fix layernorm flatten issue

* update ut

* checkout output val

* fix style

* apply tolerance
parent 075665ce
......@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c
}
}
void op::LayerNorm::validate_and_infer_types()
void op::LayerNorm::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
......@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new
}
}
void op::LayerNormBackprop::validate_and_infer_types()
void op::LayerNormBackprop::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
......
......@@ -56,7 +56,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
void validate_and_infer_types() override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -121,7 +121,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
void validate_and_infer_types() override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -318,6 +318,7 @@ random_uniform_dynamic_shapes
layer_norm_affine_stats
layer_norm_bprop_affine_stats
layer_norm_bprop_affine
layer_norm_bprop_4d_input
# Another fused op decomposition pass required after the downgrade pass
model_split_equal_parts_default
......
......@@ -194,3 +194,60 @@ NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_affine)
EXPECT_TRUE(test::all_close_f(exp_d_scale, read_vector<float>(d_scale)));
EXPECT_TRUE(test::all_close_f(exp_d_bias, read_vector<float>(d_bias)));
}
NGRAPH_TEST(${BACKEND_NAME}, layer_norm_bprop_4d_input)
{
auto p_data = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto p_delta = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto p_mean = make_shared<op::Parameter>(element::f32, Shape{2});
auto p_variance = make_shared<op::Parameter>(element::f32, Shape{2});
auto p_scale = make_shared<op::Parameter>(element::f32, Shape{60});
auto lnb = make_shared<op::LayerNormBackprop>(p_data, p_delta, p_mean, p_variance, p_scale);
auto output_data = lnb->output(0);
auto output_scale = lnb->output(1);
auto output_bias = lnb->output(2);
// flatten output_scale
auto output_scale_shape = output_scale.get_shape();
auto flattened_output_scale = make_shared<op::Reshape>(
output_scale, get_default_order(output_scale_shape), Shape{shape_size(output_scale_shape)});
auto f = make_shared<Function>(OutputVector{output_data, flattened_output_scale, output_bias},
ParameterVector{p_data, p_delta, p_mean, p_variance, p_scale});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create tensors for input
auto data = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
auto delta = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
auto mean = backend->create_tensor(element::f32, Shape{2});
auto variance = backend->create_tensor(element::f32, Shape{2});
auto scale = backend->create_tensor(element::f32, Shape{60});
// Fill in input tensors
vector<float> d_input(2 * 3 * 4 * 5, 1);
copy_data(data, d_input);
vector<float> dt_input(2 * 3 * 4 * 5, 1);
copy_data(delta, dt_input);
vector<float> m_input(2, 1);
copy_data(mean, m_input);
vector<float> v_input(2, 1);
copy_data(variance, v_input);
vector<float> s_input(60, 1);
copy_data(scale, s_input);
// Create tensors for output
auto d_data = backend->create_tensor(element::f32, Shape{2, 3, 4, 5});
auto d_scale = backend->create_tensor(element::f32, Shape{60});
auto d_bias = backend->create_tensor(element::f32, Shape{60});
auto handle = backend->compile(f);
handle->call_with_validate({d_data, d_scale, d_bias}, {data, delta, mean, variance, scale});
vector<float> expected_data(120, 0);
vector<float> expected_scale(60, 0);
vector<float> expected_bias(60, 2);
EXPECT_TRUE(test::all_close_f(expected_data, read_vector<float>(d_data), 1e-6f, 1e-6f));
EXPECT_TRUE(test::all_close_f(expected_scale, read_vector<float>(d_scale), 1e-6f, 1e-6f));
EXPECT_TRUE(test::all_close_f(expected_bias, read_vector<float>(d_bias), 1e-6f, 1e-6f));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment