Unverified Commit cfff3f1b authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3103 from NervanaSystems/aprocter/dyn-reshape

Add DynElimination for DynReshape
parents b02b0812 7512f0ef
...@@ -99,7 +99,8 @@ void op::DynReshape::validate_and_infer_types() ...@@ -99,7 +99,8 @@ void op::DynReshape::validate_and_infer_types()
if (out_shape_val[i] == 0 && m_zero_flag) if (out_shape_val[i] == 0 && m_zero_flag)
{ {
// Copy input_shape[i] for zero values // Copy input_shape[i] for zero values
NGRAPH_CHECK(i < input_shape.size()); NODE_VALIDATION_CHECK(
this, i < input_shape.size(), "'0' dimension is out of range");
partial_shape[i] = Dimension(input_shape[i]); partial_shape[i] = Dimension(input_shape[i]);
output_elements *= input_shape[i]; output_elements *= input_shape[i];
} }
...@@ -119,12 +120,21 @@ void op::DynReshape::validate_and_infer_types() ...@@ -119,12 +120,21 @@ void op::DynReshape::validate_and_infer_types()
// input elements // input elements
if (output_elements == 0) if (output_elements == 0)
{ {
NGRAPH_CHECK(input_elements == 0); // TODO(amprocte): Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_elements == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
partial_shape[negative_dim] = Dimension(0); partial_shape[negative_dim] = Dimension(0);
} }
else else
{ {
NGRAPH_CHECK(input_elements % output_elements == 0); NODE_VALIDATION_CHECK(
this,
input_elements % output_elements == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
partial_shape[negative_dim] = Dimension(input_elements / output_elements); partial_shape[negative_dim] = Dimension(input_elements / output_elements);
} }
} }
......
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <numeric>
#include "dyn_elimination.hpp" #include "dyn_elimination.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
...@@ -34,6 +37,7 @@ pass::DynElimination::DynElimination() ...@@ -34,6 +37,7 @@ pass::DynElimination::DynElimination()
{ {
construct_transpose(); construct_transpose();
construct_broadcast(); construct_broadcast();
construct_dyn_slice();
construct_dyn_reshape(); construct_dyn_reshape();
construct_range(); construct_range();
} }
...@@ -367,7 +371,7 @@ static SlicePlan make_plan(const Shape& input_shape, ...@@ -367,7 +371,7 @@ static SlicePlan make_plan(const Shape& input_shape,
return p; return p;
} }
void pass::DynElimination::construct_dyn_reshape() void pass::DynElimination::construct_dyn_slice()
{ {
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3}); auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label = auto begins_arg_label =
...@@ -436,10 +440,53 @@ void pass::DynElimination::construct_dyn_reshape() ...@@ -436,10 +440,53 @@ void pass::DynElimination::construct_dyn_reshape()
}; };
auto dyn_slice_matcher = auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape"); make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynSlice");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off); add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
} }
void pass::DynElimination::construct_dyn_reshape()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto shape_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(data_arg_label, shape_arg_label);
auto dyn_reshape_callback = [data_arg_label, shape_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto shape_arg = static_pointer_cast<op::Constant>(pattern_map[shape_arg_label]);
auto dyn_reshape_node = static_pointer_cast<op::DynReshape>(m.get_match_root());
// TODO(amprocte): Can't handle the case where data rank is dynamic even if we know the
// output shape, because static Reshape requries an axis permutation (here an identity) to
// be given. See if we can come up with a workaround.
if (data_arg->get_output_partial_shape(0).rank().is_dynamic())
{
return false;
}
if (dyn_reshape_node->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto& result_shape = dyn_reshape_node->get_output_shape(0);
AxisVector perm(size_t(data_arg->get_output_partial_shape(0).rank()));
std::iota(perm.begin(), perm.end(), 0);
auto replacement = std::make_shared<op::Reshape>(data_arg, perm, result_shape);
replace_node(dyn_reshape_node, replacement);
return true;
};
auto dyn_reshape_matcher =
make_shared<pattern::Matcher>(dyn_reshape, "DynElimination.DynReshape");
add_matcher(dyn_reshape_matcher, dyn_reshape_callback, all_pass_property_off);
}
template <typename T> template <typename T>
std::shared_ptr<op::Constant> std::shared_ptr<op::Constant>
make_range_replacement_integral(const element::Type& et, make_range_replacement_integral(const element::Type& et,
......
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
private: private:
void construct_transpose(); void construct_transpose();
void construct_broadcast(); void construct_broadcast();
void construct_dyn_slice();
void construct_dyn_reshape(); void construct_dyn_reshape();
void construct_range(); void construct_range();
}; };
......
...@@ -132,6 +132,30 @@ TEST(dyn_elimination, slice) ...@@ -132,6 +132,30 @@ TEST(dyn_elimination, slice)
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2})); ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2}));
} }
TEST(dyn_elimination, reshape)
{
auto input_arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto shape_arg = make_shared<op::Constant>(element::i64, Shape{3}, vector<int64_t>{0, 6, -1});
auto r = make_shared<op::DynReshape>(input_arg, shape_arg, true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{2, 6, 32}));
auto f = make_shared<Function>(r, ParameterVector{input_arg});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 6, 32}));
}
TEST(dyn_elimination, range) TEST(dyn_elimination, range)
{ {
auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0}); auto constant_start = make_shared<op::Constant>(element::i64, Shape{}, vector<int64_t>{0});
......
...@@ -365,3 +365,94 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, range) ...@@ -365,3 +365,94 @@ NGRAPH_TEST(dynamic_${BACKEND_NAME}, range)
ASSERT_EQ(results, test.expected_result); ASSERT_EQ(results, test.expected_result);
} }
} }
NGRAPH_TEST(dynamic_${BACKEND_NAME}, reshape)
{
auto backend = runtime::Backend::create("${BACKEND_NAME}", true);
auto build_graph = [&backend](bool zero_flag) {
// Create a graph for f(x,shape) = DynReshape(x,shape,zero_flag=zero_flag).
auto x = make_shared<op::Parameter>(element::i32, PartialShape::dynamic());
auto shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic(1));
auto dyn_reshape = make_shared<op::DynReshape>(x, shape, zero_flag);
EXPECT_TRUE(dyn_reshape->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
auto f = make_shared<Function>(NodeVector{dyn_reshape}, ParameterVector{x, shape});
auto ex = backend->compile(f);
return ex;
};
auto t_r = backend->create_dynamic_tensor(element::i32, PartialShape::dynamic());
auto ex_flag_off = build_graph(false);
auto ex_flag_on = build_graph(true);
std::vector<std::tuple<bool, Shape, std::vector<int32_t>, std::vector<int64_t>, Shape>> tests;
tests.emplace_back(make_tuple(
false, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{6}, Shape{6}));
tests.emplace_back(make_tuple(
true, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{6}, Shape{6}));
tests.emplace_back(make_tuple(
false, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6}, vector<int64_t>{-1}, Shape{6}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{2, -1},
Shape{2, 3}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, -1},
Shape{3, 2}));
tests.emplace_back(make_tuple(false,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, 2, -1},
Shape{3, 2, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{3, 2, -1},
Shape{3, 2, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{0, 0, -1},
Shape{2, 3, 1}));
tests.emplace_back(make_tuple(true,
Shape{2, 3},
vector<int32_t>{1, 2, 3, 4, 5, 6},
vector<int64_t>{2, 0, -1},
Shape{2, 3, 1}));
tests.emplace_back(make_tuple(
true, Shape{0, 3, 4}, vector<int32_t>{}, vector<int64_t>{3, -1, 2}, Shape{3, 0, 2}));
for (auto& test : tests)
{
bool zero_flag = get<0>(test);
const Shape& in_shape = get<1>(test);
const std::vector<int32_t>& data = get<2>(test);
const std::vector<int64_t>& dims = get<3>(test);
const Shape& out_shape = get<4>(test);
auto t_x = backend->create_tensor(element::i32, in_shape);
auto t_shape = backend->create_tensor(element::i64, Shape{dims.size()});
copy_data(t_x, data);
copy_data(t_shape, dims);
auto ex = zero_flag ? ex_flag_on : ex_flag_off;
ex->call_with_validate({t_r}, {t_x, t_shape});
ASSERT_EQ(t_r->get_element_type(), element::i32);
ASSERT_EQ(t_r->get_shape(), out_shape);
auto results = read_vector<int32_t>(t_r);
ASSERT_EQ(results, data);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment