Unverified Commit a315f2dd authored by Nagy Mostafa's avatar Nagy Mostafa Committed by GitHub

[MLIR] Refactor NG dialect builder (#4363)

* Use parameter output tensor as key in value map

* clean up unused function. Already in compiler.cpp

* Style.
Co-authored-by: 's avatarNishant Patel <nishant.b.patel@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent bced1048
...@@ -76,8 +76,6 @@ namespace ...@@ -76,8 +76,6 @@ namespace
void buildNgDialectModule(); void buildNgDialectModule();
void buildNgDialect(mlir::FuncOp function); void buildNgDialect(mlir::FuncOp function);
void runOnModule() override; void runOnModule() override;
// Applies any nGraph dialect optimizations
void optimizeNgDialect() { /*TODO: Add Core NG dialect optimizations */}
mlir::Type getMlirType(const descriptor::Tensor* tensor); mlir::Type getMlirType(const descriptor::Tensor* tensor);
mlir::Type getMlirType(const element::Type& type); mlir::Type getMlirType(const element::Type& type);
...@@ -174,11 +172,13 @@ void NgDialectConversionPass::runOnModule() ...@@ -174,11 +172,13 @@ void NgDialectConversionPass::runOnModule()
// populate Tensor->Value maps // populate Tensor->Value maps
int i = 0; int i = 0;
for (auto input : kernelInputs) for (auto p : m_compiledKernel->get_input_map())
{ {
auto arg = function.getArgument(i); auto paramNode = p.first;
TensorInfo tensorInfo{arg}; auto argId = p.second;
m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo)); auto argValue = function.getArgument(argId);
m_tensorToValueMap.insert(
TensorToInfo(paramNode->get_output_tensor_ptr().get(), {argValue}));
i++; i++;
} }
...@@ -631,7 +631,6 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng ...@@ -631,7 +631,6 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
{ {
std::vector<mlir::Value> argValues; std::vector<mlir::Value> argValues;
std::vector<mlir::Type> resTypes; std::vector<mlir::Type> resTypes;
auto inputMap = m_compiledKernel->get_input_map();
std::shared_ptr<descriptor::Tensor> argTensor; std::shared_ptr<descriptor::Tensor> argTensor;
int i = 0; int i = 0;
for (auto& argOutput : ngNode->input_values()) for (auto& argOutput : ngNode->input_values())
...@@ -641,18 +640,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng ...@@ -641,18 +640,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
break; break;
} }
auto argOutputNode = argOutput.get_node(); auto argOutputNode = argOutput.get_node();
if (is_type<op::Parameter>(argOutputNode))
{
auto it = inputMap.find(argOutputNode->shared_from_this());
NGRAPH_CHECK(it != inputMap.end(), "Parameter not in CK input map");
argTensor = m_compiledKernel->input_values().at(it->second).get_tensor_ptr();
}
else
{
argTensor = argOutput.get_tensor_ptr(); argTensor = argOutput.get_tensor_ptr();
}
auto argV = getTensorValue(argTensor.get()).m_value; auto argV = getTensorValue(argTensor.get()).m_value;
argValues.push_back(argV); argValues.push_back(argV);
i++; i++;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment