Commit a5c99754 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Initial PoC: NG dialect, dialect code-gen, dialect lowering to affine, no JIT yet

* Link MLIR static libs to cpu backend

* Use LLVMConfig.cmake

* Initial commit. Link fails with undefined reference to typeinfo for mlir::Dialect

* Added AddOp

* initial compiler class

* Initialize module/function, and map tensors to arguments

* Code compiles. Moved MLIR building to correct DEX handler

* NGDialect code-gen working

* Use vector instead of sets for i/o tensors. Use functor in executor

* Misc fixes

* style-apply

* WIP: Adding support for dialect lowering.

* WIP: Lowered to affine. Crash on constant ops have side effects in Constant Folding

* Fixed missing whole package linkage.

* Removed fake instruction and update func type

*  Enable lowering to LLVM dialect and IR

* Made loop nest builder handle any rank

* Fixes per PR feedback. Major ones:
- Removed ngdialect namespace
- renamed dialect classes to start with NG prefixwq:w

* Add unreachable assert

* Add reading of LLVM options from an env var MLIR_LLVM_OPTIONS (#5)
parent 021399a1
......@@ -16,6 +16,13 @@
include(FindOpenMP)
set(MLIR_SRC
mlir/dialect/dialect.cpp
mlir/dialect/type.cpp
mlir/dialect/ops.cpp
mlir/compiler.cpp
mlir/lowerer.cpp
)
set(SRC
cpu_backend.cpp
cpu_builder.cpp
......@@ -123,6 +130,7 @@ set(SRC
pass/cpu_rnn_fusion.cpp
pass/cpu_workspace_insertion.cpp
ngraph_version.cpp
${MLIR_SRC}
)
if (NOT NGRAPH_DEX_ONLY)
......@@ -197,11 +205,20 @@ if (NGRAPH_CPU_ENABLE)
endif()
target_include_directories(cpu_backend SYSTEM PUBLIC libmkldnn)
if (NOT APPLE AND NOT MSVS)
# CPU backend uses third-party libraries like Eigen that might be linked in and
# exported by other DSOs as well. In the absence of versioning, this could lead to the
# CPU backend picking up the wrong version or even multiple versions of the
# third-party library. -Bsymbolic-functions tells the linker to prefer the internal
# version inside cpu_backend over what is available through the global symbol table
set_property(TARGET cpu_backend APPEND PROPERTY LINK_FLAGS "-Wl,-Bsymbolic-functions -Wl,--exclude-libs=ALL")
endif()
# Link LLVM and MLIR
find_package(LLVM REQUIRED CONFIG)
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
message(STATUS "LLVM RTTI is ${LLVM_ENABLE_RTTI}")
add_definitions(${LLVM_DEFINITIONS})
target_include_directories(cpu_backend PRIVATE ${LLVM_INCLUDE_DIRS})
......@@ -214,7 +231,7 @@ if (NGRAPH_CPU_ENABLE)
# Link MLIR libs
target_link_libraries(
cpu_backend PRIVATE
cpu_backend PUBLIC
MLIRAnalysis
MLIREDSC
MLIRExecutionEngine
......@@ -225,24 +242,36 @@ if (NGRAPH_CPU_ENABLE)
MLIRTargetLLVMIR
MLIRTransforms
MLIRSupport
MLIRAffineOps
MLIRStandardOps
)
# some libs need whole archive linkage because of Globals static initialization
# TODO: move this helper somewhere else.
function(whole_archive_link target)
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Darwin")
set(link_flags "-Llib -Wl,-all_load ")
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "${LIB}")
ENDFOREACH(LIB)
else()
set(link_flags "-Llib -Wl,--whole-archive,")
FOREACH(LIB ${ARGN})
string(CONCAT link_flags ${link_flags} "${LIB},")
ENDFOREACH(LIB)
string(CONCAT link_flags ${link_flags} "--no-whole-archive")
endif()
message(STATUS "MLIR Ops link flag: ${link_flags}" )
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
endfunction(whole_archive_link)
whole_archive_link(cpu_backend
${LLVM_BUILD_LIBRARY_DIR}/libMLIRAffineOps.a
${LLVM_BUILD_LIBRARY_DIR}/libMLIRStandardOps.a
)
# Link LLVM libs
target_link_libraries(
cpu_backend PRIVATE
${llvm_libs}
)
if (NOT APPLE AND NOT MSVS)
# CPU backend uses third-party libraries like Eigen that might be linked in and
# exported by other DSOs as well. In the absence of versioning, this could lead to the
# CPU backend picking up the wrong version or even multiple versions of the
# third-party library. -Bsymbolic-functions tells the linker to prefer the internal
# version inside cpu_backend over what is available through the global symbol table
set_property(TARGET cpu_backend APPEND PROPERTY LINK_FLAGS "-Wl,-Bsymbolic-functions -Wl,--exclude-libs=ALL")
endif()
install(TARGETS cpu_backend DESTINATION ${NGRAPH_INSTALL_LIB})
endif()
......@@ -191,6 +191,7 @@
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#include "ngraph/runtime/cpu/mlir/compiler.hpp"
using namespace std;
using namespace ngraph;
......@@ -1392,6 +1393,12 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
// After processing inputs, outputs, constants, and intermediates, set the buffer size.
m_buffer_size = buffer_index;
if (std::getenv("NGRAPH_MLIR") != nullptr)
{
// Initialize MLIR compiler
MLIRCompiler::init_mlir();
}
for (shared_ptr<Node> node : m_function->get_ordered_ops())
{
if (node->is_parameter() || node->is_constant())
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/StandardOps/Ops.h"
#include "llvm/ADT/STLExtras.h"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
class MLIRCompiler
{
public:
using TensorList = std::vector<descriptor::Tensor*>;
using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const std::vector<const Node*>& sub_graph)
: m_sub_graph(sub_graph.begin(), sub_graph.end())
{
}
static void init_mlir();
// compiles and runs a subgraph in MLIR
void compile();
private:
struct TensorInfo
{
mlir::Value* m_value; /* mlir value this tensor maps to */
// More info here ?
};
private:
void build_module();
void lower_dialect();
void lower_to_llvm();
void build_tensors_list();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
void build_ng_dialect();
template <typename OP>
static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node)
{
throw std::runtime_error("Unimplemented op '" + ng_node->description() +
"' in MLIR Compiler");
}
template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
void create_return();
private:
mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module;
std::unique_ptr<mlir::FuncBuilder> m_builder;
using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>;
using MLIRCompOpFunction =
std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
llvm::SmallVector<const Node*, 4> m_sub_graph;
// Maps tensor to the value it represents in the IR
// use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map;
// List of input and output tensors in the graph
TensorList m_ip_tensors, m_op_tensors;
static const MLIRCompOpMap op_dispatcher;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "dialect.hpp"
#include "ops.hpp"
#include "type.hpp"
namespace ngraph
{
using namespace runtime::cpu;
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::cpu::ngdialect::Dialect>();
NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx)
{
addTypes<NGTensorType>();
addOperations<NG_AddOp>();
addOperations<NG_ReturnOp>();
addOperations<NG_FakeOutput>();
}
void NGDialect::printType(mlir::Type type, raw_ostream& os) const
{
auto arrayTy = type.dyn_cast<NGTensorType>();
if (!arrayTy)
{
NGRAPH_ASSERT(0) << "Incorrect type to print?";
}
os << "tensor";
if (!arrayTy.getShape().empty())
{
os << "<";
mlir::interleaveComma(arrayTy.getShape(), os);
os << ">";
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "ngraph/assertion.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
class NGDialect : public mlir::Dialect
{
public:
explicit NGDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{
NGRAPH_ASSERT(0) << "Unsupported type parsing.";
return mlir::Type();
}
void printType(mlir::Type type, llvm::raw_ostream& os) const override;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ops.hpp"
#include "assertion.hpp"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "type.hpp"
using llvm::ArrayRef;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace ngraph
{
namespace runtime
{
namespace cpu
{
template <typename T>
static mlir::LogicalResult verifyBinOperands(T* op)
{
if (!op->getOperand(0)->getType().template isa<NGTensorType>())
{
std::string msg;
raw_string_ostream os(msg);
os << "expects a Tensor type for LHS, got " << op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
if (!op->getOperand(1)->getType().template isa<NGTensorType>())
{
std::string msg;
raw_string_ostream os(msg);
os << "expects a Tensor type for RHS, got " << op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
template <typename T>
static mlir::LogicalResult verifySingleOperand(T* op)
{
if (!op->getOperand()->getType().template isa<NGTensorType>())
{
std::string msg;
raw_string_ostream os(msg);
os << "expects a Tensor Type for its argument, got "
<< op->getOperand()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
}
}
void runtime::cpu::NG_FakeOutput::build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Type resultType)
{
state->types.push_back(std::move(resultType));
}
mlir::LogicalResult runtime::cpu::NG_FakeOutput::verify()
{
// TODO: Verify returned tensor types must match function return type.
return mlir::success();
}
void runtime::cpu::NG_AddOp::build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Value* lhs,
mlir::Value* rhs)
{
state->types.push_back(lhs->getType());
state->operands.push_back(lhs);
state->operands.push_back(rhs);
}
mlir::LogicalResult runtime::cpu::NG_AddOp::verify()
{
// TODO: verify matching elt types
verifyBinOperands(this);
return mlir::success();
}
void runtime::cpu::NG_ReturnOp::build(mlir::Builder* builder,
mlir::OperationState* state,
std::vector<mlir::Value*> value_list)
{
for (auto value : value_list)
{
if (value)
state->operands.push_back(value);
}
}
mlir::LogicalResult runtime::cpu::NG_ReturnOp::verify()
{
// TODO: Verify returned tensor types must match function return type.
return mlir::success();
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstdarg>
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
// Fake instructions
class NG_FakeOutput : public mlir::Op<NG_FakeOutput,
mlir::OpTrait::NOperands<0>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect>
{
public:
static llvm::StringRef getOperationName() { return "ng.fake.output"; }
mlir::LogicalResult verify();
static void
build(mlir::Builder* builder, mlir::OperationState* state, mlir::Type type);
/// Inherit constructor.
using Op::Op;
};
// Binary instructions
class NG_AddOp : public mlir::Op<NG_AddOp,
mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect>
{
public:
static llvm::StringRef getOperationName() { return "ng.add"; }
/// custom verification
mlir::LogicalResult verify();
static void build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Value* lhs,
mlir::Value* rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value* getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value* getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// Return operations terminate blocks (and functions as well). They take a
/// single argument and the type must match the function return type.
class NG_ReturnOp : public mlir::Op<NG_ReturnOp,
mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult,
mlir::OpTrait::IsTerminator>
{
public:
static llvm::StringRef getOperationName() { return "ng.return"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.return` operation accepts an optional single array as an argument
/// and does not have any returned value.
static void build(mlir::Builder* builder,
mlir::OperationState* state,
std::vector<mlir::Value*> value_list);
/// Return true if there is a returned value.
bool hasOperand() { return 0 != getNumOperands(); }
/// Helper to return the optional operand. Caller must check if the operand
/// is present before calling this.
mlir::Value* getOperand() { return getOperation()->getOperand(0); }
mlir::Value* getOperand(unsigned i) { return getOperation()->getOperand(i); }
/// Inherit constructor.
using Op::Op;
};
}
}
}
\ No newline at end of file
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "type.hpp"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "ngraph/assertion.hpp"
using llvm::ArrayRef;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace ngraph
{
using namespace runtime::cpu;
/// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same.
NGTensorType NGTensorType::get(mlir::MLIRContext* context, EltType eltType, Shape shape)
{
return Base::get(context, NGTypeKind::TENSOR_TYPE_ID, eltType, shape);
}
mlir::MemRefType NGTensorType::toMemref()
{
auto memRefType =
mlir::MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
return memRefType;
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
using llvm::raw_ostream;
enum NGTypeKind
{
// The enum starts at the range reserved for this dialect.
// These values are pre-defined in MLIR lib and not configurable from here.
NG_TYPE = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
TENSOR_TYPE_ID
};
using EltType = mlir::Type;
// TODO: Can we use ngraph::shape here (given the hashing requirements)
using Shape = llvm::ArrayRef<int64_t>;
/// Tensor Type storage. There is a unique instance per type attributes.
/// Tensor Type is combination of the element type and shape. Each different
/// shape is a unique type.
struct NGTensorTypeStorage : public mlir::TypeStorage
{
// Tensor key is its type and shape.
// This is called when the user requests a specific tensor type
using KeyTy = std::tuple<EltType, Shape>;
static unsigned hashKey(const KeyTy& key)
{
return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
}
bool operator==(const KeyTy& key) const
{
return key == KeyTy(getElementType(), getShape());
}
static NGTensorTypeStorage* construct(mlir::TypeStorageAllocator& allocator,
const KeyTy& key)
{
// Deep copy the type shape over to MLIR context
EltType eltType = std::get<0>(key);
Shape shape = allocator.copyInto(std::get<1>(key));
auto* storage = allocator.allocate<NGTensorTypeStorage>();
return new (storage) NGTensorTypeStorage(eltType, shape);
}
Shape getShape() const { return m_shape; }
EltType getElementType() const { return m_eltType; }
private:
NGTensorTypeStorage(EltType eltType, Shape shape)
: m_eltType(eltType)
, m_shape(shape)
{
}
private:
EltType m_eltType;
Shape m_shape;
};
class NGTensorType
: public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage>
{
public:
using Base::Base;
EltType getElementType() const { return getImpl()->getElementType(); }
Shape getShape() const { return getImpl()->getShape(); }
int getRank() { return getShape().size(); }
/// convert to memref native MLIR type. Used for lowering.
mlir::MemRefType toMemref();
/// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::TENSOR_TYPE_ID; }
};
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "lowerer.hpp"
#include <map>
#include "llvm/ADT/DenseSet.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/EDSC/Intrinsics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "ngraph/assertion.hpp"
#include "ngraph/runtime/cpu/mlir/dialect/ops.hpp"
#include "ngraph/runtime/cpu/mlir/dialect/type.hpp"
using namespace ngraph::runtime::cpu;
// anonymous namespace
// no need to expose any of the following outside of this file
namespace
{
using namespace mlir;
using namespace mlir::edsc;
using namespace ngraph::runtime::cpu;
class DialectLoweringPass;
#include "op_lowerers.inc"
/// Use Dialect Converson Framework
class DialectLowerer : public DialectConversion
{
public:
DialectLowerer(DialectLoweringPass& pass)
: DialectConversion()
, m_pass(pass)
{
}
Type convertType(Type t) override;
protected:
// Initialize the list of converters.
llvm::DenseSet<DialectOpConversion*> initConverters(MLIRContext* context) override
{
return ConversionListBuilder<NG_AddOpConversion, NG_ReturnOpConversion>::build(
&allocator, context, m_pass);
}
private:
DialectLoweringPass& m_pass;
llvm::BumpPtrAllocator allocator;
};
/// Dialect Lowering Pass to affine ops
class DialectLoweringPass : public ModulePass<DialectLoweringPass>
{
public:
DialectLoweringPass()
: m_dialectLowerer(*this)
{
}
void runOnModule() override;
std::map<Value*, unsigned>& getOutputValueMap() { return m_outputValueMap; };
SmallVector<Value*, 4> buildOutputDefs(Operation* op, FuncBuilder& rewriter);
private:
void findOutputValues();
void fixOutputs();
private:
DialectLowerer m_dialectLowerer;
// maps output ng dialect values to args pos
std::map<Value*, unsigned> m_outputValueMap;
// list of results values to add to func signature
SmallVector<Value*, 4> m_loweredOutputValues;
};
Type DialectLowerer::convertType(Type t)
{
if (auto tensor = t.cast<NGTensorType>())
{
return tensor.toMemref();
}
return t;
}
void DialectLoweringPass::runOnModule()
{
// capture output values by looking for the Return and grabbing the values
// the order of the returned values matches the order of the lowered func signature for
// results. This is used to find the arg_id that a defined value maps to if it is an output
findOutputValues();
if (failed(m_dialectLowerer.convert(&getModule())))
{
getModule().getContext()->emitError(mlir::UnknownLoc::get(getModule().getContext()),
"Error lowering dialect\n");
signalPassFailure();
}
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{
getModule().dump();
}
fixOutputs();
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{
getModule().dump();
}
}
void DialectLoweringPass::findOutputValues()
{
auto f = getModule().getNamedFunction("main");
SmallVector<Value*, 4> outputList;
unsigned outputCount = 0;
// we find out output values by looking at returned values
// any return should return all outputs of the subgraph
f->walk<NG_ReturnOp>([this, &outputCount](NG_ReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++)
{
this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i));
}
NGRAPH_ASSERT(outputCount == 0 || outputCount == ret.getNumOperands())
<< "Inconsistent returns in function";
outputCount = ret.getNumOperands();
});
// will be populated with lowered output values later
m_loweredOutputValues.resize(outputCount, nullptr);
}
// NGDialect converters
// ADD
SmallVector<Value*, 4> NG_AddOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{
auto add = op->cast<NG_AddOp>();
auto loc = add.getLoc();
Value *origResult, *newResult;
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(result->getType().isa<MemRefType>());
// NOte that builder's current function is still the original function body.
// use getBlock to get the new block instead.
// get new operands
Value* lhs = operands[0];
Value* rhs = operands[1];
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
LoopNestBuilder(pivs, lbs, ubs, steps)({// single stmt body
iRes(ivs) = iLHS(ivs) + iRHS(ivs)});
// return result memref
return {result};
}
SmallVector<Value*, 4> NG_ReturnOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{
rewriter.create<ReturnOp>(op->getLoc());
return {};
}
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
FuncBuilder& rewriter)
{
auto& outputMap = getOutputValueMap();
SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults())
{
auto it = outputMap.find(origResult);
// create output def if this operation produces any sub-graph outputs
if (it != outputMap.end())
{
unsigned argId = (*it).second;
auto newResult = rewriter
.create<NG_FakeOutput>(
op->getLoc(),
m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */
)
.getResult();
newResults.push_back(newResult);
m_loweredOutputValues[argId] = newResult;
}
}
return newResults;
}
void DialectLoweringPass::fixOutputs()
{
auto context = getModule().getContext();
auto f = getModule().getNamedFunction("main");
mlir::Block* entryBlock = &*(f->begin());
auto oldFuncType = f->getType();
ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
SmallVector<mlir::Type, 4> allArgs;
// Move all args as inputs in new type
for (auto type : ipArgs)
{
allArgs.push_back(type);
}
for (auto type : opArgs)
{
allArgs.push_back(type);
// add new value for result
entryBlock->addArgument(type);
}
// update type
auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
f->setType(newFuncType);
// RAUW fake outputs with result values
unsigned i = 0;
for (auto value : m_loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<NG_FakeOutput>()) << "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
}
}
}
namespace ngraph
{
namespace runtime
{
namespace cpu
{
Pass* createDialectLoweringPass() { return new DialectLoweringPass(); }
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
mlir::Pass* createDialectLoweringPass();
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// Add new dialect ops lowerers to this file
#define DECL_OP_CONV(OP) \
class OP##Conversion : public mlir::DialectOpConversion \
{\
public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectOpConversion(ngraph::runtime::cpu::OP::getOperationName(), 1, context),\
m_pass(pass)\
{} \
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \
DialectLoweringPass& m_pass;\
};
DECL_OP_CONV(NG_AddOp)
DECL_OP_CONV(NG_ReturnOp)
#undef DECL_OP_CONV
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment