Commit beaf154e authored by Ilya Churaev's avatar Ilya Churaev Committed by Sang Ik Lee

Fix broadcast v1 reference (#3880)

* Added reproducer for issue with broadcast v1

* Make reference broadcast work with V1 broadcast
parent 66ce838c
......@@ -34,13 +34,31 @@ namespace ngraph
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
CoordinateTransform input_transform(in_shape);
// Remove all broadcast axes from in_shape
Shape adjusted_in_shape;
for (auto length : in_shape)
{
if (length != 1)
{
adjusted_in_shape.push_back(length);
}
}
// Remove 1s from out_shape
AxisSet adjusted_axes(broadcast_axes);
for (uint64_t axis = 0; axis < out_shape.size(); ++axis)
{
auto length = out_shape.at(axis);
if (length == 1)
{
adjusted_axes.insert(axis);
}
}
CoordinateTransform input_transform(adjusted_in_shape);
CoordinateTransform output_transform(out_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate input_coord = reduce(output_coord, broadcast_axes);
Coordinate input_coord = reduce(output_coord, adjusted_axes);
out[output_transform.index(output_coord)] =
arg[input_transform.index(input_coord)];
}
......
......@@ -206,6 +206,30 @@ TEST(constant_folding, constant_broadcast_v1)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_broadcast_v1_with_target_shape)
{
vector<int32_t> values_in{1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{1, 1, 1, 1}, values_in);
vector<int64_t> shape_in{1, 3, 1, 1};
auto target_shape = make_shared<op::Constant>(element::i64, Shape{4}, shape_in);
auto broadcast_v1 = make_shared<op::v1::Broadcast>(constant_in, target_shape);
auto f = make_shared<Function>(broadcast_v1, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Broadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_broadcast_v1_numpy)
{
vector<int32_t> values_in{0, 1};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment