Commit 62f00d68 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Bug fixes to account for padded layouts and correct handling for resh… (#1480)

* Bug fixes to account for padded layouts and correct handling for reshape that transposes and changes shape

* Use nullptr instead of dummy buffer in memory primitive creation
parent 6679c233
......@@ -114,12 +114,9 @@ mkldnn::memory::desc
size_t MKLDNNEmitter::build_memory_primitive(const mkldnn::memory::desc& desc)
{
// The MKL-DNN C++ API forces proper initialization of a memory primitive
// with a non-null pointer (unlike the C API)
// Primitives are initialized at runtime so we use a known-invalid address here
// to bypass this check
return insert_primitive(
new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, reinterpret_cast<void*>(0x42)));
size_t index =
insert_primitive(new mkldnn::memory({desc, mkldnn_utils::global_cpu_engine}, nullptr));
return index;
}
mkldnn::memory::format MKLDNNEmitter::query_convolution_forward_weight_format(
......
......@@ -1074,9 +1074,15 @@ namespace ngraph
void CPULayout::LAYOUT_DECL(ngraph::op::Result)
{
auto result = static_cast<const ngraph::op::Result*>(node.get());
if (result->needs_default_layout() ||
mkldnn_utils::get_input_mkldnn_md(node.get(), 0).data.format ==
mkldnn_format_undef)
auto cpu_tvl = dynamic_pointer_cast<runtime::cpu::LayoutDescriptor>(
node->get_inputs()[0]
.get_output()
.get_tensor_view()
->get_tensor_view_layout());
if (result->needs_default_layout() || !cpu_tvl->is_mkldnn_layout() ||
cpu_tvl->get_size() * cpu_tvl->get_element_type().size() !=
cpu_tvl->get_allocated_size())
{
set_native_layouts(external_function, node, false);
}
......@@ -1093,7 +1099,9 @@ namespace ngraph
void CPULayout::LAYOUT_DECL(ngraph::op::Reshape)
{
auto reshape = static_cast<ngraph::op::Reshape*>(node.get());
if (reshape->get_is_transpose())
if (reshape->get_is_transpose() &&
reshape->get_output_shape().size() ==
reshape->get_argument(0)->get_shape().size())
{
auto axis_order = reshape->get_input_order();
auto tvl = node->get_inputs()[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment