Commit 059a9653 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Make ConstantFolding work even when shapes are not originally static (#3372)

* WIP

* CHANGE_DYNAMIC_STATE

* Implement full type prop for DynBroadcast when inputs const/static; clean up pass properties

* Add a unit test for the late-constness thing

* Fix merge

* style
parent 56618491
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp"
using namespace std;
......@@ -62,9 +63,67 @@ void op::DynBroadcast::validate_and_infer_types()
"DynBroadcast axes rank must be 1, but has ",
axes_shape_rank);
PartialShape result_shape{PartialShape::dynamic()};
if (input(1).get_source_output().get_node_shared_ptr()->is_constant())
{
result_shape =
static_pointer_cast<op::Constant>(input(1).get_source_output().get_node_shared_ptr())
->get_shape_val();
}
bool axes_known = false;
AxisSet broadcast_axes;
if (input(2).get_source_output().get_node_shared_ptr()->is_constant())
{
axes_known = true;
broadcast_axes =
static_pointer_cast<op::Constant>(input(2).get_source_output().get_node_shared_ptr())
->get_axis_set_val();
}
PartialShape arg_shape = input(0).get_partial_shape();
if (result_shape.is_static() && axes_known && arg_shape.is_static())
{
for (auto axis : broadcast_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(result_shape.rank()),
"Broadcast axis index (",
axis,
") exceeds specified output shape rank ",
"(broadcast axes: ",
broadcast_axes,
", output shape: ",
result_shape,
").");
}
Shape required_input_shape = result_shape.to_shape();
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
{
required_input_shape.erase(required_input_shape.begin() + *i);
}
// TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_CHECK(
this,
arg_shape.compatible(required_input_shape),
"Broadcast argument shape, specified output shape, and axes are incompatible ",
"(argument shape: ",
arg_shape,
", output shape: ",
result_shape,
", broadcast axes: ",
broadcast_axes,
").");
}
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
set_output_type(0, get_input_element_type(0), result_shape);
}
shared_ptr<Node> op::DynBroadcast::copy_with_new_args(const NodeVector& new_args) const
......
This diff is collapsed.
......@@ -1019,6 +1019,43 @@ TEST(constant_folding, constant_dyn_reshape)
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
{
Shape shape_in{2, 4};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_shape{3};
// We're going to add these two together elementwise to get {2, 4, 1}.
// This means that when ConstantFolding starts, DynReshape will not yet
// have static output shape. But by the time the Add op is folded, the
// DynReshape's shape should be inferrable.
vector<int64_t> values_shape_a{1, 3, 0};
vector<int64_t> values_shape_b{1, 1, 1};
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape =
make_shared<op::DynReshape>(constant_in, constant_shape_a + constant_shape_b);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic());
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_transpose)
{
Shape shape_in{2, 4};
......@@ -1140,5 +1177,5 @@ TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
ASSERT_EQ(true, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment