Commit 7b3c323b authored by Nagy Mostafa's avatar Nagy Mostafa Committed by omarkanawi

[MLIR] Fix style in compiler and lowerer files (#3564)

* Fix style in compiler and lowerer files

* Fix comment in headers

* Revert "Fix comment in headers"

This reverts commit d52eed4c1bdf371f3cc7d3f601d9d2b1b0c233e8.

* Fix compiler.* header. Fix code style in other files
parent 76e4485b
This diff is collapsed.
......@@ -14,8 +14,8 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
// NOTE: This file follows nGraph format style.
// Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#pragma once
......@@ -69,7 +69,7 @@ namespace ngraph
using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel)
: m_compiled_kernel(compiled_kernel)
: m_compiledKernel(compiled_kernel)
{
}
......@@ -77,7 +77,7 @@ namespace ngraph
void compile();
/// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors);
void run(std::vector<void*>& externalTensors);
private:
struct TensorInfo
......@@ -87,66 +87,65 @@ namespace ngraph
};
private:
void build_ng_dialect_module();
void lower_ng_dialect();
void buildNgDialectModule();
void lowerNgDialect();
void optimize();
void bind_arguments(std::vector<void*>& external_tensors);
void bindArguments(std::vector<void*>& externalTensors);
void execute();
void cleanup();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node);
mlir::Type getMlirType(const descriptor::Tensor* tensor);
mlir::Type getMlirType(const element::Type& type);
mlir::Type getMlirType(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
TensorInfo getTensorValue(descriptor::Tensor* tensor);
void updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value);
void build_ng_dialect();
void buildNgDialect();
template <typename Op>
static mlir::Operation* create_op(MLIRCompiler& compiler,
const ngraph::Node* ng_node)
static mlir::Operation* createOp(MLIRCompiler& compiler, const ngraph::Node* ngNode)
{
throw std::runtime_error("Unimplemented op '" + ng_node->description() +
throw std::runtime_error("Unimplemented op '" + ngNode->description() +
"' in MLIR Compiler");
}
// Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename Op>
mlir::Operation* create_generic_op(const ngraph::Node* ng_node);
mlir::Operation* createGenericOp(const ngraph::Node* ngNode);
template <typename RedOp>
mlir::Operation* create_index_reduction(const ngraph::Node* ng_node);
mlir::Operation* createIndexReduction(const ngraph::Node* ngNode);
void create_return();
void createReturn();
/// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocate_memref_args();
llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocate_memref_descriptor();
mlir::StaticFloatMemRef* allocateMemrefDescriptor();
/// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg.
void dump_mlir_module(const std::string msg);
void dumpMlirModule(const std::string msg);
/// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape.
template <typename T>
void get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape);
void getMlirShape(T ngShape, llvm::SmallVectorImpl<int64_t>& mlirShape);
/// Converts an ngraph shape to an I64 array attribute
template <typename T>
mlir::ArrayAttr get_shape_as_attr(T ng_shape);
mlir::ArrayAttr getShapeAsAttr(T ngShape);
private:
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
const ngraph::op::CompiledKernel* m_compiledKernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
std::vector<void*>* m_external_tensors;
std::vector<void*>* m_externalTensors;
// Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args;
llvm::SmallVector<void*, 8> m_invokeArgs;
// MLIR context that holds all the MLIR information related to the sub-graph
// compilation.
......@@ -164,11 +163,11 @@ namespace ngraph
// Maps tensor to the value it represents in the IR
// use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map;
static const MLIRCompOpMap op_dispatcher;
TensorToInfoMap m_tensorToValueMap;
static const MLIRCompOpMap opDispatcher;
// Optimization level used by MLIR and LLVM compilers.
static unsigned mlir_opt_level;
static unsigned mlirOptLevel;
// LLVM target machine to be used by this MLIR compiler instance to retrieve
// information about target features.
......@@ -178,7 +177,7 @@ namespace ngraph
// machine or configuration flags.
// TODO: Move target machine to external nGraph backend when multiple backends start
// to use MLIR.
static std::unique_ptr<llvm::TargetMachine> target_machine;
static std::unique_ptr<llvm::TargetMachine> targetMachine;
};
}
}
......
......@@ -51,7 +51,9 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
for (auto operand : opr->getOperands())
{
if (i == 0)
{
continue;
}
mlir::Type t = operand->getType();
mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0))
......
......@@ -82,10 +82,14 @@ bool NGTensorType::isCompatible(NGTensorType& other) const
{
// Exact same tensor
if (this == &other)
{
return true;
}
// different tensors, check if of same element type and compatible shapes
if (getElementType() != other.getElementType())
{
return false;
}
// TODO: Handle dynamic ranks
// MLIR MemRefType doesn't seem to support it at the moment.
return isCompatibleShape(other);
......@@ -97,7 +101,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
auto otherShape = other.getShape();
if (shape.size() != otherShape.size())
{
return false;
}
for (auto i = 0; i < shape.size(); i++)
{
......@@ -105,7 +111,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]);
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
{
continue;
}
return false;
}
return true;
......
......@@ -104,13 +104,19 @@ namespace
// Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
{
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
{
return matchFailure();
}
}
// Convert the original function results.
SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
{
return matchFailure();
}
// Add result types as input args without mapping
result.addInputs(convertedResults);
......@@ -139,16 +145,16 @@ namespace
DialectLoweringPass& pass);
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
template <typename OP>
void lower_unary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type);
......@@ -376,49 +382,49 @@ namespace
REWRITER(NGAddOp)
{
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGSubOp)
{
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMulOp)
{
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGDivOp)
{
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGGreaterOp)
{
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGLessOp)
{
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMaxOp)
{
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMinOp)
{
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
return matchSuccess();
}
......@@ -477,7 +483,7 @@ namespace
// Negative
REWRITER(NGNegOp)
{
lower_unary_elementwise<mlir::NGNegOp>(op, operands, rewriter, pass);
lowerUnaryElementwise<mlir::NGNegOp>(op, operands, rewriter, pass);
return matchSuccess();
}
......@@ -950,10 +956,10 @@ namespace
#undef REWRITER
/// End of pattern matchers
template <typename OP>
void lower_unary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
......@@ -999,10 +1005,10 @@ namespace
}
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0];
......@@ -1138,7 +1144,9 @@ namespace
for (auto i = 0; i < vArg.rank(); i++)
{
if (i != axis)
{
nonRedIVs.push_back(allIVs[i]);
}
}
// Load current min index with integer data type and convert it to index data type.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment