Commit cbd3b53e authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

WIP

parent 1f76a2a7
......@@ -83,7 +83,6 @@
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
......@@ -93,6 +92,7 @@
#include "ngraph/runtime/cpu/cpu_emitter.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/ops/matmul_bias.hpp"
#include "ngraph/runtime/cpu/pass/cpu_layout.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
using namespace std;
......@@ -220,11 +220,12 @@ void runtime::cpu::CPU_ExternalFunction::compile()
string function_name = m_function->get_name();
pass::Manager pass_manager;
ngraph::pass::Manager pass_manager;
// For now, just make everyone row-major.
pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<pass::Liveness>();
pass_manager.register_pass<pass::MemoryLayout>(64);
//pass_manager.register_pass<pass::AssignLayout<descriptor::layout::DenseTensorViewLayout>>();
pass_manager.register_pass<runtime::cpu::pass::CPULayout>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(64);
pass_manager.run_passes(m_function);
codegen::CodeWriter writer;
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "cpu_layout_descriptor.hpp"
namespace ngraph
......@@ -32,21 +34,32 @@ namespace ngraph
, size(ngraph::shape_size(tv.get_tensor_view_type()->get_shape()))
, mkldnn_format(mkldnn_format_undef)
{
if (tv_axis_order == Native2DAxisOrder ||
tv_axis_order == Native4DAxisOrder) {
strides = ngraph::row_major_strides(get_shape());
auto shape = get_shape();
size_t s = 1;
if (tv_axis_order.size() != shape.size())
{
throw ngraph_error("Axis order is incomplete");
}
else
for (auto it = tv_axis_order.crbegin(); it != tv_axis_order.crend(); it++)
{
throw ngraph_error("Axis ordering not handled yet");
if (*it >= shape.size())
{
throw ngraph_error("Axis is out of bounds");
}
strides.emplace_back(shape[*it]);
s *= shape[*it];
}
std::reverse(strides.begin(), strides.end());
}
size_t LayoutDescriptor::get_index_offset(const std::vector<size_t>& indices)
{
if (indices.size() != strides.size())
{
throw ngraph_error("Indices have the incorrect rank.");
throw ngraph_error("Indices have incorrect rank");
}
size_t result = 0;
for (int i = 0; i < indices.size(); i++)
......@@ -56,7 +69,8 @@ namespace ngraph
return result;
}
bool LayoutDescriptor::operator==(const ngraph::descriptor::layout::TensorViewLayout& other) const
bool LayoutDescriptor::
operator==(const ngraph::descriptor::layout::TensorViewLayout& other) const
{
const LayoutDescriptor* p_other = dynamic_cast<const LayoutDescriptor*>(&other);
if (!p_other)
......
......@@ -19,8 +19,8 @@
#include <mkldnn_types.h>
#include "ngraph/common.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/descriptor/layout/tensor_view_layout.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/types/type.hpp"
namespace ngraph
......@@ -42,21 +42,19 @@ namespace ngraph
const Strides& get_strides() const override { return strides; }
bool operator==(const TensorViewLayout& other) const override;
mkldnn_memory_format_t get_mkldnn_format() const
{
return mkldnn_format;
}
mkldnn_memory_format_t get_mkldnn_format() const { return mkldnn_format; }
const AxisVector& get_axis_order() const { return axis_order; }
static const AxisVector Native2DAxisOrder;
static const AxisVector Native4DAxisOrder;
static const AxisVector CHWNAxisOrder;
private:
AxisVector axis_order;
Strides strides;
size_t offset;
size_t size;
// Numeric backend-specific fields
// Numeric backend-specific fields
mkldnn_memory_format_t mkldnn_format;
};
}
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include "cpu_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
......@@ -30,12 +32,25 @@ bool CPULayout::run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes)
continue;
}
auto tvt = tv.get_tensor_view_type();
auto tvt = tv->get_tensor_view_type();
auto& tensor = tv->get_tensor();
auto rank = tvt->get_shape().size();
AxisVector native_axis_order(rank);
std::iota(native_axis_order.begin(), native_axis_order.end(), 0);
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(*tv);
tv->set_tensor_view_layout(layout);
if (tensor.is_output() || tensor.is_input() || tensor.is_constant())
{
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(
*tv, native_axis_order);
tv->set_tensor_view_layout(layout);
}
else
{
auto layout = std::make_shared<ngraph::runtime::cpu::LayoutDescriptor>(
*tv, native_axis_order);
tv->set_tensor_view_layout(layout);
}
}
}
......
......@@ -27,7 +27,8 @@ namespace ngraph
class CPULayout : public ngraph::pass::CallGraphPass
{
public:
virtual bool run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
virtual bool
run_on_call_graph(const std::list<std::shared_ptr<Node>>& nodes) override;
};
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment