Commit f21db619 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add DynElimination for DynBroadcast (#3062)

* Add DynElimination for Broadcast

* Change silent bailouts for invalid shape/ETs to NGRAPH_CHECKs
parent 66f6331b
......@@ -15,6 +15,8 @@
//*****************************************************************************
#include "dyn_elimination.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -30,6 +32,7 @@ pass::DynElimination::DynElimination()
: GraphRewrite()
{
construct_transpose();
construct_broadcast();
construct_dyn_reshape();
}
......@@ -57,10 +60,11 @@ void pass::DynElimination::construct_transpose()
auto& data_shape = data_arg->get_output_shape(0);
// TODO(amprocte): These should be redundant if the graph is validated. Necessary?
if (perm_arg->get_element_type() != element::i64 ||
perm_arg->get_output_partial_shape(0).is_dynamic() ||
perm_arg->get_output_shape(0).size() != 1)
NGRAPH_CHECK(perm_arg->get_output_partial_shape(0).rank().compatible(1));
NGRAPH_CHECK(perm_arg->get_output_element_type(0).compatible(element::i64));
if (perm_arg->get_output_element_type(0).is_dynamic() ||
perm_arg->get_output_partial_shape(0).is_dynamic())
{
return false;
}
......@@ -79,6 +83,52 @@ void pass::DynElimination::construct_transpose()
add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
}
void pass::DynElimination::construct_broadcast()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto shape_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto axes_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{0}, pattern::has_class<op::Constant>());
auto dyn_broadcast =
make_shared<op::DynBroadcast>(data_arg_label, shape_arg_label, axes_arg_label);
auto dyn_broadcast_callback =
[data_arg_label, shape_arg_label, axes_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto shape_arg = static_pointer_cast<op::Constant>(pattern_map[shape_arg_label]);
auto axes_arg = static_pointer_cast<op::Constant>(pattern_map[axes_arg_label]);
NGRAPH_CHECK(shape_arg->get_output_partial_shape(0).rank().compatible(1));
NGRAPH_CHECK(shape_arg->get_output_element_type(0).compatible(element::i64));
NGRAPH_CHECK(axes_arg->get_output_partial_shape(0).rank().compatible(1));
NGRAPH_CHECK(axes_arg->get_output_element_type(0).compatible(element::i64));
if (shape_arg->get_output_element_type(0).is_dynamic() ||
shape_arg->get_output_partial_shape(0).is_dynamic() ||
axes_arg->get_output_element_type(0).is_dynamic() ||
axes_arg->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto shape = shape_arg->get_shape_val();
auto axes = axes_arg->get_axis_vector_val();
auto replacement = std::make_shared<op::Broadcast>(data_arg, shape, axes);
replace_node(m.get_match_root(), replacement);
return true;
};
auto dyn_broadcast_matcher =
make_shared<pattern::Matcher>(dyn_broadcast, "DynElimination.DynBroadcast");
add_matcher(dyn_broadcast_matcher, dyn_broadcast_callback, all_pass_property_off);
}
//
// We eliminate DynSlice by converting it to a sequence of ops:
//
......
......@@ -30,6 +30,7 @@ namespace ngraph
private:
void construct_transpose();
void construct_broadcast();
void construct_dyn_reshape();
};
}
......
......@@ -160,6 +160,55 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, transpose)
}
}
NGRAPH_TEST(dynamic_${BACKEND_NAME}, broadcast)
{
// Create a graph for
// f(x,shape:i32,axes:32) = Broadcast(x,Convert<i64>(shape),Convert<i64>(axes)).
auto x = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto shape = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto axes = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto shape_i64 = make_shared<op::Convert>(shape, element::i64);
auto axes_i64 = make_shared<op::Convert>(axes, element::i64);
auto bc = make_shared<op::DynBroadcast>(x, shape_i64, axes_i64);
auto f = make_shared<Function>(NodeVector{bc}, ParameterVector{x, shape, axes});
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto ex = backend->compile(f);
auto t_r = backend->create_dynamic_tensor(element::f32, PartialShape::dynamic());
std::vector<Shape> x_shapes{Shape{}, Shape{}, Shape{2}, Shape{2}};
std::vector<std::vector<int32_t>> shapes{{2, 2}, {2, 2, 2}, {3, 2}, {2, 3}};
std::vector<std::vector<int32_t>> axeses{{0, 1}, {0, 1, 2}, {0}, {1}};
std::vector<std::vector<float>> inputs{{6}, {7}, {10, 11}, {10, 11}};
std::vector<Shape> expected_result_shapes{
Shape{2, 2}, Shape{2, 2, 2}, Shape{3, 2}, Shape{2, 3}};
std::vector<std::vector<float>> expected_results{
{6, 6, 6, 6}, {7, 7, 7, 7, 7, 7, 7, 7}, {10, 11, 10, 11, 10, 11}, {10, 10, 10, 11, 11, 11}};
for (size_t i = 0; i < x_shapes.size(); i++)
{
auto t_x = backend->create_tensor(element::f32, x_shapes[i]);
auto t_shape = backend->create_tensor(element::i32, Shape{shapes[i].size()});
auto t_axes = backend->create_tensor(element::i32, Shape{axeses[i].size()});
copy_data(t_x, inputs[i]);
copy_data(t_shape, shapes[i]);
copy_data(t_axes, axeses[i]);
ex->call_with_validate({t_r}, {t_x, t_shape, t_axes});
ASSERT_EQ(t_r->get_shape(), expected_result_shapes[i]);
auto results = read_vector<float>(t_r);
ASSERT_TRUE(test::all_close_f(results, expected_results[i], MIN_FLOAT_TOLERANCE_BITS));
}
}
NGRAPH_TEST(dynamic_${BACKEND_NAME}, sum)
{
// Create a graph for f(x,axes:int32) = Sum(x,Convert<int64>(axes)).
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment