Commit 4a1e308a authored by Ilya Churaev's avatar Ilya Churaev Committed by Scott Cyphers

Added reproducer for specialize_function (#3931)

* Added reproducer for specialize_function

* Remove GOEs that might have been introduced while cloning nodes in function specialization

* Address PR feedback

* Fix incorrect merge
parent 35a25cc8
......@@ -18,30 +18,25 @@
#include "get_output_element_elimination.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp"
using namespace ngraph;
using namespace std;
bool pass::GetOutputElementElimination::run_on_node(shared_ptr<Node> n)
bool pass::GetOutputElementElimination::run_on_function(shared_ptr<Function> f)
{
bool optimized = false;
for (auto& input : n->inputs())
for (auto& n : f->get_ops())
{
if (auto goe = dynamic_cast<op::GetOutputElement*>(input.get_source_output().get_node()))
for (auto& input : n->inputs())
{
input.replace_source_output(goe->input(0).get_source_output());
// we don't need to fix anything w.r.t GetOutputElement as it will become unreachable
optimized = true;
if (auto goe = as_type<op::GetOutputElement>(input.get_source_output().get_node()))
{
input.replace_source_output(goe->input(0).get_source_output());
// we don't need to fix anything w.r.t GetOutputElement as it will become
// unreachable
optimized = true;
}
}
}
return optimized;
......
......@@ -26,8 +26,8 @@ namespace ngraph
}
}
class NGRAPH_API ngraph::pass::GetOutputElementElimination : public NodePass
class NGRAPH_API ngraph::pass::GetOutputElementElimination : public FunctionPass
{
public:
bool run_on_node(std::shared_ptr<Node> node) override;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
......@@ -16,6 +16,7 @@
#include "ngraph/specialize_function.hpp"
#include <pass/constant_folding.hpp>
#include <pass/get_output_element_elimination.hpp>
#include "ngraph/op/constant.hpp"
using namespace ngraph;
......@@ -120,5 +121,8 @@ std::shared_ptr<Function>
{
ngraph::pass::ConstantFolding().run_on_function(function);
}
ngraph::pass::GetOutputElementElimination().run_on_function(function);
return function;
}
......@@ -362,3 +362,28 @@ TEST(specialize_function, share_constants_with_cf)
ASSERT_EQ(add_const_1->output(0).get_target_inputs().size(), 1);
ASSERT_EQ(add_const_2->output(0).get_target_inputs().size(), 1);
}
TEST(specialize_function, copy_network_with_split)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 64, 64});
auto split = std::make_shared<op::Split>(p0, 1, 3);
auto res1 = std::make_shared<op::Result>(split->output(0));
auto res2 = std::make_shared<op::Result>(split->output(1));
auto res3 = std::make_shared<op::Result>(split->output(2));
ResultVector res = {res1, res2, res3};
auto f = std::make_shared<Function>(res, ParameterVector{p0});
auto f_specialized = specialize_function(
f, {element::f32}, {PartialShape{1, 3, 64, 64}}, {nullptr}, false, false);
for (const auto& op : f->get_ops())
{
ASSERT_FALSE(is_type<op::GetOutputElement>(op));
}
for (const auto& op : f_specialized->get_ops())
{
ASSERT_FALSE(is_type<op::GetOutputElement>(op));
}
ASSERT_EQ(5, f->get_ops().size());
ASSERT_EQ(f_specialized->get_ops().size(), f->get_ops().size());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment