Commit 4356b2cd authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

WIP

parent 6bd5db00
...@@ -24,8 +24,25 @@ namespace ngraph ...@@ -24,8 +24,25 @@ namespace ngraph
const AxisVector LayoutDescriptor::Native4DAxisOrder{0, 1, 2, 3}; const AxisVector LayoutDescriptor::Native4DAxisOrder{0, 1, 2, 3};
const AxisVector LayoutDescriptor::CHWNAxisOrder{1, 2, 3, 0}; const AxisVector LayoutDescriptor::CHWNAxisOrder{1, 2, 3, 0};
size_t LayoutDescriptor::LayoutDescriptor(const ngraph::descriptor::TensorView& tv,
LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices) const AxisVector& tv_axis_order)
: TensorViewLayout(tv)
, axis_order(tv_axis_order)
, offset(0)
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, mkldnn_format(mkldnn_format_undef)
{
if (tv_axis_order == Native2DAxisOrder ||
tv_axis_order == Native4DAxisOrder) {
strides = ngraph::row_major_strides(get_shape());
}
else
{
throw ngraph_error("Axis ordering not handled yet");
}
}
size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
{ {
if (indices.size() != strides.size()) if (indices.size() != strides.size())
{ {
......
...@@ -33,22 +33,7 @@ namespace ngraph ...@@ -33,22 +33,7 @@ namespace ngraph
{ {
public: public:
LayoutDescriptor(const ngraph::descriptor::TensorView& tv, LayoutDescriptor(const ngraph::descriptor::TensorView& tv,
const AxisVector& tv_axis_order) const AxisVector& tv_axis_order);
: TensorViewLayout(tv)
, axis_order(tv_axis_order)
, offset(0)
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, mkldnn_format(mkldnn_format_undef)
{
if (tv_axis_order == Native2DAxisOrder ||
tv_axis_order == Native4DAxisOrder) {
strides = ngraph::row_major_strides(get_shape());
}
else
{
throw ngraph_error("Axis ordering not handled yet");
}
}
size_t get_size() override { return size; } size_t get_size() override { return size; }
size_t get_offset() const { return offset; } size_t get_offset() const { return offset; }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include "cpu_layout.hpp" #include "cpu_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
using namespace ngraph::runtime::cpu::pass; using namespace ngraph::runtime::cpu::pass;
...@@ -20,7 +22,21 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) ...@@ -20,7 +22,21 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{ {
for (const auto& node : nodes) for (const auto& node : nodes)
{ {
for (size_t i = 0; i < node->get_output_size(); ++i)
{
auto tv = node->get_output_tensor_view(i);
if (tv->get_tensor_view_layout())
{
continue;
}
auto tvt = tv.get_tensor_view_type();
auto rank = tvt->get_shape().size();
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(*tv);
tv->set_tensor_view_layout(layout);
}
} }
return false; return false;
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#pragma once #pragma once
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
namespace ngraph namespace ngraph
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment