Unverified Commit 4fc1a478 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Eliminate redundant copies due to op::Result (#612)

* removing extra copies due to op::Result

* remove comment

* fix comment

* switch to a flag version

* add copyright header #pragma once

* add impl file, rename result_elimination.hpp to result_copy_elimination.hpp to match the opt name

* add cpp suffix to result_copy_elimination

* use member in-class member init
parent 5885c09a
...@@ -97,6 +97,7 @@ set (SRC ...@@ -97,6 +97,7 @@ set (SRC
pass/memory_visualize.cpp pass/memory_visualize.cpp
pass/pass.cpp pass/pass.cpp
pass/reshape_elimination.cpp pass/reshape_elimination.cpp
pass/result_copy_elimination.cpp
pass/visualize_tree.cpp pass/visualize_tree.cpp
pattern/matcher.cpp pattern/matcher.cpp
pattern/core_fusion.cpp pattern/core_fusion.cpp
......
...@@ -49,5 +49,7 @@ std::shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args) ...@@ -49,5 +49,7 @@ std::shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args)
throw ngraph_error("Expected a single-output argument"); throw ngraph_error("Expected a single-output argument");
} }
return std::make_shared<Result>(new_args.at(0)); auto res = std::make_shared<Result>(new_args.at(0));
res->set_needs_copy(res->needs_copy());
return res;
} }
...@@ -36,12 +36,17 @@ namespace ngraph ...@@ -36,12 +36,17 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_output() const override { return true; } virtual bool is_output() const override { return true; }
void set_needs_copy(bool val) { m_needs_copy = val; }
bool needs_copy() const { return m_needs_copy; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override const std::shared_ptr<Node>& delta) override
{ {
adjoints.add_delta(get_input_op(0), delta); adjoints.add_delta(get_input_op(0), delta);
} }
private:
bool m_needs_copy{true};
}; };
} }
} }
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "result_copy_elimination.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/result.hpp"
#include "ngraph/util.hpp"
bool ngraph::pass::ResultCopyElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
{
std::set<std::shared_ptr<Node>> seen;
for (auto res : f->get_results())
{
auto arg = res->get_input_op(0);
//we need a copy
if (arg->is_parameter() || arg->is_constant())
{
continue;
}
//TODO: check if broadcast replace op::Result w/ a copy of broadcast node
//TODO: consider other cases where it's easier to recompute than make a copy
//we will compute the result directly into output[]
if (seen.count(arg) == 0)
{
res->set_needs_copy(false);
seen.insert(arg);
}
}
return true;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ResultCopyElimination;
}
}
class ngraph::pass::ResultCopyElimination : public ngraph::pass::FunctionPass
{
public:
ResultCopyElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
...@@ -3550,6 +3550,13 @@ namespace ngraph ...@@ -3550,6 +3550,13 @@ namespace ngraph
template <> template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Result) void CPU_Emitter::EMITTER_DECL(ngraph::op::Result)
{ {
const ngraph::op::Result* result = static_cast<const ngraph::op::Result*>(node);
if (!result->needs_copy())
{
return;
}
writer << "kernel::result<" << out[0].get_type() << ">(" << args[0].get_name() writer << "kernel::result<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n"; << ",\n";
writer << " " << out[0].get_name() << ",\n"; writer << " " << out[0].get_name() << ",\n";
......
...@@ -100,6 +100,7 @@ ...@@ -100,6 +100,7 @@
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/result_copy_elimination.hpp"
#include "ngraph/pattern/core_fusion.hpp" #include "ngraph/pattern/core_fusion.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp" #include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp" #include "ngraph/runtime/cpu/cpu_call_frame.hpp"
...@@ -278,10 +279,11 @@ void runtime::cpu::CPU_ExternalFunction::compile() ...@@ -278,10 +279,11 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(); pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this); pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this); pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>(); pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment); pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.run_passes(m_function); pass_manager.run_passes(m_function);
codegen::CodeWriter writer; codegen::CodeWriter writer;
bool include_mkldnn_headers = false; bool include_mkldnn_headers = false;
...@@ -638,6 +640,16 @@ using namespace ngraph::runtime; ...@@ -638,6 +640,16 @@ using namespace ngraph::runtime;
stringstream ss; stringstream ss;
ss << "((" << type << "*)(outputs[" << i << "]))"; ss << "((" << type << "*)(outputs[" << i << "]))";
m_variable_name_map[tv->get_tensor().get_name()] = ss.str(); m_variable_name_map[tv->get_tensor().get_name()] = ss.str();
//it should be safe to assign both descriptors to one output*
//since needs_copy == false makes `op::Result` an nop
auto res = std::dynamic_pointer_cast<ngraph::op::Result>(op);
if (!res->needs_copy())
{
shared_ptr<descriptor::TensorView> itv =
res->get_input_op(0)->get_output_tensor_view();
m_variable_name_map[itv->get_tensor().get_name()] = ss.str();
}
} }
for (shared_ptr<Node> node : current_function->get_ordered_ops()) for (shared_ptr<Node> node : current_function->get_ordered_ops())
...@@ -829,7 +841,6 @@ using namespace ngraph::runtime; ...@@ -829,7 +841,6 @@ using namespace ngraph::runtime;
} }
// TODO: Cleanup and make this a utility function // TODO: Cleanup and make this a utility function
file_util::make_directory(s_output_dir); file_util::make_directory(s_output_dir);
string filename = file_util::path_join(s_output_dir, m_function_name + "_codegen.cpp"); string filename = file_util::path_join(s_output_dir, m_function_name + "_codegen.cpp");
ofstream out(filename); ofstream out(filename);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment