Commit 9db8f874 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Enable nGraph dialect in ngraph-opt (#3657)

* [MLIR] Add support for parsing nGraph tensor type

Initial commit that enables nGraph parsing. It's needed for testing.

* [MLIR] Enable nGraph dialect in ngraph-opt

This PR registers nGraph dialect in ngraph-opt and prepares
nGraph lowering pass for LIT testing, fixing all the related issues.
Among other things, lowering pass has to be turned into a function pass,
dead argument in constructor was removed and `convert-ngraph-to-affine`
flag was added.

* Fix issue with function name and multiple functions

* Extend module_function.mlir lit test

* Improve module_function.mlir test

Remove ngraph to affine dialect conversion since we just need to verify
that we can parse and print modules and functions.
Add verification for parsing the printed code.

* [MLIR] Add support for parsing nGraph element types (#3665)

* [MLIR] Add support for parsing nGraph element types

It introduces initial support for parsing nGraph signed/unsigned
integer and floating point data types.

* Improve LIT tests

Test parsing and printing of types separately from lowering to affine
since these tests will evolve differently, particularly for tensor
types.

* Missed file

I left this file behind in the previous commit
parent 68f6110c
......@@ -24,6 +24,7 @@ set(SRC
pass/mlir_subgraph_extraction.cpp
pass/mlir_subgraph_extraction.hpp
pass/memory_optimization.cpp
tools.cpp
)
add_library(mlir_backend SHARED ${SRC})
......
......@@ -47,6 +47,7 @@
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp"
#include "pass/memory_optimization.hpp"
#include "tools.hpp"
#include <llvm/ADT/STLExtras.h>
#include <llvm/Analysis/TargetTransformInfo.h>
......@@ -173,7 +174,7 @@ void MLIRCompiler::init_mlir()
if (!initialized)
{
mlir::registerDialect<mlir::NGraphOpsDialect>();
initializeNGraphMLIR();
// Register MLIR command line options in the pool of supported flags and and process flags
// from environment variable to be used by nGraph, MLIR and LLVM.
......@@ -352,7 +353,7 @@ void MLIRCompiler::lowerNgDialect()
{
// Lower NG dialect to Affine
mlir::PassManager pm(&m_context);
pm.addPass(mlir::createDialectLoweringPass(this));
pm.addPass(mlir::createDialectLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
// Apply any generic pass manager command line options.
......
......@@ -43,6 +43,8 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo
{
StringRef origTypeStr = tyData;
MLIRContext* context = getContext();
// Process nGraph tensor type.
if (tyData.consume_front("tensor"))
{
if (!tyData.consume_front("<") || !tyData.consume_back(">"))
......@@ -72,6 +74,7 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo
shape.push_back(dim);
}
// Parse nGraph element type.
auto elem_ty = mlir::parseType(subStrings.back(), context);
if (!elem_ty)
{
......@@ -81,6 +84,41 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo
return NGTensorType::get(context, elem_ty, shape);
}
// Process nGraph integer element types.
if (tyData.startswith("i") || tyData.startswith("u"))
{
bool isSigned = tyData.consume_front("i");
bool isUnsigned = tyData.consume_front("u");
NGRAPH_CHECK(isSigned != isUnsigned, "nGraph integer cannot be signed and unsigned");
unsigned width = 0;
// NOTE: `consumeInteger` returns false if an integer was parsed successfully.
if (tyData.consumeInteger(/*Radix=*/10, width) || width == 0 || !tyData.empty())
{
return (emitError(loc, "Unexpected nGraph integer type: " + origTypeStr), Type());
}
switch (width)
{
case 8:
return isSigned ? NGIntegerType::getInt8(context) : NGIntegerType::getUInt8(context);
case 16:
return isSigned ? NGIntegerType::getInt16(context) : NGIntegerType::getUInt16(context);
case 32:
return isSigned ? NGIntegerType::getInt32(context) : NGIntegerType::getUInt32(context);
case 64:
return isSigned ? NGIntegerType::getInt64(context) : NGIntegerType::getUInt64(context);
default:
return (emitError(loc, "Unexpected width for nGraph integer type: " + origTypeStr),
Type());
}
}
// nGraph reuses standard dialect floating point element types.
NGRAPH_CHECK(!tyData.startswith("f"),
"Floating point types should be processed by standard parser");
// NOTE: We may hit this error if the nGraph type is not yet supported in parser.
return (emitError(loc, "Unknown nGraph type: " + origTypeStr), Type());
}
......
......@@ -19,7 +19,6 @@
#include "lowerer.hpp"
#include "compiler.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
#include "ngraph/assertion.hpp"
......@@ -36,6 +35,9 @@
#include <map>
#define PASS_NAME "convert-ngraph-to-affine"
#define DEBUG_TYPE PASS_NAME
// anonymous namespace
// no need to expose any of the following outside of this file
namespace
......@@ -111,15 +113,19 @@ namespace
}
}
// Convert the original function results.
SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
auto funcTypeResults = type.getResults();
if (!funcTypeResults.empty())
{
return matchFailure();
}
// Convert the original function results.
SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(funcTypeResults, convertedResults)))
{
return matchFailure();
}
// Add result types as input args without mapping
result.addInputs(convertedResults);
// Add result types as input args without mapping
result.addInputs(convertedResults);
}
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
......@@ -174,12 +180,8 @@ namespace
class DialectLoweringPass : public ModulePass<DialectLoweringPass>
{
public:
DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
: compiler(compiler)
{
}
void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
Value* createTempTensor(Type type, PatternRewriter& rewriter);
......@@ -204,7 +206,8 @@ namespace
using IdToMemRefMap = std::unordered_map<unsigned, Value*>;
IdToMemRefMap m_id_to_memref;
ngmlir::MLIRCompiler& compiler;
// TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470.
std::string funcName;
};
void DialectLoweringPass::runOnModule()
......@@ -225,18 +228,32 @@ namespace
return typeConverter.isSignatureLegal(op.getType());
});
// capture output values by looking for the Return and grabbing the values
// the order of the returned values matches the order of the lowered func signature for
// results. This is used to find the arg_id that a defined value maps to if it is an output
findOutputValues();
// Gather functions to be processed. Note that new functions will be added to module as part
// of the function signature conversion so we have to collect the original ones before hand.
SmallVector<FuncOp, 2> origFuncOps(getModule().getOps<FuncOp>());
if (failed(applyFullConversion(getModule(), target, std::move(patterns), &converter)))
for (auto origFunc : origFuncOps)
{
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
signalPassFailure();
}
// TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470.
funcName = origFunc.getName();
// Capture output values by looking for the Return and grabbing the values the order of
// the returned values matches the order of the lowered func signature for results. This
// is used to find the arg_id that a defined value maps to if it is an output.
findOutputValues();
// NOTE: Function signature conversion creates a new FuncOp that is inserted in the
// module. References the original FuncOp are no longer valid after this point.
if (failed(applyFullConversion(origFunc, target, std::move(patterns), &converter)))
{
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
signalPassFailure();
}
insertNoAliasArgAttrs();
// TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion.
insertNoAliasArgAttrs();
}
}
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
......@@ -254,8 +271,9 @@ namespace
void DialectLoweringPass::findOutputValues()
{
// get original function
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found");
SmallVector<Value*, 4> outputList;
unsigned outputCount = 0;
unsigned inputCount = f.getType().getNumInputs();
......@@ -280,13 +298,15 @@ namespace
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
PatternRewriter& rewriter)
{
FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found");
SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults())
{
// find output arg if this operation produces any sub-graph outputs
if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx"))
{
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
mlir::Block* entryBlock = &*(f.begin());
unsigned argId = (unsigned)attr.getInt();
newResults.push_back(entryBlock->getArgument(argId));
......@@ -350,7 +370,9 @@ namespace
/// by nGraph op semantics.
void DialectLoweringPass::insertNoAliasArgAttrs()
{
auto func = getModule().lookupSymbol<mlir::FuncOp>("main");
FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(func, "FuncOp '" + funcName + "' not found");
unsigned int argIdx = 0;
for (auto* arg : func.getArguments())
{
......@@ -1315,8 +1337,11 @@ namespace
namespace mlir
{
std::unique_ptr<Pass> createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler)
std::unique_ptr<Pass> createDialectLoweringPass()
{
return std::make_unique<DialectLoweringPass>(*compiler);
return std::make_unique<DialectLoweringPass>();
}
} // namespace mlir
static PassRegistration<DialectLoweringPass> pass(PASS_NAME,
"Convert nGraph dialect to affine dialect");
......@@ -36,6 +36,5 @@ namespace ngraph
namespace mlir
{
std::unique_ptr<Pass>
createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler);
std::unique_ptr<Pass> createDialectLoweringPass();
}
......@@ -158,6 +158,3 @@ namespace mlir
return std::make_unique<MemoryOptimizationPass>();
}
} // namespace mlir
static PassRegistration<MemoryOptimizationPass> pass("ng-inplace-mem-opt",
"Performs in-place memory optimizations");
\ No newline at end of file
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "tools.hpp"
#include "dialect/dialect.hpp"
#include <mlir/IR/Dialect.h>
void ngraph::runtime::ngmlir::initializeNGraphMLIR()
{
mlir::registerDialect<mlir::NGraphOpsDialect>();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once
namespace ngraph
{
namespace runtime
{
namespace ngmlir
{
/// Common nGraph dialect initialization code. Used by nGraph compiler and tools that
/// require nGraph dialect initialization.
void initializeNGraphMLIR();
} // namespace ngmlir
} // namespace runtime
} // namespace ngraph
......@@ -14,25 +14,18 @@
# limitations under the License.
# ******************************************************************************
set(LIB_LIBS
MLIRPass
)
add_library(ngraph_opt_lib
ngraph_opt.cpp
)
target_link_libraries(ngraph_opt_lib ${LIB_LIBS})
set(LIBS
mlir_backend
MLIROptMain
MLIRPass
MLIRParser
LLVMSupport
)
add_executable(ngraph-opt
ngraph_opt.cpp
)
#whole_archive_link(ngraph-opt ${LIBS})
target_link_libraries(ngraph-opt PRIVATE ngraph_opt_lib ${LIBS} LLVMSupport)
target_link_libraries(ngraph-opt PRIVATE ${LIBS})
install(TARGETS ngraph-opt RUNTIME DESTINATION ${NGRAPH_INSTALL_BIN})
......@@ -24,6 +24,7 @@
/// small sequence of passes without running the whole compiler pipeline. Please, refer to
/// ngraph_repo_path/tests/mlir/ for examples.
#include "contrib/mlir/compiler/tools.hpp"
#include "ngraph/check.hpp"
#include <llvm/Support/CommandLine.h>
......@@ -33,6 +34,7 @@
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include "llvm/Support/InitLLVM.h"
static llvm::cl::opt<std::string>
input_filename(llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
......@@ -63,7 +65,8 @@ static std::vector<const mlir::PassRegistryEntry*>* pass_list;
int main(int argc, char** argv)
{
// TODO: Init nGraph MLIR Compiler here, when necessary.
llvm::InitLLVM y(argc, argv);
ngraph::runtime::ngmlir::initializeNGraphMLIR();
// Register any pass manager command line options.
mlir::registerPassManagerCLOptions();
......
// RUN: ngraph-opt %s -convert-ngraph-to-affine -split-input-file | FileCheck %s
// These tests verify that we can parse nGraph dialect types and lower them to affine.
// -----
// CHECK-LABEL: func @f32
// CHECK-SAME: (%{{.*}}: f32)
func @f32(%arg0: f32) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @f64
// CHECK-SAME: (%{{.*}}: f64)
func @f64(%arg0: f64) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i8
// CHECK-SAME: (%{{.*}}: i8)
func @i8(%arg0: !ng.i8) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i16
// CHECK-SAME: (%{{.*}}: i16)
func @i16(%arg0: !ng.i16) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i32
// CHECK-SAME: (%{{.*}}: i32)
func @i32(%arg0: !ng.i32) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i64
// CHECK-SAME: (%{{.*}}: i64)
func @i64(%arg0: !ng.i64) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @u8
// CHECK-SAME: (%{{.*}}: i8)
func @u8(%arg0: !ng.u8) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @u16
// CHECK-SAME: (%{{.*}}: i16)
func @u16(%arg0: !ng.u16) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @u32
// CHECK-SAME: (%{{.*}}: i32)
func @u32(%arg0: !ng.u32) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @u64
// CHECK-SAME (%{{.*}}: i64)
func @u64(%arg0: !ng.u64) {
"ng.return"() : () -> ()
}
// RUN: ngraph-opt %s -split-input-file | FileCheck %s
// Verify the printed output can be parsed.
// RUN: ngraph-opt %s -split-input-file | ngraph-opt | FileCheck %s
// These tests verify parsing and printing of various combinations of nGraph module and function
// ops.
// -----
// CHECK: module {
// CHECK: func @empty_func() {
// CHECK: return
module {
func @empty_func() -> () {
"ng.return"() : () -> ()
}
}
// -----
// CHECK: module {
// CHECK: func @empty_func1() {
// CHECK: ng.return
// CHECK: func @empty_func2() {
// CHECK: ng.return
module {
func @empty_func1() -> () {
"ng.return"() : () -> ()
}
func @empty_func2() -> () {
"ng.return"() : () -> ()
}
}
// -----
// Empty module must be automatically generated.
// CHECK: module {
// CHECK: func @no_module() {
// CHECK: ng.return
func @no_module() -> () {
"ng.return"() : () -> ()
}
// RUN: ngraph-opt %s -split-input-file | FileCheck %s
// Verify the printed output can be parsed.
// RUN: ngraph-opt %s -split-input-file | ngraph-opt | FileCheck %s
// These tests verify parsing and printing of nGraph types.
// -----
// CHECK-LABEL: func @f32
// CHECK-SAME: (%{{.*}}: f32)
func @f32(%arg0: f32) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @f64
// CHECK-SAME: (%{{.*}}: f64)
func @f64(%arg0: f64) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i8
// CHECK-SAME: (%{{.*}}: !ng.i8)
func @i8(%arg0: !ng.i8) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i16
// CHECK-SAME: (%{{.*}}: !ng.i16)
func @i16(%arg0: !ng.i16) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i32
// CHECK-SAME: (%{{.*}}: !ng.i32)
func @i32(%arg0: !ng.i32) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @i64
// CHECK-SAME: (%{{.*}}: !ng.i64)
func @i64(%arg0: !ng.i64) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @u8
// CHECK-SAME: (%{{.*}}: !ng.i8)
func @u8(%arg0: !ng.u8) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @u16
// CHECK-SAME: (%{{.*}}: !ng.i16)
func @u16(%arg0: !ng.u16) {
"ng.return"() : () -> ()
}
// -----
// CHECK-LABEL: func @u32
// CHECK-SAME: (%{{.*}}: !ng.i32)
func @u32(%arg0: !ng.u32) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @u64
// CHECK-SAME (%{{.*}}: !ng.i64)
func @u64(%arg0: !ng.u64) {
"ng.return"() : () -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment