Commit 1ac3e5c7 authored by Sang Ik Lee's avatar Sang Ik Lee Committed by Scott Cyphers

Bug Fix: incorrect shape validation logic. (#3898)

parent 6e405b81
...@@ -77,18 +77,16 @@ void op::ScatterNDAdd::validate_and_infer_types() ...@@ -77,18 +77,16 @@ void op::ScatterNDAdd::validate_and_infer_types()
bool compatible = true; bool compatible = true;
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static()) if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
{ {
for (size_t i = 0; i < static_cast<size_t>(indices_shape.rank()) - 1; i++) size_t indices_rank = static_cast<size_t>(indices_shape.rank());
size_t updates_rank = static_cast<size_t>(updates_shape.rank());
for (size_t i = 0; i < indices_rank - 1; i++)
{ {
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]); compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
} }
size_t j = size_t j = static_cast<size_t>(indices_shape[indices_rank - 1]);
static_cast<size_t>(indices_shape[static_cast<size_t>(indices_shape.rank()) - 1]); for (size_t i = indices_rank - 1; i < updates_rank; i++, j++)
for (size_t i = j; i < static_cast<size_t>(inputs_shape.rank()); i++)
{ {
compatible = compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]);
compatible &&
updates_shape[static_cast<size_t>(indices_shape.rank()) + i - 2].same_scheme(
inputs_shape[i]);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment