Unverified Commit 2da03f8a authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Allow topological sort in Function to be replaced (#4206)

* Add replaceable topological sort to Function

* Cleanup

* Cleanup unit test

* Address review comment

* Fix missed item in merge
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d80c8e42
......@@ -38,6 +38,7 @@ Function::Function(const ResultVector& results,
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_name(name)
, m_unique_name("Function_" + to_string(m_instance_id))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{
init();
}
......@@ -50,6 +51,7 @@ Function::Function(const OutputVector& results,
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_name(name)
, m_unique_name("Function_" + to_string(m_instance_id))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{
init();
}
......@@ -62,6 +64,7 @@ Function::Function(const NodeVector& results,
, m_instance_id(m_next_instance_id.fetch_add(1))
, m_name(name)
, m_unique_name("Function_" + to_string(m_instance_id))
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
{
init();
}
......@@ -98,7 +101,7 @@ void Function::init()
std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_deps) const
{
NodeVector nodes;
vector<shared_ptr<Node>> nodes;
for (auto& r : get_results())
{
nodes.push_back(r);
......@@ -108,7 +111,7 @@ std::vector<shared_ptr<Node>> Function::get_ordered_ops(bool include_control_dep
nodes.push_back(param);
}
return topological_sort(nodes, include_control_deps);
return m_topological_sorter(nodes, include_control_deps);
}
void Function::map_unordered_ops(std::function<void(Node*)> f) const
......@@ -295,3 +298,8 @@ void Function::replace_parameter(size_t parameter_index, const shared_ptr<op::Pa
replace_node(m_parameters[parameter_index], parameter);
m_parameters[parameter_index] = parameter;
}
void Function::set_topological_sort(topological_sort_t sorter)
{
m_topological_sorter = sorter;
}
......@@ -126,6 +126,10 @@ namespace ngraph
void replace_parameter(size_t parameter_index,
const std::shared_ptr<op::Parameter>& parameter);
using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
const std::vector<std::shared_ptr<Node>>& root_nodes, bool include_control_deps)>;
void set_topological_sort(topological_sort_t);
protected:
size_t m_temporary_pool_size;
......@@ -139,5 +143,6 @@ namespace ngraph
std::string m_name;
const std::string m_unique_name;
size_t m_placement{0};
topological_sort_t m_topological_sorter;
};
}
......@@ -691,3 +691,24 @@ TEST(util, clone_function_op_annotations)
EXPECT_TRUE(found_A);
EXPECT_TRUE(found_B);
}
TEST(util, topological_sort_replace)
{
Shape shape{2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(A + B + C, ParameterVector{A, B, C});
bool custom_sorter_used = false;
f->set_topological_sort([&custom_sorter_used](
const std::vector<std::shared_ptr<Node>>& root_nodes, bool include_control_deps) {
custom_sorter_used = true;
return topological_sort(root_nodes, include_control_deps);
});
// Need to now call topological sort but don't care about the results
f->get_ordered_ops();
EXPECT_TRUE(custom_sorter_used);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment