Unverified Commit e49dd589 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Bob/hybrid multi (#3005)

* handle case where a node's output is connected multiple inputs of another node

* fix creation of the FunctionCall to have the correct outputs

* fix per review comment
parent 08dcd01b
......@@ -223,12 +223,10 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
// we just added
auto sub_function = make_shared<Function>(cluster_outputs, cluster_inputs);
sub_function->set_placement(placement);
auto fc = make_shared<runtime::hybrid::op::FunctionCall>(function_call_outputs,
function_call_inputs,
*sub_function,
backend_list[placement]);
auto fc = make_shared<runtime::hybrid::op::FunctionCall>(
cluster_outputs, function_call_inputs, *sub_function, backend_list[placement]);
fc->set_placement_index(0);
for (size_t i = 0; i < function_call_outputs.size(); i++)
for (size_t i = 0; i < cluster_outputs.size(); i++)
{
// First add a GetOutputElement to the ith output of the FunctionCall
auto goe = make_shared<ngraph::op::GetOutputElement>(fc, i);
......@@ -238,12 +236,10 @@ void runtime::hybrid::rewrite_function(const shared_ptr<Function>& f,
auto target = function_call_outputs[i];
std::vector<Input<Node>> target_inputs = get_inputs_from(*old_source, *target);
NGRAPH_CHECK(target_inputs.size() == 1,
"rewrite_function encountered more than "
"one input between the old source node and the target node");
auto& target_input = target_inputs[0];
target_input.replace_source_output(goe->output(0));
for (Input<Node> target_input : target_inputs)
{
target_input.replace_source_output(goe->output(0));
}
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment