Commit 02cbfc41 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

WIP

parent 7a569557
...@@ -171,6 +171,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND ...@@ -171,6 +171,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
runtime/cpu/cpu_tensor_view.cpp runtime/cpu/cpu_tensor_view.cpp
runtime/cpu/cpu_tensor_view_wrapper.cpp runtime/cpu/cpu_tensor_view_wrapper.cpp
runtime/cpu/cpu_layout_descriptor.cpp runtime/cpu/cpu_layout_descriptor.cpp
runtime/cpu/mkldnn_utils.cpp
runtime/cpu/ops/matmul_bias.cpp runtime/cpu/ops/matmul_bias.cpp
runtime/cpu/pass/cpu_fusion.cpp runtime/cpu/pass/cpu_fusion.cpp
runtime/cpu/pass/cpu_layout.cpp runtime/cpu/pass/cpu_layout.cpp
......
...@@ -39,7 +39,7 @@ namespace ngraph ...@@ -39,7 +39,7 @@ namespace ngraph
, axis_order(tv_axis_order) , axis_order(tv_axis_order)
, offset(0) , offset(0)
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape())) , size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, mkldnn_format(mkldnn_format_undef) , mkldnn_format(mkldnn::memory::format::format_undef)
{ {
auto shape = get_shape(); auto shape = get_shape();
size_t s = 1; size_t s = 1;
...@@ -93,8 +93,8 @@ namespace ngraph ...@@ -93,8 +93,8 @@ namespace ngraph
return false; return false;
//TODO: Numeric backend-specific properties //TODO: Numeric backend-specific properties
if (mkldnn_format != p_other->mkldnn_format) // if (mkldnn_format != p_other->mkldnn_format)
return false; // return false;
return true; return true;
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <mkldnn_types.h> #include <mkldnn.hpp>
#include "ngraph/common.hpp" #include "ngraph/common.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
...@@ -45,7 +45,8 @@ namespace ngraph ...@@ -45,7 +45,8 @@ namespace ngraph
const Strides& get_strides() const override { return strides; } const Strides& get_strides() const override { return strides; }
bool operator==(const TensorViewLayout& other) const override; bool operator==(const TensorViewLayout& other) const override;
mkldnn_memory_format_t get_mkldnn_format() const { return mkldnn_format; } void set_mkldnn_format(const mkldnn::memory::format& format) { mkldnn_format = format; }
mkldnn::memory::format get_mkldnn_format() const { return mkldnn_format; }
const AxisVector& get_axis_order() const { return axis_order; } const AxisVector& get_axis_order() const { return axis_order; }
static const AxisVector Native2DAxisOrder; static const AxisVector Native2DAxisOrder;
...@@ -60,7 +61,7 @@ namespace ngraph ...@@ -60,7 +61,7 @@ namespace ngraph
size_t size; size_t size;
// Numeric backend-specific fields // Numeric backend-specific fields
mkldnn_memory_format_t mkldnn_format; mkldnn::memory::format mkldnn_format;
}; };
typedef std::vector<std::shared_ptr<ngraph::runtime::cpu::LayoutDescriptor>> LayoutDescriptorPtrs; typedef std::vector<std::shared_ptr<ngraph::runtime::cpu::LayoutDescriptor>> LayoutDescriptorPtrs;
......
...@@ -15,9 +15,12 @@ ...@@ -15,9 +15,12 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <mkldnn.hpp>
#include "cpu_layout.hpp" #include "cpu_layout.hpp"
#include "ngraph/descriptor/output.hpp" #include "ngraph/descriptor/output.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
using namespace ngraph::runtime::cpu::pass; using namespace ngraph::runtime::cpu::pass;
...@@ -39,18 +42,37 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) ...@@ -39,18 +42,37 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
auto native_axis_order = ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank); auto native_axis_order = ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank);
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(
*tv, native_axis_order);
if (tensor.is_output() || tensor.is_input() || tensor.is_constant()) if (tensor.is_output() || tensor.is_input() || tensor.is_constant())
{ {
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>( // Set the MKLDNN format to native row-major variants
*tv, native_axis_order); layout->set_mkldnn_format(MKLDNN::CreateNativeDataFormat(*layout));
tv->set_tensor_view_layout(layout);
} }
else else
{ {
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>( if (ngraph::runtime::cpu::MKLDNN::IsMKLDNNOp(*node))
*tv, native_axis_order); {
tv->set_tensor_view_layout(layout); // TODO(jmenon): get_inputs is marked as to-be-deprecated
// but get_input_ops isn't a suitable API so this needs to be
// reworked
for (const descriptor::Input& input : node->get_inputs())
{
const auto& output = input.get_output();
auto output_tv = output.get_tensor_view();
auto output_tvl = output_tv->get_tensor_view_layout();
// TODO(jmenon): Propagate layout based on inputs
// TODO(jmenon): Insert layout conversions when needed
}
}
else
{
layout->set_mkldnn_format(mkldnn::memory::format::format_undef);
}
} }
tv->set_tensor_view_layout(layout);
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment