Commit 0e4c227b authored by adstraw's avatar adstraw Committed by Scott Cyphers

fix bug in Broadcast axes shape checking (#768)

parent 57fdd798
......@@ -31,6 +31,10 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
Shape target_shape = m_shape;
for (auto i = m_broadcast_axes.rbegin(); i != m_broadcast_axes.rend(); ++i)
{
if (*i >= target_shape.size())
{
throw ngraph_error("Broadcast axis exceeds target shape rank");
}
target_shape.erase(target_shape.begin() + *i);
}
if (Shape{target_shape} != input.get_shape())
......
......@@ -1709,6 +1709,16 @@ TEST(${BACKEND_NAME}, broadcast_scalar_vector)
EXPECT_EQ((vector<float>{6, 6, 6, 6}), read_vector<float>(result));
}
TEST(${BACKEND_NAME}, broadcast_to_non_existent_axis)
{
Shape shape_a{};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{4};
ASSERT_THROW(auto f = make_shared<Function>(
make_shared<op::Broadcast>(A, shape_r, AxisSet{0, 1}), op::ParameterVector{A}),
ngraph_error);
}
TEST(${BACKEND_NAME}, broadcast_scalar_matrix)
{
Shape shape_a{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment