Commit 4356b2cd authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

WIP

parent 6bd5db00
......@@ -24,8 +24,25 @@ namespace ngraph
const AxisVector LayoutDescriptor::Native4DAxisOrder{0, 1, 2, 3};
const AxisVector LayoutDescriptor::CHWNAxisOrder{1, 2, 3, 0};
size_t
LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
LayoutDescriptor::LayoutDescriptor(const ngraph::descriptor::TensorView& tv,
const AxisVector& tv_axis_order)
: TensorViewLayout(tv)
, axis_order(tv_axis_order)
, offset(0)
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, mkldnn_format(mkldnn_format_undef)
{
if (tv_axis_order == Native2DAxisOrder ||
tv_axis_order == Native4DAxisOrder) {
strides = ngraph::row_major_strides(get_shape());
}
else
{
throw ngraph_error("Axis ordering not handled yet");
}
}
size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
{
if (indices.size() != strides.size())
{
......
......@@ -33,22 +33,7 @@ namespace ngraph
{
public:
LayoutDescriptor(const ngraph::descriptor::TensorView& tv,
const AxisVector& tv_axis_order)
: TensorViewLayout(tv)
, axis_order(tv_axis_order)
, offset(0)
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, mkldnn_format(mkldnn_format_undef)
{
if (tv_axis_order == Native2DAxisOrder ||
tv_axis_order == Native4DAxisOrder) {
strides = ngraph::row_major_strides(get_shape());
}
else
{
throw ngraph_error("Axis ordering not handled yet");
}
}
const AxisVector& tv_axis_order);
size_t get_size() override { return size; }
size_t get_offset() const { return offset; }
......
......@@ -13,6 +13,8 @@
// ----------------------------------------------------------------------------
#include "cpu_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
using namespace ngraph::runtime::cpu::pass;
......@@ -20,7 +22,21 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
{
for (const auto& node : nodes)
{
for (size_t i = 0; i < node->get_output_size(); ++i)
{
auto tv = node->get_output_tensor_view(i);
if (tv->get_tensor_view_layout())
{
continue;
}
auto tvt = tv.get_tensor_view_type();
auto rank = tvt->get_shape().size();
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(*tv);
tv->set_tensor_view_layout(layout);
}
}
return false;
......
......@@ -15,7 +15,6 @@
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
namespace ngraph
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment