Commit 7b3c323b authored by Nagy Mostafa's avatar Nagy Mostafa Committed by omarkanawi

[MLIR] Fix style in compiler and lowerer files (#3564)

* Fix style in compiler and lowerer files

* Fix comment in headers

* Revert "Fix comment in headers"

This reverts commit d52eed4c1bdf371f3cc7d3f601d9d2b1b0c233e8.

* Fix compiler.* header. Fix code style in other files
parent 76e4485b
This diff is collapsed.
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it // NOTE: This file follows nGraph format style.
// exposes a public API to the rest of nGraph codebase. // Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#pragma once #pragma once
...@@ -69,7 +69,7 @@ namespace ngraph ...@@ -69,7 +69,7 @@ namespace ngraph
using TypeList = llvm::SmallVector<mlir::Type, 4>; using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel) MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel)
: m_compiled_kernel(compiled_kernel) : m_compiledKernel(compiled_kernel)
{ {
} }
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
void compile(); void compile();
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors); void run(std::vector<void*>& externalTensors);
private: private:
struct TensorInfo struct TensorInfo
...@@ -87,66 +87,65 @@ namespace ngraph ...@@ -87,66 +87,65 @@ namespace ngraph
}; };
private: private:
void build_ng_dialect_module(); void buildNgDialectModule();
void lower_ng_dialect(); void lowerNgDialect();
void optimize(); void optimize();
void bind_arguments(std::vector<void*>& external_tensors); void bindArguments(std::vector<void*>& externalTensors);
void execute(); void execute();
void cleanup(); void cleanup();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor); mlir::Type getMlirType(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type); mlir::Type getMlirType(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node); mlir::Type getMlirType(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor); TensorInfo getTensorValue(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value); void updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value);
void build_ng_dialect(); void buildNgDialect();
template <typename Op> template <typename Op>
static mlir::Operation* create_op(MLIRCompiler& compiler, static mlir::Operation* createOp(MLIRCompiler& compiler, const ngraph::Node* ngNode)
const ngraph::Node* ng_node)
{ {
throw std::runtime_error("Unimplemented op '" + ng_node->description() + throw std::runtime_error("Unimplemented op '" + ngNode->description() +
"' in MLIR Compiler"); "' in MLIR Compiler");
} }
// Generic op lowerer to ng dialect. // Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic. // Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename Op> template <typename Op>
mlir::Operation* create_generic_op(const ngraph::Node* ng_node); mlir::Operation* createGenericOp(const ngraph::Node* ngNode);
template <typename RedOp> template <typename RedOp>
mlir::Operation* create_index_reduction(const ngraph::Node* ng_node); mlir::Operation* createIndexReduction(const ngraph::Node* ngNode);
void create_return(); void createReturn();
/// Helper to create memref arguments for MLIR function signature /// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocate_memref_args(); llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now. /// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocate_memref_descriptor(); mlir::StaticFloatMemRef* allocateMemrefDescriptor();
/// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg. /// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg.
void dump_mlir_module(const std::string msg); void dumpMlirModule(const std::string msg);
/// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape. /// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape.
template <typename T> template <typename T>
void get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape); void getMlirShape(T ngShape, llvm::SmallVectorImpl<int64_t>& mlirShape);
/// Converts an ngraph shape to an I64 array attribute /// Converts an ngraph shape to an I64 array attribute
template <typename T> template <typename T>
mlir::ArrayAttr get_shape_as_attr(T ng_shape); mlir::ArrayAttr getShapeAsAttr(T ngShape);
private: private:
// Sub-graph to be compiled and executed with MLIR. // Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel; const ngraph::op::CompiledKernel* m_compiledKernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors. // Pointers to externally allocated memory for sub-graph's input and output tensors.
std::vector<void*>* m_external_tensors; std::vector<void*>* m_externalTensors;
// Arguments for the MLIR function generated for the nGraph sub-graph. // Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args; llvm::SmallVector<void*, 8> m_invokeArgs;
// MLIR context that holds all the MLIR information related to the sub-graph // MLIR context that holds all the MLIR information related to the sub-graph
// compilation. // compilation.
...@@ -164,11 +163,11 @@ namespace ngraph ...@@ -164,11 +163,11 @@ namespace ngraph
// Maps tensor to the value it represents in the IR // Maps tensor to the value it represents in the IR
// use for MLIR dialect gen // use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map; TensorToInfoMap m_tensorToValueMap;
static const MLIRCompOpMap op_dispatcher; static const MLIRCompOpMap opDispatcher;
// Optimization level used by MLIR and LLVM compilers. // Optimization level used by MLIR and LLVM compilers.
static unsigned mlir_opt_level; static unsigned mlirOptLevel;
// LLVM target machine to be used by this MLIR compiler instance to retrieve // LLVM target machine to be used by this MLIR compiler instance to retrieve
// information about target features. // information about target features.
...@@ -178,7 +177,7 @@ namespace ngraph ...@@ -178,7 +177,7 @@ namespace ngraph
// machine or configuration flags. // machine or configuration flags.
// TODO: Move target machine to external nGraph backend when multiple backends start // TODO: Move target machine to external nGraph backend when multiple backends start
// to use MLIR. // to use MLIR.
static std::unique_ptr<llvm::TargetMachine> target_machine; static std::unique_ptr<llvm::TargetMachine> targetMachine;
}; };
} }
} }
......
...@@ -51,7 +51,9 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR ...@@ -51,7 +51,9 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
for (auto operand : opr->getOperands()) for (auto operand : opr->getOperands())
{ {
if (i == 0) if (i == 0)
{
continue; continue;
}
mlir::Type t = operand->getType(); mlir::Type t = operand->getType();
mlir::NGTensorType opType = t.cast<NGTensorType>(); mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0)) if (!opType.isCompatible(opType0))
......
...@@ -82,10 +82,14 @@ bool NGTensorType::isCompatible(NGTensorType& other) const ...@@ -82,10 +82,14 @@ bool NGTensorType::isCompatible(NGTensorType& other) const
{ {
// Exact same tensor // Exact same tensor
if (this == &other) if (this == &other)
{
return true; return true;
}
// different tensors, check if of same element type and compatible shapes // different tensors, check if of same element type and compatible shapes
if (getElementType() != other.getElementType()) if (getElementType() != other.getElementType())
{
return false; return false;
}
// TODO: Handle dynamic ranks // TODO: Handle dynamic ranks
// MLIR MemRefType doesn't seem to support it at the moment. // MLIR MemRefType doesn't seem to support it at the moment.
return isCompatibleShape(other); return isCompatibleShape(other);
...@@ -97,7 +101,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const ...@@ -97,7 +101,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
auto otherShape = other.getShape(); auto otherShape = other.getShape();
if (shape.size() != otherShape.size()) if (shape.size() != otherShape.size())
{
return false; return false;
}
for (auto i = 0; i < shape.size(); i++) for (auto i = 0; i < shape.size(); i++)
{ {
...@@ -105,7 +111,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const ...@@ -105,7 +111,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]); NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]);
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i]) if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
{
continue; continue;
}
return false; return false;
} }
return true; return true;
......
...@@ -104,13 +104,19 @@ namespace ...@@ -104,13 +104,19 @@ namespace
// Convert the original function arguments. // Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs()); TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
{
if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
{
return matchFailure(); return matchFailure();
}
}
// Convert the original function results. // Convert the original function results.
SmallVector<Type, 4> convertedResults; SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults))) if (failed(converter.convertTypes(type.getResults(), convertedResults)))
{
return matchFailure(); return matchFailure();
}
// Add result types as input args without mapping // Add result types as input args without mapping
result.addInputs(convertedResults); result.addInputs(convertedResults);
...@@ -139,16 +145,16 @@ namespace ...@@ -139,16 +145,16 @@ namespace
DialectLoweringPass& pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lower_unary_elementwise(Operation* op, void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type); ValueHandle createZeroConstant(mlir::Type type);
...@@ -376,49 +382,49 @@ namespace ...@@ -376,49 +382,49 @@ namespace
REWRITER(NGAddOp) REWRITER(NGAddOp)
{ {
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGSubOp) REWRITER(NGSubOp)
{ {
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMulOp) REWRITER(NGMulOp)
{ {
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGDivOp) REWRITER(NGDivOp)
{ {
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGGreaterOp) REWRITER(NGGreaterOp)
{ {
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGLessOp) REWRITER(NGLessOp)
{ {
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMaxOp) REWRITER(NGMaxOp)
{ {
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMinOp) REWRITER(NGMinOp)
{ {
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -477,7 +483,7 @@ namespace ...@@ -477,7 +483,7 @@ namespace
// Negative // Negative
REWRITER(NGNegOp) REWRITER(NGNegOp)
{ {
lower_unary_elementwise<mlir::NGNegOp>(op, operands, rewriter, pass); lowerUnaryElementwise<mlir::NGNegOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -950,10 +956,10 @@ namespace ...@@ -950,10 +956,10 @@ namespace
#undef REWRITER #undef REWRITER
/// End of pattern matchers /// End of pattern matchers
template <typename OP> template <typename OP>
void lower_unary_elementwise(Operation* op, void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
...@@ -999,10 +1005,10 @@ namespace ...@@ -999,10 +1005,10 @@ namespace
} }
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
...@@ -1138,7 +1144,9 @@ namespace ...@@ -1138,7 +1144,9 @@ namespace
for (auto i = 0; i < vArg.rank(); i++) for (auto i = 0; i < vArg.rank(); i++)
{ {
if (i != axis) if (i != axis)
{
nonRedIVs.push_back(allIVs[i]); nonRedIVs.push_back(allIVs[i]);
}
} }
// Load current min index with integer data type and convert it to index data type. // Load current min index with integer data type and convert it to index data type.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment