Unverified Commit f227b591 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Simplify the access of placement in Function (#2405)

* prep work

* wip

* remove debug

* style

* update unit test
parent 6beb6732
...@@ -214,3 +214,13 @@ size_t Function::get_graph_size() const ...@@ -214,3 +214,13 @@ size_t Function::get_graph_size() const
} }
return total_size; return total_size;
} }
size_t Function::get_placement() const
{
return m_placement;
}
void Function::set_placement(size_t placement)
{
m_placement = placement;
}
...@@ -90,6 +90,9 @@ namespace ngraph ...@@ -90,6 +90,9 @@ namespace ngraph
/// graphs and should not be considered the actual memory consumption of a graph. /// graphs and should not be considered the actual memory consumption of a graph.
size_t get_graph_size() const; size_t get_graph_size() const;
size_t get_placement() const;
void set_placement(size_t placement);
protected: protected:
ResultVector m_results; ResultVector m_results;
ParameterVector m_parameters; ParameterVector m_parameters;
...@@ -104,5 +107,6 @@ namespace ngraph ...@@ -104,5 +107,6 @@ namespace ngraph
size_t m_instance_id; size_t m_instance_id;
std::string m_name; std::string m_name;
const std::string m_unique_name; const std::string m_unique_name;
size_t m_placement;
}; };
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
add_library(hybrid_base STATIC add_library(hybrid_base STATIC
hybrid_backend.cpp hybrid_backend.cpp
hybrid_util.cpp hybrid_util.cpp
pass/assign_placement.cpp pass/default_placement.cpp
pass/dump.cpp pass/dump.cpp
pass/fix_get_output_element.cpp pass/fix_get_output_element.cpp
pass/liveness.cpp pass/liveness.cpp
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/hybrid/hybrid_util.hpp" #include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/runtime/hybrid/pass/assign_placement.hpp" #include "ngraph/runtime/hybrid/pass/default_placement.hpp"
#include "ngraph/runtime/hybrid/pass/dump.hpp" #include "ngraph/runtime/hybrid/pass/dump.hpp"
#include "ngraph/runtime/hybrid/pass/fix_get_output_element.hpp" #include "ngraph/runtime/hybrid/pass/fix_get_output_element.hpp"
#include "ngraph/runtime/hybrid/pass/liveness.hpp" #include "ngraph/runtime/hybrid/pass/liveness.hpp"
...@@ -74,7 +74,7 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun ...@@ -74,7 +74,7 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
// Run placement pass // Run placement pass
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::AssignPlacement>(m_backend_list); pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>(); pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
pass_manager.register_pass<runtime::hybrid::pass::Liveness>(); pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump"); pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
...@@ -94,7 +94,7 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun ...@@ -94,7 +94,7 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
size_t subfunction_number = 0; size_t subfunction_number = 0;
for (shared_ptr<Function>& sub_function : instance.m_sub_functions) for (shared_ptr<Function>& sub_function : instance.m_sub_functions)
{ {
size_t placement = runtime::hybrid::get_colocated_function_placement(sub_function); size_t placement = sub_function->get_placement();
if (m_debug_enabled) if (m_debug_enabled)
{ {
string name = "subfunction_" + to_string(subfunction_number++); string name = "subfunction_" + to_string(subfunction_number++);
...@@ -149,7 +149,7 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func, ...@@ -149,7 +149,7 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
for (const shared_ptr<Function>& sub_function : instance.m_sub_functions) for (const shared_ptr<Function>& sub_function : instance.m_sub_functions)
{ {
// Init backend // Init backend
size_t placement = runtime::hybrid::get_colocated_function_placement(sub_function); size_t placement = sub_function->get_placement();
auto backend = m_backend_list[placement]; auto backend = m_backend_list[placement];
// Prepare parameter Tensors // Prepare parameter Tensors
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/hybrid/hybrid_util.hpp" #include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp" #include "ngraph/pass/visualize_tree.hpp"
...@@ -25,20 +26,20 @@ static Node* take_independent_node_with_placement_priority( ...@@ -25,20 +26,20 @@ static Node* take_independent_node_with_placement_priority(
map<size_t, deque<Node*>>& independent_nodes_by_placement, size_t placement) map<size_t, deque<Node*>>& independent_nodes_by_placement, size_t placement)
{ {
Node* selected_node = nullptr; Node* selected_node = nullptr;
if (independent_nodes_by_placement.find(placement) != independent_nodes_by_placement.end() && auto it = independent_nodes_by_placement.find(placement);
independent_nodes_by_placement.at(placement).size() != 0) if (it != independent_nodes_by_placement.end() && it->second.size() != 0)
{ {
selected_node = independent_nodes_by_placement.at(placement).front(); selected_node = it->second.front();
independent_nodes_by_placement.at(placement).pop_front(); it->second.pop_front();
} }
else else
{ {
for (auto& it : independent_nodes_by_placement) for (auto& p : independent_nodes_by_placement)
{ {
if (it.second.size() > 0) if (p.second.size() > 0)
{ {
selected_node = it.second.front(); selected_node = p.second.front();
it.second.pop_front(); p.second.pop_front();
break; break;
} }
} }
...@@ -238,8 +239,10 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar ...@@ -238,8 +239,10 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar
{ {
ParameterVector par_vector; ParameterVector par_vector;
ResultVector res_vector; ResultVector res_vector;
size_t placement = -1;
for (auto node : cluster) for (auto node : cluster)
{ {
placement = node->get_placement_index();
if (auto res_node = dynamic_pointer_cast<op::Result>(node)) if (auto res_node = dynamic_pointer_cast<op::Result>(node))
{ {
res_vector.push_back(res_node); res_vector.push_back(res_node);
...@@ -250,6 +253,7 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar ...@@ -250,6 +253,7 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar
} }
} }
auto sub_function = make_shared<Function>(res_vector, par_vector); auto sub_function = make_shared<Function>(res_vector, par_vector);
sub_function->set_placement(placement);
sub_functions.push_back(sub_function); sub_functions.push_back(sub_function);
#ifdef HYBRID_DEBUG #ifdef HYBRID_DEBUG
ngraph::pass::Manager pass_manager; ngraph::pass::Manager pass_manager;
...@@ -261,26 +265,3 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar ...@@ -261,26 +265,3 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar
return make_pair(sub_functions, map_parameter_to_result); return make_pair(sub_functions, map_parameter_to_result);
} }
// Assert that nodes in the function is colocated and return that placement
size_t runtime::hybrid::get_colocated_function_placement(shared_ptr<Function> func)
{
auto ops = func->get_ops();
//it's okay to not do Placement::DEFAULT check; the same node will be checked in the loop below
size_t function_placement = ops.front()->get_placement_index();
for (auto op : ops)
{
size_t node_placement = op->get_placement_index();
if (node_placement == Node::placement_invalid)
{
throw ngraph_error("Node " + op->get_name() + " should have a device placement");
}
if (function_placement != node_placement)
{
throw ngraph_error("Function contains nodes of two different placements");
}
}
return function_placement;
}
...@@ -35,9 +35,6 @@ namespace ngraph ...@@ -35,9 +35,6 @@ namespace ngraph
std::vector<std::shared_ptr<Function>>, std::vector<std::shared_ptr<Function>>,
std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Result>>> std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Result>>>
split_function_by_placement(const std::shared_ptr<Function>& f); split_function_by_placement(const std::shared_ptr<Function>& f);
// Assert that nodes in the function is colocated and return that placement
size_t get_colocated_function_placement(std::shared_ptr<Function> func);
} }
} }
} }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/hybrid/pass/assign_placement.hpp" #include "ngraph/runtime/hybrid/pass/default_placement.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/placement.hpp" #include "ngraph/placement.hpp"
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
runtime::hybrid::pass::AssignPlacement::AssignPlacement( runtime::hybrid::pass::DefaultPlacement::DefaultPlacement(
const vector<shared_ptr<runtime::Backend>>& placement_backends) const vector<shared_ptr<runtime::Backend>>& placement_backends)
: m_placement_backends(placement_backends) : m_placement_backends(placement_backends)
{ {
} }
bool runtime::hybrid::pass::AssignPlacement::run_on_node(shared_ptr<Node> node) bool runtime::hybrid::pass::DefaultPlacement::run_on_node(shared_ptr<Node> node)
{ {
size_t backend_index = 0; size_t backend_index = 0;
for (auto backend : m_placement_backends) for (auto backend : m_placement_backends)
......
...@@ -30,16 +30,16 @@ namespace ngraph ...@@ -30,16 +30,16 @@ namespace ngraph
{ {
namespace pass namespace pass
{ {
class AssignPlacement; class DefaultPlacement;
} }
} }
} }
} }
class ngraph::runtime::hybrid::pass::AssignPlacement : public ngraph::pass::NodePass class ngraph::runtime::hybrid::pass::DefaultPlacement : public ngraph::pass::NodePass
{ {
public: public:
AssignPlacement( DefaultPlacement(
const std::vector<std::shared_ptr<ngraph::runtime::Backend>>& placement_backends); const std::vector<std::shared_ptr<ngraph::runtime::Backend>>& placement_backends);
private: private:
......
...@@ -56,7 +56,7 @@ TEST(HYBRID, abc) ...@@ -56,7 +56,7 @@ TEST(HYBRID, abc)
auto t1 = A * B; auto t1 = A * B;
auto t2 = t1 * D; auto t2 = t1 * D;
auto C = make_shared<op::Parameter>(element::f32, shape); auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((t2 + C) * t1, ParameterVector{A, B, C, D}); auto f = make_shared<Function>(((t2 + C) + A) * t1, ParameterVector{A, B, C, D});
shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1"); shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1");
static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true); static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true);
...@@ -75,5 +75,5 @@ TEST(HYBRID, abc) ...@@ -75,5 +75,5 @@ TEST(HYBRID, abc)
auto handle = backend->compile(f); auto handle = backend->compile(f);
backend->call_with_validate(handle, {result}, {a, b, c, d}); backend->call_with_validate(handle, {result}, {a, b, c, d});
EXPECT_EQ(read_vector<float>(result), (vector<float>{145, 552, 1113, 1408})); EXPECT_EQ(read_vector<float>(result), (vector<float>{150, 576, 1176, 1536}));
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment