Commit 846f6bfe authored by Nishant Patel's avatar Nishant Patel Committed by Robert Kimball

Support dimshuffle/transpose with MKLDNN (#1129)

* Reshape 4d

* Support dimshuffles/transpose with MKLDNN

* Addressing PR Feedback

* Use Eigen for 3D dimshuffles
parent d861ba32
......@@ -1658,35 +1658,91 @@ namespace ngraph
writer << " );\n";
}
#else
if (args[0].get_element_type() == element::f32 && args[0].get_shape().size() == 3 &&
out[0].get_shape().size() == 3)
{
writer << "cpu::kernel::reshape_3d_3d_float32(" << args[0].get_name() << ", "
<< out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(reshape->get_input_order()) << "}, "
<< "{" << join(out[0].get_shape()) << "}"
<< ");\n";
}
else if (args[0].get_element_type() == element::f32 &&
args[0].get_shape().size() == 4 && out[0].get_shape().size() == 4)
if (runtime::cpu::mkldnn_utils::use_mkldnn_kernel(node))
{
writer << "cpu::kernel::reshape_4d_4d_float32(" << args[0].get_name() << ", "
<< out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(reshape->get_input_order()) << "}, "
<< "{" << join(out[0].get_shape()) << "}"
<< ");\n";
auto input_tvl = node->get_inputs()[0]
.get_output()
.get_tensor_view()
->get_tensor_view_layout();
auto input_cpu_tvl =
dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(input_tvl);
// Reorder input shape if needed
auto input_axis_order = input_cpu_tvl->get_axis_order();
Shape input_shape(input_axis_order.size());
for (size_t idx = 0; idx < input_axis_order.size(); idx++)
{
input_shape[idx] = args[0].get_shape()[input_axis_order[idx]];
}
auto output_tvl = node->get_output_tensor_view(0)->get_tensor_view_layout();
auto input_strides = input_tvl->get_strides();
auto output_strides = output_tvl->get_strides();
auto axis_order = reshape->get_input_order();
Strides new_output_strides(output_strides.size());
for (int i = 0; i < output_strides.size(); i++)
new_output_strides[axis_order[i]] = output_strides[i];
mkldnn::memory::data_type et = runtime::cpu::mkldnn_utils::get_mkldnn_data_type(
node->get_input_element_type(0));
mkldnn::memory::dims mkldnn_input_shape(input_shape.begin(), input_shape.end());
mkldnn::memory::dims mkldnn_input_strides(input_strides.begin(),
input_strides.end());
mkldnn::memory::dims mkldnn_output_strides(new_output_strides.begin(),
new_output_strides.end());
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
auto input_desc = mkldnn_emitter->build_blocked_memory_descriptor(
mkldnn_input_shape, mkldnn_input_strides, et);
auto result_desc = mkldnn_emitter->build_blocked_memory_descriptor(
mkldnn_input_shape, mkldnn_output_strides, et);
size_t reorder_index = mkldnn_emitter->build_reorder(input_desc, result_desc);
auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index);
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[0])
<< ", " << args[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::set_memory_ptr(ctx, " << to_string(deps[1])
<< ", " << out[0].get_name() << ");\n";
writer << "cpu::mkldnn_utils::mkldnn_invoke_primitive(ctx, "
<< to_string(reorder_index) << ");\n";
}
else
{
kernel::emit_reshape(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
out[0].get_name(),
args[0].get_shape(),
out[0].get_shape(),
reshape->get_input_order());
if (args[0].get_element_type() == element::f32 &&
args[0].get_shape().size() == 3 && out[0].get_shape().size() == 3)
{
writer << "cpu::kernel::reshape_3d_3d_float32(" << args[0].get_name()
<< ", " << out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(reshape->get_input_order()) << "}, "
<< "{" << join(out[0].get_shape()) << "}"
<< ");\n";
}
else if (args[0].get_element_type() == element::f32 &&
args[0].get_shape().size() == 4 && out[0].get_shape().size() == 4)
{
writer << "cpu::kernel::reshape_4d_4d_float32(" << args[0].get_name()
<< ", " << out[0].get_name() << ", "
<< "{" << join(args[0].get_shape()) << "}, "
<< "{" << join(reshape->get_input_order()) << "}, "
<< "{" << join(out[0].get_shape()) << "}"
<< ");\n";
}
else
{
kernel::emit_reshape(writer,
args[0].get_element_type().c_type_string(),
args[0].get_name(),
out[0].get_name(),
args[0].get_shape(),
out[0].get_shape(),
reshape->get_input_order());
}
}
#endif
writer.block_end();
......
......@@ -87,6 +87,31 @@ mkldnn::memory::desc MKLDNNEmitter::build_memory_descriptor(const Shape& shape,
fmt);
}
mkldnn::memory::desc
MKLDNNEmitter::build_blocked_memory_descriptor(const mkldnn::memory::dims& dim,
const mkldnn::memory::dims& strides,
mkldnn::memory::data_type dtype) const
{
mkldnn_memory_desc_t md;
md.primitive_kind = mkldnn_memory;
md.ndims = static_cast<int>(dim.size());
md.format = mkldnn_blocked;
md.data_type = mkldnn::memory::convert_to_c(dtype);
for (size_t i = 0; i < dim.size(); i++)
{
md.layout_desc.blocking.block_dims[i] = 1;
md.layout_desc.blocking.strides[1][i] = 1;
md.layout_desc.blocking.strides[0][i] = strides[i];
md.layout_desc.blocking.padding_dims[i] = dim[i];
md.layout_desc.blocking.offset_padding_to_data[i] = 0;
md.dims[i] = dim[i];
}
md.layout_desc.blocking.offset_padding = 0;
return mkldnn::memory::desc(md);
}
mkldnn::memory MKLDNNEmitter::build_memory_primitive(const TensorViewWrapper& tvw) const
{
return mkldnn::memory({build_memory_descriptor(tvw), mkldnn_utils::global_cpu_engine}, nullptr);
......
......@@ -65,6 +65,10 @@ namespace ngraph
mkldnn::memory::desc build_memory_descriptor(const Shape& shape,
const ngraph::element::Type& et,
mkldnn::memory::format fmt) const;
mkldnn::memory::desc
build_blocked_memory_descriptor(const mkldnn::memory::dims& dim,
const mkldnn::memory::dims& strides,
mkldnn::memory::data_type dtype) const;
mkldnn::memory build_memory_primitive(const TensorViewWrapper& tvw) const;
size_t build_memory_primitive(const mkldnn::memory::desc& desc);
......
......@@ -27,6 +27,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp"
......@@ -57,7 +58,8 @@ static const std::unordered_set<std::type_index> s_op_registry{
TI(ngraph::op::MaxPool),
TI(ngraph::op::MaxPoolBackprop),
TI(ngraph::op::Relu),
TI(ngraph::op::ReluBackprop)};
TI(ngraph::op::ReluBackprop),
TI(ngraph::op::Reshape)};
// Mapping from POD types to MKLDNN data types
static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map{
......
......@@ -32,6 +32,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
......@@ -548,6 +549,24 @@ namespace ngraph
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::Reshape)
{
auto reshape = static_cast<op::Reshape*>(node);
// Use Eigen for 3D
if (node->get_input_element_type(0) == element::f32 &&
node->get_input_shape(0).size() < TENSOR_MAX_DIMS &&
node->get_input_shape(0).size() > 3 &&
node->get_input_shape(0).size() == node->get_output_shape(0).size())
{
auto op_annotations =
std::make_shared<ngraph::runtime::cpu::CPUOpAnnotations>();
op_annotations->set_mkldnn_op(true);
reshape->set_op_annotations(op_annotations);
}
}
template <>
void CPUAssignment::ASSIGN_DECL(ngraph::op::BatchNorm)
{
......@@ -694,6 +713,7 @@ static const runtime::cpu::pass::AssignOpMap s_dispatcher{
&runtime::cpu::pass::CPUAssignment::assign<ngraph::op::SigmoidBackprop>},
{TI(ngraph::op::Lstm), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Lstm>},
{TI(ngraph::op::Rnn), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Rnn>},
{TI(ngraph::op::Reshape), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Reshape>},
{TI(ngraph::op::Softmax), &runtime::cpu::pass::CPUAssignment::assign<ngraph::op::Softmax>},
};
......
......@@ -2485,6 +2485,66 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_m2m_dim_change_transpose)
EXPECT_EQ((vector<float>{1, 3, 5, 2, 4, 6}), read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_3d_transpose)
{
vector<float> a_data(2 * 2 * 5);
for (int i = 0; i < 2 * 2 * 5; i++)
{
a_data[i] = float(i + 1);
}
Shape shape_a{2, 2, 5};
Shape shape_r{2, 5, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{0, 2, 1}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ((vector<float>{1., 6., 2., 7., 3., 8., 4., 9., 5., 10.,
11., 16., 12., 17., 13., 18., 14., 19., 15., 20.}),
read_vector<float>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, reshape_4d_transpose)
{
vector<float> a_data(2 * 2 * 5 * 5);
for (int i = 0; i < 2 * 2 * 5 * 5; i++)
{
a_data[i] = float(i + 1);
}
Shape shape_a{2, 2, 5, 5};
Shape shape_r{2, 5, 5, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto r = make_shared<op::Reshape>(A, AxisVector{0, 2, 3, 1}, shape_r);
auto f = make_shared<Function>(r, op::ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, a_data);
auto result = backend->create_tensor(element::f32, shape_r);
backend->call(f, {result}, {a});
EXPECT_EQ(
(vector<float>{1., 26., 2., 27., 3., 28., 4., 29., 5., 30., 6., 31., 7., 32., 8.,
33., 9., 34., 10., 35., 11., 36., 12., 37., 13., 38., 14., 39., 15., 40.,
16., 41., 17., 42., 18., 43., 19., 44., 20., 45., 21., 46., 22., 47., 23.,
48., 24., 49., 25., 50., 51., 76., 52., 77., 53., 78., 54., 79., 55., 80.,
56., 81., 57., 82., 58., 83., 59., 84., 60., 85., 61., 86., 62., 87., 63.,
88., 64., 89., 65., 90., 66., 91., 67., 92., 68., 93., 69., 94., 70., 95.,
71., 96., 72., 97., 73., 98., 74., 99., 75., 100.}),
read_vector<float>(result));
}
//
// Numpy:
//
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment