Commit f10022cc authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

GetOutputElement Elimination (#644)

* rough draft but needs to use get_n to get the right input

* v2 fully working but hacky

* remove hacks ; switch back build_users() to users()

* rollback hacks to node.cpp

* perms, remove prints, format
parent e7abc0f3
......@@ -88,6 +88,7 @@ set (SRC
ops/util/unary_elementwise.cpp
pass/assign_placement.cpp
pass/dump_sorted.cpp
pass/get_output_element_elimination.cpp
pass/graph_rewrite.cpp
pass/inliner.cpp
pass/liveness.cpp
......
......@@ -37,6 +37,10 @@
namespace ngraph
{
namespace pass
{
class GetOutputElementElimination;
}
namespace op
{
class Parameter;
......@@ -63,6 +67,8 @@ namespace ngraph
std::shared_ptr<Node> dst_node,
std::shared_ptr<op::Parameter> p_node);
friend class ngraph::pass::GetOutputElementElimination;
protected:
Node(const std::string& node_type, const NodeVector& arguments);
virtual ~Node()
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <set>
#include "get_output_element_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/avg_pool.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/constant.hpp"
#include "ngraph/ops/convolution.hpp"
#include "ngraph/ops/get_output_element.hpp"
#include "ngraph/ops/max_pool.hpp"
#include "ngraph/ops/pad.hpp"
#include "ngraph/ops/product.hpp"
#include "ngraph/ops/sum.hpp"
using namespace ngraph;
bool ngraph::pass::GetOutputElementElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool optimized = false;
for (auto n : f->get_ordered_ops())
{
for (auto& input : n->get_inputs())
{
if (auto goe =
std::dynamic_pointer_cast<op::GetOutputElement>(input.get_output().get_node()))
{
auto multi = goe->get_inputs().at(0).get_output().get_node();
input.replace_output(goe->get_inputs().at(goe->get_n()).get_output());
//fix node arguments
auto& n_args =
const_cast<ngraph::NodeVector&>(n->get_arguments_FOR_GRAPH_REWRITE_ONLY());
auto it = std::find(begin(n_args), end(n_args), goe);
if (it == end(n_args))
{
throw ngraph_error("Expected to find GetOutputElement in n's inputs");
}
*it = multi;
//fix multi's users
const_cast<std::multiset<Node*>&>(multi->users()).insert(n.get());
//we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
}
}
}
return optimized;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class GetOutputElementElimination;
}
}
class ngraph::pass::GetOutputElementElimination : public FunctionPass
{
public:
GetOutputElementElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -97,6 +97,7 @@
#include "ngraph/ops/tan.hpp"
#include "ngraph/ops/tanh.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
......@@ -290,6 +291,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.run_passes(m_function);
codegen::CodeWriter writer;
......@@ -619,7 +621,7 @@ using namespace ngraph::runtime;
if (!res->needs_copy())
{
shared_ptr<descriptor::TensorView> itv =
res->get_input_op(0)->get_output_tensor_view();
res->get_inputs().at(0).get_output().get_tensor_view();
m_variable_name_map[itv->get_tensor().get_name()] = ss.str();
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment