Commit bad0c25b authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

[MLIR] Encapsulate internal nodes of CompiledKernel Op to prevent lat… (#3583)

* [MLIR] Encapsulate internal nodes of CompiledKernel Op to prevent later passes from accessing them.

* Fix style and namespace.

* Address PR feedback.

* Move encapsulate_nodes from pass to CompiledKernel Op.

* Address PR feedback.

* Follow MLIR naming convention.
parent de16c5cf
......@@ -639,11 +639,25 @@ mlir::Operation* MLIRCompiler::createGenericOp(const ngraph::Node* ngNode)
{
std::vector<mlir::Value*> argValues;
std::vector<mlir::Type> resTypes;
for (auto& arg : ngNode->get_arguments())
auto inputMap = m_compiledKernel->get_input_map();
std::shared_ptr<descriptor::Tensor> argTensor;
for (auto& argOutput : ngNode->input_values())
{
auto argTensor = arg->get_output_tensor_ptr();
auto argv = getTensorValue(argTensor.get()).m_value;
argValues.push_back(argv);
auto argOutputNode = argOutput.get_node();
if (as_type<op::Parameter>(argOutputNode))
{
auto it = inputMap.find(argOutputNode->shared_from_this());
NGRAPH_CHECK(it != inputMap.end(), "Parameter not in CK input map");
argTensor = m_compiledKernel->input_values().at(it->second).get_tensor_ptr();
}
else
{
argTensor = argOutput.get_tensor_ptr();
}
auto argV = getTensorValue(argTensor.get()).m_value;
argValues.push_back(argV);
}
for (auto& output : ngNode->outputs())
......
......@@ -132,6 +132,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
#endif
clean_up();
return true;
}
......@@ -366,19 +367,19 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No
}
}
// Any input to CK must also have at least one user in the sub-graph body
// Any input to CK must not have any user in the sub-graph body
for (auto& arg : ck_node->get_arguments())
{
bool found = false;
for (auto& user : arg->get_users())
{
found = (node_set.find(user) != node_set.end());
found = (node_set.find(user) == node_set.end());
if (found)
{
break;
}
}
NGRAPH_CHECK(found, "CK input is not input to sub-graph");
NGRAPH_CHECK(found, "CK input is input to sub-graph");
}
}
}
......
......@@ -45,7 +45,15 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
OutputVector cur_args;
for (auto a : n->input_values())
{
cur_args.push_back(a.for_node(nm.at(a.get_node())));
if (as_type<op::Parameter>(a.get_node()))
{
// dummy parameter
cur_args.push_back(a);
}
else
{
cur_args.push_back(a.for_node(nm.at(a.get_node())));
}
}
auto new_n = n->copy_with_new_inputs(cur_args);
nm[n.get()] = new_n;
......@@ -58,7 +66,12 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
new_outputs.push_back(nm.at(o.get()));
}
return std::make_shared<CompiledKernel>(new_node_list, new_outputs, new_args);
auto ck = std::make_shared<CompiledKernel>(new_node_list, new_outputs, new_args);
for (auto it : m_input_map)
{
ck->insert_to_input_map(it.first, it.second);
}
return ck;
}
ngraph::op::CompiledKernel::CompiledKernel(const OutputVector& node_list,
......@@ -76,6 +89,7 @@ ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
, m_output_nodes(outputs)
{
constructor_validate_and_infer_types();
encapsulate_nodes();
set_output_size(m_output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++i)
......@@ -89,3 +103,35 @@ ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
set_output_type(i, o->get_element_type(), o->get_shape());
}
}
void ngraph::op::CompiledKernel::encapsulate_nodes()
{
std::unordered_set<std::shared_ptr<Node>> node_set(m_node_list.begin(), m_node_list.end());
// Go through each non-CK user of input to CK
int ck_arg_idx = 0;
for (auto& arg_output : input_values())
{
for (auto& input : arg_output.get_target_inputs())
{
auto user = input.get_node();
if (!as_type<op::CompiledKernel>(user) &&
node_set.find(user->shared_from_this()) != node_set.end())
{
arg_output.remove_target_input(input);
// Use a dummy Parameter as input for now, will replace later with the correct
// one.
auto temp_input_param = std::make_shared<ngraph::op::Parameter>(
arg_output.get_element_type(), arg_output.get_partial_shape());
input.replace_source_output(temp_input_param->output(0));
insert_to_input_map(temp_input_param, ck_arg_idx);
}
}
ck_arg_idx++;
}
}
void ngraph::op::CompiledKernel::insert_to_input_map(std::shared_ptr<Node> node, size_t ck_arg_idx)
{
m_input_map.emplace(node, ck_arg_idx);
}
......@@ -46,9 +46,22 @@ namespace ngraph
const NodeVector& get_node_list() const { return m_node_list; }
const NodeVector& get_kernel_outputs() const { return m_output_nodes; }
// For node B inside CompiledKernel ck such that A->B and A is outside of ck:
// replace input to B with a dummy Parameter Op and add an entry to ck's
// m_input_map.
void encapsulate_nodes();
const std::unordered_map<std::shared_ptr<Node>, size_t>& get_input_map() const
{
return m_input_map;
}
void insert_to_input_map(std::shared_ptr<Node>, size_t);
private:
NodeVector m_node_list;
NodeVector m_output_nodes;
// Used to store the information of internal nodes that have input coming from outside
// of CK
std::unordered_map<std::shared_ptr<Node>, size_t> m_input_map;
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment