Unverified Commit 2da03f8a authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Allow topological sort in Function to be replaced (#4206)

* Add replaceable topological sort to Function

* Cleanup

* Cleanup unit test

* Address review comment

* Fix missed item in merge
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d80c8e42
...@@ -38,6 +38,7 @@ Function::Function(const ResultVector& results, ...@@ -38,6 +38,7 @@ Function::Function(const ResultVector& results,
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
, m_name(name) , m_name(name)
, m_unique_name("Function_" + to_string(m_instance_id)) , m_unique_name("Function_" + to_string(m_instance_id))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{ {
init(); init();
} }
...@@ -50,6 +51,7 @@ Function::Function(const OutputVector& results, ...@@ -50,6 +51,7 @@ Function::Function(const OutputVector& results,
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
, m_name(name) , m_name(name)
, m_unique_name("Function_" + to_string(m_instance_id)) , m_unique_name("Function_" + to_string(m_instance_id))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{ {
init(); init();
} }
...@@ -62,6 +64,7 @@ Function::Function(const NodeVector& results, ...@@ -62,6 +64,7 @@ Function::Function(const NodeVector& results,
, m_instance_id(m_next_instance_id.fetch_add(1)) , m_instance_id(m_next_instance_id.fetch_add(1))
, m_name(name) , m_name(name)
, m_unique_name("Function_" + to_string(m_instance_id)) , m_unique_name("Function_" + to_string(m_instance_id))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{ {
init(); init();
} }
...@@ -98,7 +101,7 @@ void Function::init() ...@@ -98,7 +101,7 @@ void Function::init()
std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{ {
NodeVector nodes; vector<shared_ptr<Node>> nodes;
for (auto& r : get_results()) for (auto& r : get_results())
{ {
nodes.push_back(r); nodes.push_back(r);
...@@ -108,7 +111,7 @@ std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_dep ...@@ -108,7 +111,7 @@ std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_dep
nodes.push_back(param); nodes.push_back(param);
} }
return topological_sort(nodes, include_control_deps); return m_topological_sorter(nodes, include_control_deps);
} }
void Function::map_unordered_ops(std::function<void(Node*)> f) const void Function::map_unordered_ops(std::function<void(Node*)> f) const
...@@ -295,3 +298,8 @@ void Function::replace_parameter(size_t parameter_index, const shared_ptr<op::Pa ...@@ -295,3 +298,8 @@ void Function::replace_parameter(size_t parameter_index, const shared_ptr<op::Pa
replace_node(m_parameters[parameter_index], parameter); replace_node(m_parameters[parameter_index], parameter);
m_parameters[parameter_index] = parameter; m_parameters[parameter_index] = parameter;
} }
void Function::set_topological_sort(topological_sort_t sorter)
{
m_topological_sorter = sorter;
}
...@@ -126,6 +126,10 @@ namespace ngraph ...@@ -126,6 +126,10 @@ namespace ngraph
void replace_parameter(size_t parameter_index, void replace_parameter(size_t parameter_index,
const std::shared_ptr<op::Parameter>& parameter); const std::shared_ptr<op::Parameter>& parameter);
using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
const std::vector<std::shared_ptr<Node>>& root_nodes, bool include_control_deps)>;
void set_topological_sort(topological_sort_t);
protected: protected:
size_t m_temporary_pool_size; size_t m_temporary_pool_size;
...@@ -139,5 +143,6 @@ namespace ngraph ...@@ -139,5 +143,6 @@ namespace ngraph
std::string m_name; std::string m_name;
const std::string m_unique_name; const std::string m_unique_name;
size_t m_placement{0}; size_t m_placement{0};
topological_sort_t m_topological_sorter;
}; };
} }
...@@ -691,3 +691,24 @@ TEST(util, clone_function_op_annotations) ...@@ -691,3 +691,24 @@ TEST(util, clone_function_op_annotations)
EXPECT_TRUE(found_A); EXPECT_TRUE(found_A);
EXPECT_TRUE(found_B); EXPECT_TRUE(found_B);
} }
TEST(util, topological_sort_replace)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(A + B + C, ParameterVector{A, B, C});
bool custom_sorter_used = false;
f->set_topological_sort([&custom_sorter_used](
const std::vector<std::shared_ptr<Node>>& root_nodes, bool include_control_deps) {
custom_sorter_used = true;
return topological_sort(root_nodes, include_control_deps);
});
// Need to now call topological sort but don't care about the results
f->get_ordered_ops();
EXPECT_TRUE(custom_sorter_used);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment