Unverified Commit 4fc1a478 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Eliminate redundant copies due to op::Result (#612)

* removing extra copies due to op::Result

* remove comment

* fix comment

* switch to a flag version

* add copyright header #pragma once

* add impl file, rename result_elimination.hpp to result_copy_elimination.hpp to match the opt name

* add cpp suffix to result_copy_elimination

* use member in-class member init
parent 5885c09a
......@@ -97,6 +97,7 @@ set (SRC
pass/memory_visualize.cpp
pass/pass.cpp
pass/reshape_elimination.cpp
pass/result_copy_elimination.cpp
pass/visualize_tree.cpp
pattern/matcher.cpp
pattern/core_fusion.cpp
......
......@@ -49,5 +49,7 @@ std::shared_ptr<Node> op::Result::copy_with_new_args(const NodeVector& new_args)
throw ngraph_error("Expected a single-output argument");
}
return std::make_shared<Result>(new_args.at(0));
auto res = std::make_shared<Result>(new_args.at(0));
res->set_needs_copy(res->needs_copy());
return res;
}
......@@ -36,12 +36,17 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_output() const override { return true; }
void set_needs_copy(bool val) { m_needs_copy = val; }
bool needs_copy() const { return m_needs_copy; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override
{
adjoints.add_delta(get_input_op(0), delta);
}
private:
bool m_needs_copy{true};
};
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "result_copy_elimination.hpp"
#include "ngraph/node.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/result.hpp"
#include "ngraph/util.hpp"
bool ngraph::pass::ResultCopyElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
{
std::set<std::shared_ptr<Node>> seen;
for (auto res : f->get_results())
{
auto arg = res->get_input_op(0);
//we need a copy
if (arg->is_parameter() || arg->is_constant())
{
continue;
}
//TODO: check if broadcast replace op::Result w/ a copy of broadcast node
//TODO: consider other cases where it's easier to recompute than make a copy
//we will compute the result directly into output[]
if (seen.count(arg) == 0)
{
res->set_needs_copy(false);
seen.insert(arg);
}
}
return true;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class ResultCopyElimination;
}
}
class ngraph::pass::ResultCopyElimination : public ngraph::pass::FunctionPass
{
public:
ResultCopyElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
......@@ -3550,6 +3550,13 @@ namespace ngraph
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Result)
{
const ngraph::op::Result* result = static_cast<const ngraph::op::Result*>(node);
if (!result->needs_copy())
{
return;
}
writer << "kernel::result<" << out[0].get_type() << ">(" << args[0].get_name()
<< ",\n";
writer << " " << out[0].get_name() << ",\n";
......
......@@ -100,6 +100,7 @@
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/result_copy_elimination.hpp"
#include "ngraph/pattern/core_fusion.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
......@@ -278,10 +279,11 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<ngraph::pass::ResultCopyElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(s_memory_pool_alignment);
pass_manager.run_passes(m_function);
codegen::CodeWriter writer;
bool include_mkldnn_headers = false;
......@@ -638,6 +640,16 @@ using namespace ngraph::runtime;
stringstream ss;
ss << "((" << type << "*)(outputs[" << i << "]))";
m_variable_name_map[tv->get_tensor().get_name()] = ss.str();
//it should be safe to assign both descriptors to one output*
//since needs_copy == false makes `op::Result` an nop
auto res = std::dynamic_pointer_cast<ngraph::op::Result>(op);
if (!res->needs_copy())
{
shared_ptr<descriptor::TensorView> itv =
res->get_input_op(0)->get_output_tensor_view();
m_variable_name_map[itv->get_tensor().get_name()] = ss.str();
}
}
for (shared_ptr<Node> node : current_function->get_ordered_ops())
......@@ -829,7 +841,6 @@ using namespace ngraph::runtime;
}
// TODO: Cleanup and make this a utility function
file_util::make_directory(s_output_dir);
string filename = file_util::path_join(s_output_dir, m_function_name + "_codegen.cpp");
ofstream out(filename);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment