Commit 4782e060 authored by Chris Sullivan's avatar Chris Sullivan Committed by Robert Kimball

Fix bug in cpu_layout: explicitly handle , add test for coverage. (#1621)

parent 84de3bf4
......@@ -1193,51 +1193,60 @@ namespace ngraph
void CPULayout::LAYOUT_DECL(ngraph::op::Reshape)
{
auto reshape = static_cast<ngraph::op::Reshape*>(node.get());
if (reshape->get_is_transpose() &&
reshape->get_output_shape().size() ==
reshape->get_argument(0)->get_shape().size())
if (reshape->get_is_transpose())
{
auto axis_order = reshape->get_input_order();
auto tvl = node->get_inputs()[0]
.get_output()
.get_tensor_ptr()
->get_tensor_layout();
auto cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (cpu_tvl && cpu_tvl->is_mkldnn_layout())
if (reshape->get_output_shape().size() ==
reshape->get_argument(0)->get_shape().size())
{
// Rotate MKLDNN memory descriptor
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
auto output_md = mkldnn_utils::rotate_blocked_md(input_md, axis_order);
set_output_layouts(node, {output_md});
auto op_annotations = reshape->get_op_annotations();
if (op_annotations)
auto axis_order = reshape->get_input_order();
auto tvl = node->get_inputs()[0]
.get_output()
.get_tensor_ptr()
->get_tensor_layout();
auto cpu_tvl = dynamic_cast<runtime::cpu::LayoutDescriptor*>(tvl.get());
if (cpu_tvl && cpu_tvl->is_mkldnn_layout())
{
// pass-through
op_annotations->add_in_place_oi_pair({0, 0, false});
// Rotate MKLDNN memory descriptor
auto input_md = mkldnn_utils::get_input_mkldnn_md(node.get(), 0);
auto output_md =
mkldnn_utils::rotate_blocked_md(input_md, axis_order);
set_output_layouts(node, {output_md});
auto op_annotations = reshape->get_op_annotations();
if (op_annotations)
{
// pass-through
op_annotations->add_in_place_oi_pair({0, 0, false});
}
else
{
op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
// pass-through
op_annotations->add_in_place_oi_pair({0, 0, false});
reshape->set_op_annotations(op_annotations);
}
}
else
{
op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
// pass-through
op_annotations->add_in_place_oi_pair({0, 0, false});
reshape->set_op_annotations(op_annotations);
auto input_strides = cpu_tvl->get_strides();
Strides output_strides(input_strides.size());
for (size_t i = 0; i < input_strides.size(); i++)
{
output_strides[i] = input_strides[axis_order[i]];
}
set_native_layouts(external_function, node);
auto output_tvl =
dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(
node->get_output_tensor_ptr()->get_tensor_layout());
// TODO (jbobba): For now non-MKLDNN layouts are always in row-major format
// Enable this once we support non row-major strided formats
// output_tvl->set_strides(output_strides);
}
}
else
{
auto input_strides = cpu_tvl->get_strides();
Strides output_strides(input_strides.size());
for (size_t i = 0; i < input_strides.size(); i++)
{
output_strides[i] = input_strides[axis_order[i]];
}
set_native_layouts(external_function, node);
auto output_tvl = dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(
node->get_output_tensor_ptr()->get_tensor_layout());
// TODO (jbobba): For now non-MKLDNN layouts are always in row-major format
// Enable this once we support non row-major strided formats
// output_tvl->set_strides(output_strides);
return;
}
}
else
......
......@@ -3044,6 +3044,25 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_4d_no_transpose)
EXPECT_EQ(a_data, read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_transposed_shape_change)
{
Shape shape_a{2, 6};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_r{12};
auto r = make_shared<op::Reshape>(A, AxisVector{1, 0}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto result = backend->create_tensor(element::f32, shape_r);
backend->call_with_validate(f, {result}, {a});
EXPECT_EQ((vector<float>{1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12}), read_vector<float>(result));
}
//
// Numpy:
//
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment