Commit 1ac3e5c7 authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

Bug Fix: incorrect shape validation logic. (#3898)

parent 6e405b81
......@@ -77,18 +77,16 @@ void op::ScatterNDAdd::validate_and_infer_types()
bool compatible = true;
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
{
for (size_t i = 0; i < static_cast<size_t>(indices_shape.rank()) - 1; i++)
size_t indices_rank = static_cast<size_t>(indices_shape.rank());
size_t updates_rank = static_cast<size_t>(updates_shape.rank());
for (size_t i = 0; i < indices_rank - 1; i++)
{
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
}
size_t j =
static_cast<size_t>(indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]);
for (size_t i = j; i < static_cast<size_t>(inputs_shape.rank()); i++)
size_t j = static_cast<size_t>(indices_shape[indices_rank - 1]);
for (size_t i = indices_rank - 1; i < updates_rank; i++, j++)
{
compatible =
compatible &&
updates_shape[static_cast<size_t>(indices_shape.rank()) + i - 2].same_scheme(
inputs_shape[i]);
compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment