Commit 143fd0f2 authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by Scott Cyphers

MatMul and Reduces rank propagation (#3875)

parent 6fbed3b9
......@@ -37,6 +37,28 @@ op::MatMul::MatMul(const Output<Node>& A,
constructor_validate_and_infer_types();
}
void op::MatMul::pre_validate_and_infer_types()
{
element::Type result_et;
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)),
"Arguments do not have the same element type (arg0 element type: ",
get_input_element_type(0),
", arg1 element type: ",
get_input_element_type(1),
").");
const Rank& A_rank = get_input_partial_shape(0).rank();
const Rank& B_rank = get_input_partial_shape(1).rank();
if (A_rank.is_static() && B_rank.is_static())
{
Rank max_rank = int64_t(A_rank) > int64_t(B_rank) ? A_rank : B_rank;
set_output_type(0, result_et, PartialShape::dynamic(max_rank));
}
}
NodeVector op::MatMul::decompose_op() const
{
auto A = input_value(0);
......
......@@ -43,6 +43,8 @@ namespace ngraph
const bool& transpose_a = 0,
const bool& transpose_b = 0);
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
......
......@@ -162,25 +162,19 @@ op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis)
, m_axis(axis)
{
constructor_validate_and_infer_types();
const PartialShape& input_shape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_shape.rank().is_static(),
"Input node rank must be static (input_shape=",
input_shape,
").");
NODE_VALIDATION_CHECK(this,
axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
axis,
") is out of bounds (argument shape: ",
input_shape,
").");
}
void op::v1::Softmax::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
if (input_shape.rank().is_static())
NODE_VALIDATION_CHECK(this,
m_axis < static_cast<size_t>(input_shape.rank()),
"Reduction axis (",
m_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
if (input_shape.is_static())
set_output_type(0, get_input_element_type(0), input_shape.to_shape());
else
......
......@@ -38,6 +38,9 @@ void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
auto input_rank = input_shape.rank();
PartialShape result_shape{PartialShape::dynamic()};
if (input_rank.is_static())
result_shape = PartialShape::dynamic(input_rank);
if (input_rank.is_static() && reduction_axes_constant())
{
std::vector<Dimension> dims;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment