Commit 222de9fe authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Move layout assignment to a pass

parent b098aaf4
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <exception>
#include <sstream>
#include "ngraph/descriptor/output.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
template <typename LT>
class AssignLayout : public CallGraphPass
{
public:
virtual bool run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) override
{
for (const std::shared_ptr<Node>& node : nodes)
{
try
{
for (const descriptor::Output& output : node->get_outputs())
{
auto tv = output.get_tensor_view();
if (nullptr == tv->get_tensor_view_layout())
{
auto layout = std::make_shared<LT>(*tv);
tv->set_tensor_view_layout(layout);
}
}
}
catch (const std::exception& e)
{
std::stringstream ss;
ss << "Error with node " << *node << ": ";
ss << e.what();
throw std::invalid_argument(ss.str());
}
}
return false;
}
};
}
}
...@@ -50,6 +50,7 @@ ...@@ -50,6 +50,7 @@
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_tensors.hpp" #include "ngraph/pass/assign_tensors.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/propagate_types.hpp" #include "ngraph/pass/propagate_types.hpp"
...@@ -123,29 +124,13 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -123,29 +124,13 @@ void ExternalFunction::compile(FunctionMap& function_map)
return; return;
} }
// This will be replaced with the pass manager
// Get the ordered list of ops in execution order
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::TopologicalSort>(); pass_manager.register_pass<pass::TopologicalSort>();
pass_manager.register_pass<pass::PropagateTypes>(); pass_manager.register_pass<pass::PropagateTypes>();
pass_manager.register_pass<pass::AssignTensors>(); pass_manager.register_pass<pass::AssignTensors>();
pass_manager.run_passes(m_function);
// Turn this into a pass
// Assign layouts
// For now, just make everyone row-major. // For now, just make everyone row-major.
for (shared_ptr<Node> node : m_function->get_ordered_ops()) pass_manager.register_pass<pass::AssignLayout<DenseTensorViewLayout>>();
{ pass_manager.run_passes(m_function);
for (const descriptor::Output& output : node->get_outputs())
{
auto tv = output.get_tensor_view();
if (nullptr == tv->get_tensor_view_layout())
{
auto layout = std::make_shared<DenseTensorViewLayout>(*tv);
tv->set_tensor_view_layout(layout);
}
}
}
// Determine tensor requirements for the call frame // Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index; unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment