Unverified Commit f227b591 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Simplify the access of placement in Function (#2405)

* prep work

* wip

* remove debug

* style

* update unit test
parent 6beb6732
......@@ -214,3 +214,13 @@ size_t Function::get_graph_size() const
}
return total_size;
}
size_t Function::get_placement() const
{
return m_placement;
}
void Function::set_placement(size_t placement)
{
m_placement = placement;
}
......@@ -90,6 +90,9 @@ namespace ngraph
/// graphs and should not be considered the actual memory consumption of a graph.
size_t get_graph_size() const;
size_t get_placement() const;
void set_placement(size_t placement);
protected:
ResultVector m_results;
ParameterVector m_parameters;
......@@ -104,5 +107,6 @@ namespace ngraph
size_t m_instance_id;
std::string m_name;
const std::string m_unique_name;
size_t m_placement;
};
}
......@@ -17,7 +17,7 @@
add_library(hybrid_base STATIC
hybrid_backend.cpp
hybrid_util.cpp
pass/assign_placement.cpp
pass/default_placement.cpp
pass/dump.cpp
pass/fix_get_output_element.cpp
pass/liveness.cpp
......
......@@ -20,7 +20,7 @@
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/runtime/hybrid/pass/assign_placement.hpp"
#include "ngraph/runtime/hybrid/pass/default_placement.hpp"
#include "ngraph/runtime/hybrid/pass/dump.hpp"
#include "ngraph/runtime/hybrid/pass/fix_get_output_element.hpp"
#include "ngraph/runtime/hybrid/pass/liveness.hpp"
......@@ -74,7 +74,7 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
// Run placement pass
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::hybrid::pass::AssignPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::DefaultPlacement>(m_backend_list);
pass_manager.register_pass<runtime::hybrid::pass::FixGetOutputElement>();
pass_manager.register_pass<runtime::hybrid::pass::Liveness>();
pass_manager.register_pass<runtime::hybrid::pass::Dump>("graph.dump");
......@@ -94,7 +94,7 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
size_t subfunction_number = 0;
for (shared_ptr<Function>& sub_function : instance.m_sub_functions)
{
size_t placement = runtime::hybrid::get_colocated_function_placement(sub_function);
size_t placement = sub_function->get_placement();
if (m_debug_enabled)
{
string name = "subfunction_" + to_string(subfunction_number++);
......@@ -149,7 +149,7 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
for (const shared_ptr<Function>& sub_function : instance.m_sub_functions)
{
// Init backend
size_t placement = runtime::hybrid::get_colocated_function_placement(sub_function);
size_t placement = sub_function->get_placement();
auto backend = m_backend_list[placement];
// Prepare parameter Tensors
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/runtime/hybrid/hybrid_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
......@@ -25,20 +26,20 @@ static Node* take_independent_node_with_placement_priority(
map<size_t, deque<Node*>>& independent_nodes_by_placement, size_t placement)
{
Node* selected_node = nullptr;
if (independent_nodes_by_placement.find(placement) != independent_nodes_by_placement.end() &&
independent_nodes_by_placement.at(placement).size() != 0)
auto it = independent_nodes_by_placement.find(placement);
if (it != independent_nodes_by_placement.end() && it->second.size() != 0)
{
selected_node = independent_nodes_by_placement.at(placement).front();
independent_nodes_by_placement.at(placement).pop_front();
selected_node = it->second.front();
it->second.pop_front();
}
else
{
for (auto& it : independent_nodes_by_placement)
for (auto& p : independent_nodes_by_placement)
{
if (it.second.size() > 0)
if (p.second.size() > 0)
{
selected_node = it.second.front();
it.second.pop_front();
selected_node = p.second.front();
p.second.pop_front();
break;
}
}
......@@ -238,8 +239,10 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar
{
ParameterVector par_vector;
ResultVector res_vector;
size_t placement = -1;
for (auto node : cluster)
{
placement = node->get_placement_index();
if (auto res_node = dynamic_pointer_cast<op::Result>(node))
{
res_vector.push_back(res_node);
......@@ -250,6 +253,7 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar
}
}
auto sub_function = make_shared<Function>(res_vector, par_vector);
sub_function->set_placement(placement);
sub_functions.push_back(sub_function);
#ifdef HYBRID_DEBUG
ngraph::pass::Manager pass_manager;
......@@ -261,26 +265,3 @@ pair<vector<shared_ptr<Function>>, unordered_map<shared_ptr<op::Parameter>, shar
return make_pair(sub_functions, map_parameter_to_result);
}
// Assert that nodes in the function is colocated and return that placement
size_t runtime::hybrid::get_colocated_function_placement(shared_ptr<Function> func)
{
auto ops = func->get_ops();
//it's okay to not do Placement::DEFAULT check; the same node will be checked in the loop below
size_t function_placement = ops.front()->get_placement_index();
for (auto op : ops)
{
size_t node_placement = op->get_placement_index();
if (node_placement == Node::placement_invalid)
{
throw ngraph_error("Node " + op->get_name() + " should have a device placement");
}
if (function_placement != node_placement)
{
throw ngraph_error("Function contains nodes of two different placements");
}
}
return function_placement;
}
......@@ -35,9 +35,6 @@ namespace ngraph
std::vector<std::shared_ptr<Function>>,
std::unordered_map<std::shared_ptr<op::Parameter>, std::shared_ptr<op::Result>>>
split_function_by_placement(const std::shared_ptr<Function>& f);
// Assert that nodes in the function is colocated and return that placement
size_t get_colocated_function_placement(std::shared_ptr<Function> func);
}
}
}
......@@ -14,7 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/hybrid/pass/assign_placement.hpp"
#include "ngraph/runtime/hybrid/pass/default_placement.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/placement.hpp"
......@@ -23,13 +23,13 @@
using namespace ngraph;
using namespace std;
runtime::hybrid::pass::AssignPlacement::AssignPlacement(
runtime::hybrid::pass::DefaultPlacement::DefaultPlacement(
const vector<shared_ptr<runtime::Backend>>& placement_backends)
: m_placement_backends(placement_backends)
{
}
bool runtime::hybrid::pass::AssignPlacement::run_on_node(shared_ptr<Node> node)
bool runtime::hybrid::pass::DefaultPlacement::run_on_node(shared_ptr<Node> node)
{
size_t backend_index = 0;
for (auto backend : m_placement_backends)
......
......@@ -30,16 +30,16 @@ namespace ngraph
{
namespace pass
{
class AssignPlacement;
class DefaultPlacement;
}
}
}
}
class ngraph::runtime::hybrid::pass::AssignPlacement : public ngraph::pass::NodePass
class ngraph::runtime::hybrid::pass::DefaultPlacement : public ngraph::pass::NodePass
{
public:
AssignPlacement(
DefaultPlacement(
const std::vector<std::shared_ptr<ngraph::runtime::Backend>>& placement_backends);
private:
......
......@@ -56,7 +56,7 @@ TEST(HYBRID, abc)
auto t1 = A * B;
auto t2 = t1 * D;
auto C = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>((t2 + C) * t1, ParameterVector{A, B, C, D});
auto f = make_shared<Function>(((t2 + C) + A) * t1, ParameterVector{A, B, C, D});
shared_ptr<runtime::Backend> backend = runtime::Backend::create("H1");
static_pointer_cast<runtime::hybrid::HybridBackend>(backend)->set_debug_enabled(true);
......@@ -75,5 +75,5 @@ TEST(HYBRID, abc)
auto handle = backend->compile(f);
backend->call_with_validate(handle, {result}, {a, b, c, d});
EXPECT_EQ(read_vector<float>(result), (vector<float>{145, 552, 1113, 1408}));
EXPECT_EQ(read_vector<float>(result), (vector<float>{150, 576, 1176, 1536}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment