Commit 5fe48a06 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Support construction of fused ops with partial shapes (#3770)

* Skip shape inference for fused ops in the presence of dynamic-shaped inputs

* Added unit test to check fused op construction with dynamic input shapes
parent 56976f0c
......@@ -29,6 +29,11 @@ op::Split::Split(const Output<Node>& data, const int axis, const size_t num_spli
, m_axis{axis}
, m_num_split{num_split}
{
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < num_split; i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
constructor_validate_and_infer_types();
}
......@@ -38,6 +43,11 @@ op::Split::Split(const Output<Node>& data, const int axis, const std::vector<siz
, m_axis{axis}
, m_splits{splits}
{
// Create dynamic-typed outputs. Actual shape/type will be computed during shape inference
for (size_t i = 0; i < splits.size(); i++)
{
set_output_type(i, element::dynamic, PartialShape::dynamic());
}
constructor_validate_and_infer_types();
}
......
......@@ -42,6 +42,23 @@ op::util::FusedOp::FusedOp(const std::string& node_type, const NodeVector& args)
void op::util::FusedOp::validate_and_infer_types()
{
// Bail out if any of the shapes are unknown since fused op decomposition
// typically requires fully-determined static types.
//
// In the absence of decomposition, we will not know how many outputs this
// fused op has, so we conservatively create and set a single output to
// facilitate downstream ops that would like to use this op as an argument.
// Multi-output fused ops (e.g., split) should create these outputs in their
// constructors instead.
for (size_t i = 0; i < get_input_size(); i++)
{
if (!get_input_partial_shape(i).is_static())
{
set_output_type(0, element::dynamic, PartialShape::dynamic());
return;
}
}
pre_validate_and_infer_types();
auto subgraph_outputs = decompose_op();
......
......@@ -167,6 +167,20 @@ TEST(build_graph, multi_output_split)
EXPECT_EQ(conv->get_shape(), (Shape{64, 128, 91, 131}));
}
TEST(build_graph, multi_output_split_dynamic)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto split = make_shared<op::Split>(data, 1, 2);
auto abs = make_shared<op::Abs>(split->output(1));
EXPECT_TRUE(abs->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
auto f = make_shared<Function>(abs, ParameterVector{data});
auto new_parameter = make_shared<op::Parameter>(element::f32, Shape{2, 4});
split->input(0).replace_source_output(new_parameter->output(0));
f->validate_nodes_and_infer_types();
EXPECT_EQ(abs->get_shape(), (Shape{2, 2}));
}
TEST(build_graph, function_revalidate_and_infer)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment