Commit 4a1e308a authored by Ilya Churaev's avatar Ilya Churaev Committed by Scott Cyphers

Added reproducer for specialize_function (#3931)

* Added reproducer for specialize_function

* Remove GOEs that might have been introduced while cloning nodes in function specialization

* Address PR feedback

* Fix incorrect merge
parent 35a25cc8
...@@ -18,30 +18,25 @@ ...@@ -18,30 +18,25 @@
#include "get_output_element_elimination.hpp" #include "get_output_element_elimination.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
bool pass::GetOutputElementElimination::run_on_node(shared_ptr<Node> n) bool pass::GetOutputElementElimination::run_on_function(shared_ptr<Function> f)
{ {
bool optimized = false; bool optimized = false;
for (auto& input : n->inputs()) for (auto& n : f->get_ops())
{ {
if (auto goe = dynamic_cast<op::GetOutputElement*>(input.get_source_output().get_node())) for (auto& input : n->inputs())
{ {
input.replace_source_output(goe->input(0).get_source_output()); if (auto goe = as_type<op::GetOutputElement>(input.get_source_output().get_node()))
// we don't need to fix anything w.r.t GetOutputElement as it will become unreachable {
optimized = true; input.replace_source_output(goe->input(0).get_source_output());
// we don't need to fix anything w.r.t GetOutputElement as it will become
// unreachable
optimized = true;
}
} }
} }
return optimized; return optimized;
......
...@@ -26,8 +26,8 @@ namespace ngraph ...@@ -26,8 +26,8 @@ namespace ngraph
} }
} }
class NGRAPH_API ngraph::pass::GetOutputElementElimination : public NodePass class NGRAPH_API ngraph::pass::GetOutputElementElimination : public FunctionPass
{ {
public: public:
bool run_on_node(std::shared_ptr<Node> node) override; bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
}; };
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/specialize_function.hpp" #include "ngraph/specialize_function.hpp"
#include <pass/constant_folding.hpp> #include <pass/constant_folding.hpp>
#include <pass/get_output_element_elimination.hpp>
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -120,5 +121,8 @@ std::shared_ptr<Function> ...@@ -120,5 +121,8 @@ std::shared_ptr<Function>
{ {
ngraph::pass::ConstantFolding().run_on_function(function); ngraph::pass::ConstantFolding().run_on_function(function);
} }
ngraph::pass::GetOutputElementElimination().run_on_function(function);
return function; return function;
} }
...@@ -362,3 +362,28 @@ TEST(specialize_function, share_constants_with_cf) ...@@ -362,3 +362,28 @@ TEST(specialize_function, share_constants_with_cf)
ASSERT_EQ(add_const_1->output(0).get_target_inputs().size(), 1); ASSERT_EQ(add_const_1->output(0).get_target_inputs().size(), 1);
ASSERT_EQ(add_const_2->output(0).get_target_inputs().size(), 1); ASSERT_EQ(add_const_2->output(0).get_target_inputs().size(), 1);
} }
TEST(specialize_function, copy_network_with_split)
{
auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 3, 64, 64});
auto split = std::make_shared<op::Split>(p0, 1, 3);
auto res1 = std::make_shared<op::Result>(split->output(0));
auto res2 = std::make_shared<op::Result>(split->output(1));
auto res3 = std::make_shared<op::Result>(split->output(2));
ResultVector res = {res1, res2, res3};
auto f = std::make_shared<Function>(res, ParameterVector{p0});
auto f_specialized = specialize_function(
f, {element::f32}, {PartialShape{1, 3, 64, 64}}, {nullptr}, false, false);
for (const auto& op : f->get_ops())
{
ASSERT_FALSE(is_type<op::GetOutputElement>(op));
}
for (const auto& op : f_specialized->get_ops())
{
ASSERT_FALSE(is_type<op::GetOutputElement>(op));
}
ASSERT_EQ(5, f->get_ops().size());
ASSERT_EQ(f_specialized->get_ops().size(), f->get_ops().size());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment