Unverified Commit a315f2dd authored by Nagy Mostafa's avatar Nagy Mostafa Committed by GitHub

[MLIR] Refactor NG dialect builder (#4363)

* Use parameter output tensor as key in value map

* clean up unused function. Already in compiler.cpp

* Style.
Co-authored-by: 's avatarNishant Patel <nishant.b.patel@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent bced1048
......@@ -76,8 +76,6 @@ namespace
void buildNgDialectModule();
void buildNgDialect(mlir::FuncOp function);
void runOnModule() override;
// Applies any nGraph dialect optimizations
void optimizeNgDialect() { /*TODO: Add Core NG dialect optimizations */}
mlir::Type getMlirType(const descriptor::Tensor* tensor);
mlir::Type getMlirType(const element::Type& type);
......@@ -174,11 +172,13 @@ void NgDialectConversionPass::runOnModule()
// populate Tensor->Value maps
int i = 0;
for (auto input : kernelInputs)
for (auto p : m_compiledKernel->get_input_map())
{
auto arg = function.getArgument(i);
TensorInfo tensorInfo{arg};
m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo));
auto paramNode = p.first;
auto argId = p.second;
auto argValue = function.getArgument(argId);
m_tensorToValueMap.insert(
TensorToInfo(paramNode->get_output_tensor_ptr().get(), {argValue}));
i++;
}
......@@ -631,7 +631,6 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
{
std::vector<mlir::Value> argValues;
std::vector<mlir::Type> resTypes;
auto inputMap = m_compiledKernel->get_input_map();
std::shared_ptr<descriptor::Tensor> argTensor;
int i = 0;
for (auto& argOutput : ngNode->input_values())
......@@ -641,18 +640,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
break;
}
auto argOutputNode = argOutput.get_node();
if (is_type<op::Parameter>(argOutputNode))
{
auto it = inputMap.find(argOutputNode->shared_from_this());
NGRAPH_CHECK(it != inputMap.end(), "Parameter not in CK input map");
argTensor = m_compiledKernel->input_values().at(it->second).get_tensor_ptr();
}
else
{
argTensor = argOutput.get_tensor_ptr();
}
auto argV = getTensorValue(argTensor.get()).m_value;
argValues.push_back(argV);
i++;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment