Commit 62f00d68 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Bug fixes to account for padded layouts and correct handling for resh… (#1480)

* Bug fixes to account for padded layouts and correct handling for reshape that transposes and changes shape

* Use nullptr instead of dummy buffer in memory primitive creation
parent 6679c233
...@@ -114,12 +114,9 @@ mkldnn::memory::desc ...@@ -114,12 +114,9 @@ mkldnn::memory::desc
size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc) size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{ {
// The MKL-DNN C++ API forces proper initialization of a memory primitive size_t index =
// with a non-null pointer (unlike the C API) insert_primitive(new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr));
// Primitives are initialized at runtime so we use a known-invalid address here return index;
// to bypass this check
return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, reinterpret_cast<void*>(0x42)));
} }
mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format( mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format(
......
...@@ -1074,9 +1074,15 @@ namespace ngraph ...@@ -1074,9 +1074,15 @@ namespace ngraph
void CPULayout::LAYOUT_DECL(ngraph::op::Result) void CPULayout::LAYOUT_DECL(ngraph::op::Result)
{ {
auto result = static_cast<const ngraph::op::Result*>(node.get()); auto result = static_cast<const ngraph::op::Result*>(node.get());
if (result->needs_default_layout() || auto cpu_tvl = dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(
mkldnn_utils::get_input_mkldnn_md(node.get(), 0).data.format == node->get_inputs()[0]
mkldnn_format_undef) .get_output()
.get_tensor_view()
->get_tensor_view_layout());
if (result->needs_default_layout() || !cpu_tvl->is_mkldnn_layout() ||
cpu_tvl->get_size() * cpu_tvl->get_element_type().size() !=
cpu_tvl->get_allocated_size())
{ {
set_native_layouts(external_function, node, false); set_native_layouts(external_function, node, false);
} }
...@@ -1093,7 +1099,9 @@ namespace ngraph ...@@ -1093,7 +1099,9 @@ namespace ngraph
void CPULayout::LAYOUT_DECL(ngraph::op::Reshape) void CPULayout::LAYOUT_DECL(ngraph::op::Reshape)
{ {
auto reshape = static_cast<ngraph::op::Reshape*>(node.get()); auto reshape = static_cast<ngraph::op::Reshape*>(node.get());
if (reshape->get_is_transpose()) if (reshape->get_is_transpose() &&
reshape->get_output_shape().size() ==
reshape->get_argument(0)->get_shape().size())
{ {
auto axis_order = reshape->get_input_order(); auto axis_order = reshape->get_input_order();
auto tvl = node->get_inputs()[0] auto tvl = node->get_inputs()[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment