Commit ecce61f1 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

double check that args for dot_transpose are 2d (#996)

parent 175d51fa
......@@ -182,10 +182,20 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
}
auto arg0 = mdot->get_argument(0);
if (arg0->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Arg0 has the wrong shape. " << vector_to_string(arg0->get_shape());
return false;
}
auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)};
auto reshape0 = std::make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
auto arg1 = mdot->get_argument(1);
if (arg1->get_shape().size() != 2)
{
NGRAPH_DEBUG << "Arg1 has the wrong shape. " << vector_to_string(arg1->get_shape());
return false;
}
auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)};
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment