Commit 08483fbd authored by Adam Procter's avatar Adam Procter Committed by Robert Kimball

Fix for AllReduce partial shape/type validation (#1913)

parent 759f79c0
......@@ -27,11 +27,6 @@ op::AllReduce::AllReduce(const shared_ptr<Node>& arg)
void op::AllReduce::validate_and_infer_types()
{
if (validate_punt_if_dynamic())
{
return;
}
NODE_VALIDATION_ASSERT(this,
get_input_element_type(0).is_dynamic() ||
get_input_element_type(0) == element::f32 ||
......@@ -39,7 +34,7 @@ void op::AllReduce::validate_and_infer_types()
<< "Only element types f32 and f64 are supported (argument element type: "
<< get_input_element_type(0) << ").";
set_output_type(0, get_input_element_type(0), get_input_shape(0));
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
shared_ptr<Node> op::AllReduce::copy_with_new_args(const NodeVector& new_args) const
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment