Commit 059a9653 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Make ConstantFolding work even when shapes are not originally static (#3372)

* WIP

* CHANGE_DYNAMIC_STATE

* Implement full type prop for DynBroadcast when inputs const/static; clean up pass properties

* Add a unit test for the late-constness thing

* Fix merge

* style
parent 56618491
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp"
using namespace std;
......@@ -62,9 +63,67 @@ void op::DynBroadcast::validate_and_infer_types()
"DynBroadcast axes rank must be 1, but has ",
axes_shape_rank);
PartialShape result_shape{PartialShape::dynamic()};
if (input(1).get_source_output().get_node_shared_ptr()->is_constant())
{
result_shape =
static_pointer_cast<op::Constant>(input(1).get_source_output().get_node_shared_ptr())
->get_shape_val();
}
bool axes_known = false;
AxisSet broadcast_axes;
if (input(2).get_source_output().get_node_shared_ptr()->is_constant())
{
axes_known = true;
broadcast_axes =
static_pointer_cast<op::Constant>(input(2).get_source_output().get_node_shared_ptr())
->get_axis_set_val();
}
PartialShape arg_shape = input(0).get_partial_shape();
if (result_shape.is_static() && axes_known && arg_shape.is_static())
{
for (auto axis : broadcast_axes)
{
NODE_VALIDATION_CHECK(this,
axis < size_t(result_shape.rank()),
"Broadcast axis index (",
axis,
") exceeds specified output shape rank ",
"(broadcast axes: ",
broadcast_axes,
", output shape: ",
result_shape,
").");
}
Shape required_input_shape = result_shape.to_shape();
for (auto i = broadcast_axes.rbegin(); i != broadcast_axes.rend(); ++i)
{
required_input_shape.erase(required_input_shape.begin() + *i);
}
// TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
NODE_VALIDATION_CHECK(
this,
arg_shape.compatible(required_input_shape),
"Broadcast argument shape, specified output shape, and axes are incompatible ",
"(argument shape: ",
arg_shape,
", output shape: ",
result_shape,
", broadcast axes: ",
broadcast_axes,
").");
}
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
set_output_type(0, get_input_element_type(0), result_shape);
}
shared_ptr<Node> op::DynBroadcast::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -115,6 +115,19 @@
using namespace std;
using namespace ngraph;
static bool revalidate_and_ensure_static(shared_ptr<Node> n)
{
n->revalidate_and_infer_types();
for (auto& o : n->outputs())
{
if (o.get_partial_shape().is_dynamic() || o.get_element_type().is_dynamic())
{
return false;
}
}
return true;
}
template <class T>
shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant,
shared_ptr<op::Reshape> reshape,
......@@ -159,6 +172,8 @@ void pass::ConstantFolding::construct_constant_reshape()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(reshape_match));
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
......@@ -227,7 +242,7 @@ void pass::ConstantFolding::construct_constant_reshape()
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
reshape_matcher, constant_reshape_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -289,6 +304,8 @@ void pass::ConstantFolding::construct_constant_pad()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto pad_match = static_pointer_cast<op::Pad>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(pad_match));
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
......@@ -353,7 +370,7 @@ void pass::ConstantFolding::construct_constant_pad()
};
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -397,6 +414,8 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto dyn_reshape_match = static_pointer_cast<op::DynReshape>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(dyn_reshape_match));
std::shared_ptr<Node> replacement;
auto type = dyn_reshape_match->get_element_type();
switch (type.get_type_enum())
......@@ -466,7 +485,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
auto dyn_reshape_matcher =
make_shared<pattern::Matcher>(dyn_reshape, "ConstantFolding.ConstantDynReshape");
this->add_matcher(
dyn_reshape_matcher, constant_dyn_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
dyn_reshape_matcher, constant_dyn_reshape_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -509,6 +528,8 @@ void pass::ConstantFolding::construct_constant_transpose()
static_pointer_cast<op::Constant>(pattern_map[constant_perm_label]);
auto transpose_match = static_pointer_cast<op::Transpose>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(transpose_match));
std::shared_ptr<Node> replacement;
auto type = transpose_match->get_element_type();
switch (type.get_type_enum())
......@@ -582,7 +603,7 @@ void pass::ConstantFolding::construct_constant_transpose()
auto transpose_matcher =
make_shared<pattern::Matcher>(transpose, "ConstantFolding.ConstantTranspose");
this->add_matcher(
transpose_matcher, constant_transpose_callback, PassProperty::REQUIRE_STATIC_SHAPE);
transpose_matcher, constant_transpose_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -630,6 +651,8 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto broadcast_match = static_pointer_cast<op::Broadcast>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(broadcast_match));
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
......@@ -699,7 +722,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
auto broadcast_matcher =
make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast");
this->add_matcher(
broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
broadcast_matcher, constant_broadcast_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -747,6 +770,8 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
static_pointer_cast<op::Constant>(pattern_map[constant_axes_label]);
auto dyn_broadcast_match = static_pointer_cast<op::DynBroadcast>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(dyn_broadcast_match));
std::shared_ptr<Node> replacement;
auto type = dyn_broadcast_match->get_output_element_type(0);
switch (type.get_type_enum())
......@@ -820,7 +845,7 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
auto dyn_broadcast_matcher =
make_shared<pattern::Matcher>(dyn_broadcast, "ConstantFolding.ConstantDynBroadcast");
this->add_matcher(
dyn_broadcast_matcher, constant_dyn_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
dyn_broadcast_matcher, constant_dyn_broadcast_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class Tin, class Tout>
......@@ -1086,6 +1111,8 @@ void pass::ConstantFolding::construct_constant_binary()
return false;
}
NGRAPH_CHECK(revalidate_and_ensure_static(binary_match));
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
......@@ -1168,7 +1195,7 @@ void pass::ConstantFolding::construct_constant_binary()
auto reshape_matcher = make_shared<pattern::Matcher>(be, "ConstantFolding.ConstantBinary");
this->add_matcher(
reshape_matcher, constant_binary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
reshape_matcher, constant_binary_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
bool is_supported_unary_op(std::shared_ptr<Node> n)
......@@ -1281,6 +1308,8 @@ void pass::ConstantFolding::construct_constant_unary()
return false;
}
NGRAPH_CHECK(revalidate_and_ensure_static(unary_match));
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
......@@ -1348,7 +1377,7 @@ void pass::ConstantFolding::construct_constant_unary()
};
auto reshape_matcher = make_shared<pattern::Matcher>(ue, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class QUANT, class REAL>
......@@ -1390,11 +1419,13 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto dequant_match = pattern_map[dequant];
auto dequantize_op = dynamic_pointer_cast<op::Dequantize>(dequant_match);
auto scale = dynamic_pointer_cast<op::Constant>(
dequant_match->input(1).get_source_output().get_node_shared_ptr());
auto offset = dynamic_pointer_cast<op::Constant>(
dequant_match->input(2).get_source_output().get_node_shared_ptr());
NGRAPH_CHECK(revalidate_and_ensure_static(dequantize_op));
auto type = constant_match->get_element_type();
if (dequant_match->get_element_type() != element::f32)
......@@ -1423,7 +1454,7 @@ void pass::ConstantFolding::construct_constant_dequantize()
auto dequantize_matcher =
make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
this->add_matcher(
dequantize_matcher, constant_dequantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
dequantize_matcher, constant_dequantize_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class REAL, class QUANT>
......@@ -1467,6 +1498,9 @@ void pass::ConstantFolding::construct_constant_quantize()
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto quant_match = pattern_map[quant];
auto quantize_op = dynamic_pointer_cast<op::Quantize>(quant_match);
NGRAPH_CHECK(revalidate_and_ensure_static(quantize_op));
auto args = quant_match->get_arguments();
auto scale = static_pointer_cast<op::Constant>(
quant_match->input(1).get_source_output().get_node_shared_ptr());
......@@ -1501,7 +1535,7 @@ void pass::ConstantFolding::construct_constant_quantize()
auto quantize_matcher =
make_shared<pattern::Matcher>(quant, "ConstantFolding.ConstantQuantize");
this->add_matcher(
quantize_matcher, constant_quantize_callback, PassProperty::REQUIRE_STATIC_SHAPE);
quantize_matcher, constant_quantize_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
// Helper for mapping element::Types to runtime::reference::convert, which is templated in C++
......@@ -1646,6 +1680,8 @@ void pass::ConstantFolding::construct_constant_convert()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto convert_match = static_pointer_cast<op::Convert>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(convert_match));
replace_node(
m.get_match_root(),
fold_constant_convert(constant_match, convert_match->get_output_element_type(0)));
......@@ -1654,7 +1690,8 @@ void pass::ConstantFolding::construct_constant_convert()
auto convert_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantConvert");
this->add_matcher(convert_matcher, constant_convert_callback, all_pass_property_off);
this->add_matcher(
convert_matcher, constant_convert_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
// ShapeOf is a bit of an odd duck: it doesn't matter if the input's value is
......@@ -1674,6 +1711,8 @@ void pass::ConstantFolding::construct_constant_shape_of()
if (arg_match->get_output_partial_shape(0).is_static())
{
NGRAPH_CHECK(revalidate_and_ensure_static(m.get_match_root()));
auto arg_shape = arg_match->get_output_shape(0);
auto replacement =
make_shared<op::Constant>(element::i64, Shape{arg_shape.size()}, arg_shape.data());
......@@ -1690,7 +1729,8 @@ void pass::ConstantFolding::construct_constant_shape_of()
auto shape_of_matcher =
make_shared<pattern::Matcher>(shape_of_op, "ConstantFolding.ConstantShapeOf");
this->add_matcher(shape_of_matcher, constant_shape_of_callback, all_pass_property_off);
this->add_matcher(
shape_of_matcher, constant_shape_of_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <typename T>
......@@ -1770,6 +1810,8 @@ void pass::ConstantFolding::construct_constant_reverse()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reverse_match = static_pointer_cast<op::Reverse>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(reverse_match));
replace_node(m.get_match_root(),
fold_constant_reverse(constant_match, reverse_match->get_reversed_axes()));
return true;
......@@ -1777,7 +1819,8 @@ void pass::ConstantFolding::construct_constant_reverse()
auto convert_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantReverse");
this->add_matcher(convert_matcher, constant_reverse_callback, all_pass_property_off);
this->add_matcher(
convert_matcher, constant_reverse_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <typename T>
......@@ -1903,6 +1946,8 @@ void pass::ConstantFolding::construct_constant_arithmetic_reduction()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto reduction_match = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
replace_node(reduction_match,
fold_constant_arithmetic_reduction(constant_match, reduction_match));
return true;
......@@ -1912,7 +1957,7 @@ void pass::ConstantFolding::construct_constant_arithmetic_reduction()
make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
this->add_matcher(arithmetic_reduction_matcher,
constant_arithmetic_reduction_callback,
all_pass_property_off);
PassProperty::CHANGE_DYNAMIC_STATE);
}
static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
......@@ -1973,6 +2018,8 @@ void pass::ConstantFolding::construct_constant_logical_reduction()
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto reduction_match = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(reduction_match));
replace_node(reduction_match,
fold_constant_logical_reduction(constant_match, reduction_match));
return true;
......@@ -1980,8 +2027,9 @@ void pass::ConstantFolding::construct_constant_logical_reduction()
auto logical_reduction_matcher =
make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantLogicalReduction");
this->add_matcher(
logical_reduction_matcher, constant_logical_reduction_callback, all_pass_property_off);
this->add_matcher(logical_reduction_matcher,
constant_logical_reduction_callback,
PassProperty::CHANGE_DYNAMIC_STATE);
}
template <typename T>
......@@ -2029,6 +2077,8 @@ void pass::ConstantFolding::construct_constant_concat()
return false;
}
NGRAPH_CHECK(revalidate_and_ensure_static(concat_node));
std::shared_ptr<op::Constant> replacement;
switch (concat_node->get_output_element_type(0).get_type_enum())
......@@ -2086,7 +2136,7 @@ void pass::ConstantFolding::construct_constant_concat()
auto concat_matcher =
make_shared<pattern::Matcher>(concat_op, "ConstantFolding.ConstantConcat");
this->add_matcher(concat_matcher, constant_concat_callback, all_pass_property_off);
this->add_matcher(concat_matcher, constant_concat_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
// "Inner" helper for fold_constant_gather, which has to switch on the indices
......@@ -2168,6 +2218,8 @@ void pass::ConstantFolding::construct_constant_gather()
auto indices = static_pointer_cast<op::Constant>(pattern_map[indices_label]);
auto gather = static_pointer_cast<op::Gather>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(gather));
std::shared_ptr<Node> replacement;
auto data_type = data->get_output_element_type(0);
auto indices_type = indices->get_output_element_type(0);
......@@ -2226,7 +2278,7 @@ void pass::ConstantFolding::construct_constant_gather()
auto gather_matcher =
make_shared<pattern::Matcher>(gather_op, "ConstantFolding.ConstantGather");
this->add_matcher(gather_matcher, constant_gather_callback, PassProperty::REQUIRE_STATIC_SHAPE);
this->add_matcher(gather_matcher, constant_gather_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -2263,6 +2315,8 @@ void pass::ConstantFolding::construct_constant_slice()
auto data_node = static_pointer_cast<op::Constant>(pattern_map[data_label]);
auto slice = static_pointer_cast<op::Slice>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(slice));
std::shared_ptr<op::Constant> replacement;
switch (slice->get_output_element_type(0).get_type_enum())
......@@ -2319,7 +2373,7 @@ void pass::ConstantFolding::construct_constant_slice()
};
auto slice_matcher = make_shared<pattern::Matcher>(slice_op, "ConstantFolding.ConstantSlice");
this->add_matcher(slice_matcher, constant_slice_callback, all_pass_property_off);
this->add_matcher(slice_matcher, constant_slice_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -2399,6 +2453,8 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
auto strides_node = static_pointer_cast<op::Constant>(pattern_map[strides_label]);
auto dyn_slice = static_pointer_cast<op::DynSlice>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(dyn_slice));
std::shared_ptr<op::Constant> replacement;
switch (dyn_slice->get_output_element_type(0).get_type_enum())
......@@ -2469,7 +2525,8 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_op, "ConstantFolding.ConstantDynSlice");
this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off);
this->add_matcher(
dyn_slice_matcher, constant_dyn_slice_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -2507,6 +2564,8 @@ void pass::ConstantFolding::construct_constant_range()
auto step_node = static_pointer_cast<op::Constant>(pattern_map[step_label]);
auto range = static_pointer_cast<op::Range>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(range));
std::shared_ptr<op::Constant> replacement;
switch (range->get_output_element_type(0).get_type_enum())
......@@ -2563,7 +2622,7 @@ void pass::ConstantFolding::construct_constant_range()
};
auto range_matcher = make_shared<pattern::Matcher>(range_op, "ConstantFolding.ConstantRange");
this->add_matcher(range_matcher, constant_range_callback, all_pass_property_off);
this->add_matcher(range_matcher, constant_range_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
template <class T>
......@@ -2605,6 +2664,8 @@ void pass::ConstantFolding::construct_constant_select()
auto f_node = static_pointer_cast<op::Constant>(pattern_map[f_label]);
auto select = static_pointer_cast<op::Select>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(select));
std::shared_ptr<op::Constant> replacement;
switch (select->get_output_element_type(0).get_type_enum())
......@@ -2662,5 +2723,5 @@ void pass::ConstantFolding::construct_constant_select()
auto select_matcher =
make_shared<pattern::Matcher>(select_op, "ConstantFolding.ConstantSelect");
this->add_matcher(select_matcher, constant_select_callback, all_pass_property_off);
this->add_matcher(select_matcher, constant_select_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
......@@ -1019,6 +1019,43 @@ TEST(constant_folding, constant_dyn_reshape)
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
{
Shape shape_in{2, 4};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_shape{3};
// We're going to add these two together elementwise to get {2, 4, 1}.
// This means that when ConstantFolding starts, DynReshape will not yet
// have static output shape. But by the time the Add op is folded, the
// DynReshape's shape should be inferrable.
vector<int64_t> values_shape_a{1, 3, 0};
vector<int64_t> values_shape_b{1, 1, 1};
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape_a = make_shared<op::Constant>(element::i64, shape_shape, values_shape_a);
auto constant_shape_b = make_shared<op::Constant>(element::i64, shape_shape, values_shape_b);
auto dyn_reshape =
make_shared<op::DynReshape>(constant_in, constant_shape_a + constant_shape_b);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
ASSERT_TRUE(dyn_reshape->output(0).get_partial_shape().is_dynamic());
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_transpose)
{
Shape shape_in{2, 4};
......@@ -1140,5 +1177,5 @@ TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
ASSERT_EQ(false, pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
ASSERT_EQ(false, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
ASSERT_EQ(true, pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment