Commit 2d5a886d authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

WIP

parent 48015415
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view.hpp"
#include "ngraph/runtime/external_function.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
......
...@@ -76,11 +76,13 @@ void runtime::cpu::CPU_CallFrame::call( ...@@ -76,11 +76,13 @@ void runtime::cpu::CPU_CallFrame::call(
} }
void runtime::cpu::CPU_CallFrame::propagate_layouts( void runtime::cpu::CPU_CallFrame::propagate_layouts(
const std::vector<std::shared_ptr<runtime::TensorView>>& tvs, const LayoutDescriptorPtrs& layouts) const const std::vector<std::shared_ptr<runtime::TensorView>>& tvs,
const LayoutDescriptorPtrs& layouts) const
{ {
if (layouts.size() != tvs.size()) if (layouts.size() != tvs.size())
{ {
throw ngraph_error("Error propagating layouts - tensor view and layout descriptor counts do not match"); throw ngraph_error(
"Error propagating layouts - tensor view and layout descriptor counts do not match");
} }
for (size_t i = 0; i < tvs.size(); i++) for (size_t i = 0; i < tvs.size(); i++)
{ {
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/runtime/call_frame.hpp" #include "ngraph/runtime/call_frame.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -832,12 +832,14 @@ shared_ptr<ngraph::runtime::CallFrame> runtime::cpu::CPU_ExternalFunction::make_ ...@@ -832,12 +832,14 @@ shared_ptr<ngraph::runtime::CallFrame> runtime::cpu::CPU_ExternalFunction::make_
m_compiled_function); m_compiled_function);
} }
const runtime::cpu::LayoutDescriptorPtrs& runtime::cpu::CPU_ExternalFunction::get_parameter_layout_descriptors() const runtime::cpu::LayoutDescriptorPtrs&
runtime::cpu::CPU_ExternalFunction::get_parameter_layout_descriptors()
{ {
return parameter_layout_descriptors; return parameter_layout_descriptors;
} }
const runtime::cpu::LayoutDescriptorPtrs& runtime::cpu::CPU_ExternalFunction::get_result_layout_descriptors() const runtime::cpu::LayoutDescriptorPtrs&
runtime::cpu::CPU_ExternalFunction::get_result_layout_descriptors()
{ {
return result_layout_descriptors; return result_layout_descriptors;
} }
......
...@@ -36,8 +36,7 @@ namespace ngraph ...@@ -36,8 +36,7 @@ namespace ngraph
public: public:
LayoutDescriptor(const ngraph::descriptor::TensorView& tv, LayoutDescriptor(const ngraph::descriptor::TensorView& tv,
const AxisVector& tv_axis_order); const AxisVector& tv_axis_order);
~LayoutDescriptor() { } ~LayoutDescriptor() {}
size_t get_size() override { return size; } size_t get_size() override { return size; }
size_t get_offset() const { return offset; } size_t get_offset() const { return offset; }
size_t get_index_offset(const std::vector<size_t>& indices) override; size_t get_index_offset(const std::vector<size_t>& indices) override;
...@@ -45,10 +44,12 @@ namespace ngraph ...@@ -45,10 +44,12 @@ namespace ngraph
const Strides& get_strides() const override { return strides; } const Strides& get_strides() const override { return strides; }
bool operator==(const TensorViewLayout& other) const override; bool operator==(const TensorViewLayout& other) const override;
void set_mkldnn_format(const mkldnn::memory::format& format) { mkldnn_format = format; } void set_mkldnn_format(const mkldnn::memory::format& format)
{
mkldnn_format = format;
}
mkldnn::memory::format get_mkldnn_format() const { return mkldnn_format; } mkldnn::memory::format get_mkldnn_format() const { return mkldnn_format; }
const AxisVector& get_axis_order() const { return axis_order; } const AxisVector& get_axis_order() const { return axis_order; }
static const AxisVector Native2DAxisOrder; static const AxisVector Native2DAxisOrder;
static const AxisVector Native4DAxisOrder; static const AxisVector Native4DAxisOrder;
static const AxisVector CHWNAxisOrder; static const AxisVector CHWNAxisOrder;
...@@ -64,7 +65,8 @@ namespace ngraph ...@@ -64,7 +65,8 @@ namespace ngraph
mkldnn::memory::format mkldnn_format; mkldnn::memory::format mkldnn_format;
}; };
typedef std::vector<std::shared_ptr<ngraph::runtime::cpu::LayoutDescriptor>> LayoutDescriptorPtrs; typedef std::vector<std::shared_ptr<ngraph::runtime::cpu::LayoutDescriptor>>
LayoutDescriptorPtrs;
} }
} }
} }
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include "ngraph/except.hpp" #include "cpu_tensor_view.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp" #include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp" #include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/except.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp" #include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "cpu_tensor_view.hpp" #include "ngraph/shape.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -41,9 +41,8 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_ ...@@ -41,9 +41,8 @@ runtime::cpu::CPUTensorView::CPUTensorView(const ngraph::element::Type& element_
// TODO(jmenon): A fallback layout should not be needed but is required // TODO(jmenon): A fallback layout should not be needed but is required
// because of how some unit test functionality is written (ex. 'backprop_derivative') // because of how some unit test functionality is written (ex. 'backprop_derivative')
// This needs to be removed // This needs to be removed
m_descriptor->set_tensor_view_layout( m_descriptor->set_tensor_view_layout(std::make_shared<runtime::cpu::LayoutDescriptor>(
std::make_shared<runtime::cpu::LayoutDescriptor>(*m_descriptor, *m_descriptor, runtime::cpu::LayoutDescriptor::create_native_axis_order(shape.size())));
runtime::cpu::LayoutDescriptor::create_native_axis_order(shape.size())));
buffer_size = shape_size(shape) * element_type.size(); buffer_size = shape_size(shape) * element_type.size();
if (buffer_size) if (buffer_size)
......
...@@ -31,13 +31,10 @@ namespace ngraph ...@@ -31,13 +31,10 @@ namespace ngraph
{ {
namespace MKLDNN namespace MKLDNN
{ {
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
const std::unordered_set<std::type_index> OpRegistry{ const std::unordered_set<std::type_index> OpRegistry{
TI(ngraph::op::Convolution), TI(ngraph::op::Convolution), TI(ngraph::op::AvgPool), TI(ngraph::op::MaxPool),
TI(ngraph::op::AvgPool),
TI(ngraph::op::MaxPool),
}; };
bool IsMKLDNNOp(ngraph::Node& op) bool IsMKLDNNOp(ngraph::Node& op)
...@@ -45,18 +42,15 @@ namespace ngraph ...@@ -45,18 +42,15 @@ namespace ngraph
return (OpRegistry.find(TI(op)) != OpRegistry.end()); return (OpRegistry.find(TI(op)) != OpRegistry.end());
} }
mkldnn::memory::format CreateNativeDataFormat(const ngraph::runtime::cpu::LayoutDescriptor& layout) mkldnn::memory::format
CreateNativeDataFormat(const ngraph::runtime::cpu::LayoutDescriptor& layout)
{ {
switch(layout.get_shape().size()) switch (layout.get_shape().size())
{ {
case 1: case 1: return mkldnn::memory::format::x;
return mkldnn::memory::format::x; case 2: return mkldnn::memory::format::nc;
case 2: case 4: return mkldnn::memory::format::nchw;
return mkldnn::memory::format::nc; default: return mkldnn::memory::format::format_undef;
case 4:
return mkldnn::memory::format::nchw;
default:
return mkldnn::memory::format::format_undef;
} }
} }
} }
......
...@@ -32,7 +32,8 @@ namespace ngraph ...@@ -32,7 +32,8 @@ namespace ngraph
namespace MKLDNN namespace MKLDNN
{ {
bool IsMKLDNNOp(ngraph::Node& op); bool IsMKLDNNOp(ngraph::Node& op);
mkldnn::memory::format CreateNativeDataFormat(const ngraph::runtime::cpu::LayoutDescriptor& layout); mkldnn::memory::format
CreateNativeDataFormat(const ngraph::runtime::cpu::LayoutDescriptor& layout);
} }
} }
} }
......
...@@ -40,10 +40,11 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) ...@@ -40,10 +40,11 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
auto& tensor = tv->get_tensor(); auto& tensor = tv->get_tensor();
auto rank = tvt->get_shape().size(); auto rank = tvt->get_shape().size();
auto native_axis_order = ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank); auto native_axis_order =
ngraph::runtime::cpu::LayoutDescriptor::create_native_axis_order(rank);
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>( auto layout =
*tv, native_axis_order); std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(*tv, native_axis_order);
if (tensor.is_output() || tensor.is_input() || tensor.is_constant()) if (tensor.is_output() || tensor.is_input() || tensor.is_constant())
{ {
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "ngraph/descriptor/tensor_view.hpp" #include "ngraph/descriptor/tensor_view.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment