Commit a095c587 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Fix naming convention in MLIR files (#3292)

* [MLIR] Fix naming convention in MLIR files

Add naming convention note per file to state which files should use
nGraph naming convention and which MLIR naming convention and align
naming convention in those files with such a note.

* Remove m-prefix
parent ddc261b9
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#include "compiler.hpp"
#include "dialect/dialect.hpp"
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#pragma once
#include "memory_manager.hpp"
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "dialect.hpp"
#include "ngraph/check.hpp"
#include "ops.hpp"
......@@ -41,12 +44,12 @@ void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
case NG_TENSOR_TYPE_ID:
{
os << "tensor<";
auto tensor_ty = type.cast<NGTensorType>();
for (auto dim : tensor_ty.getShape())
auto tensorTy = type.cast<NGTensorType>();
for (auto dim : tensorTy.getShape())
{
os << dim << 'x';
}
os << tensor_ty.getElementType() << '>';
os << tensorTy.getElementType() << '>';
return;
}
case NG_I8_TYPE_ID:
......@@ -58,8 +61,8 @@ void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
case NG_U32_TYPE_ID:
case NG_U64_TYPE_ID:
{
auto int_ty = type.cast<NGIntegerType>();
os << "i" << int_ty.getWidth();
auto intTy = type.cast<NGIntegerType>();
os << "i" << intTy.getWidth();
return;
}
case NG_BOOL_TYPE_ID:
......
......@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once
#include "mlir/IR/Dialect.h"
......@@ -23,6 +27,7 @@
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "ngraph/check.hpp"
namespace mlir
{
class NGraphOpsDialect : public mlir::Dialect
......
......@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "ops.hpp"
#include "assertion.hpp"
#include "llvm/Support/ErrorHandling.h"
......
......@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once
#include <cstdarg>
......
......@@ -18,6 +18,9 @@
//
//===----------------------------------------------------------------------===//
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
include "mlir/IR/OpBase.td"
// nGraph Dialect operations definitions
......
......@@ -12,6 +12,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "type.hpp"
......
......@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once
#include "mlir/IR/Dialect.h"
......@@ -198,19 +202,19 @@ namespace mlir
return new (storage) NGTensorTypeStorage(eltType, shape);
}
Shape getShape() const { return m_shape; }
int64_t getRank() const { return m_shape.size(); }
EltType getElementType() const { return m_eltType; }
Shape getShape() const { return shape; }
int64_t getRank() const { return shape.size(); }
EltType getElementType() const { return eltType; }
private:
NGTensorTypeStorage(EltType eltType, Shape shape)
: m_eltType(eltType)
, m_shape(shape)
: eltType(eltType)
, shape(shape)
{
}
private:
EltType m_eltType;
Shape m_shape;
EltType eltType;
Shape shape;
};
/// NGraph Tensor Type
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "lowerer.hpp"
#include "compiler.hpp"
......@@ -51,12 +54,12 @@ namespace
public:
NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass)
: ConversionPattern(rootOpName, /*benefit=*/1, context)
, m_pass(pass){};
, pass(pass){};
protected:
// Back-reference to the lowering pass which contains the lowering state, including the
// nGraph type converter.
DialectLoweringPass& m_pass;
DialectLoweringPass& pass;
};
// Conversion classes declarations
......@@ -81,13 +84,13 @@ namespace
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass);
DialectLoweringPass& pass);
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass);
DialectLoweringPass& pass);
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
......@@ -106,7 +109,7 @@ namespace
{
public:
DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
: m_compiler(compiler)
: compiler(compiler)
{
}
......@@ -129,13 +132,13 @@ namespace
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private:
NGraphTypeConverter m_typeConverter;
NGraphTypeConverter typeConverter;
// Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs;
SmallVector<Value*, 4> memMgrDefs;
// list of results values to add to func signature
SmallVector<Value*, 4> m_loweredOutputValues;
ngmlir::MLIRCompiler& m_compiler;
SmallVector<Value*, 4> loweredOutputValues;
ngmlir::MLIRCompiler& compiler;
};
void DialectLoweringPass::runOnModule()
......@@ -203,7 +206,7 @@ namespace
// TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code.
m_loweredOutputValues.resize(outputCount, nullptr);
loweredOutputValues.resize(outputCount, nullptr);
}
/// Inserts a fake def for Mem Mgr pointer at converted func start
......@@ -217,7 +220,7 @@ namespace
auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(&getContext()));
// will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult());
memMgrDefs.push_back(op.getResult());
return op.getResult();
}
......@@ -233,19 +236,18 @@ namespace
unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(),
m_typeConverter.convertType(origResult->getType()) /* convert to lowered type */
typeConverter.convertType(origResult->getType()) /* convert to lowered type */
);
// Fake instrution is short-lived. Verify here.
fakeOp.verify();
auto newResult = fakeOp.getResult();
newResults.push_back(newResult);
m_loweredOutputValues[argId] = newResult;
loweredOutputValues[argId] = newResult;
}
else
{
auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult =
createTempTensor(m_typeConverter.convertType(tensorType), rewriter);
auto newResult = createTempTensor(typeConverter.convertType(tensorType), rewriter);
newResults.push_back(newResult);
}
}
......@@ -302,7 +304,7 @@ namespace
// RAUW fake outputs with result values
unsigned i = 0;
for (auto value : m_loweredOutputValues)
for (auto value : loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?");
......@@ -310,9 +312,9 @@ namespace
op->erase();
i++;
}
for (auto v : m_memMgrDefs)
for (auto v : memMgrDefs)
{
v->replaceAllUsesWith(entryBlock->getArgument(m_compiler.get_mem_mgr_arg_id(f)));
v->replaceAllUsesWith(entryBlock->getArgument(compiler.get_mem_mgr_arg_id(f)));
v->getDefiningOp()->erase();
}
}
......@@ -357,27 +359,27 @@ namespace
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since NGraphTypeConverter is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>())
if (auto tensorType = type.dyn_cast<NGTensorType>())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return MemRefType::get(tensor_type.getShape(),
convertType(tensor_type.getElementType()),
return MemRefType::get(tensorType.getShape(),
convertType(tensorType.getElementType()),
{/* no map used */},
0);
}
if (auto float_type = type.dyn_cast<NGFloatType>())
if (auto floatType = type.dyn_cast<NGFloatType>())
{
// Float types are already std type.
return float_type;
return floatType;
}
if (auto int_type = type.dyn_cast<NGIntegerType>())
if (auto intType = type.dyn_cast<NGIntegerType>())
{
return mlir::IntegerType::get(int_type.getWidth(), int_type.getContext());
return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
}
if (auto bool_type = type.dyn_cast<NGBoolType>())
if (auto boolType = type.dyn_cast<NGBoolType>())
{
return mlir::IntegerType::get(1 /* width */, bool_type.getContext());
return mlir::IntegerType::get(1 /* width */, boolType.getContext());
}
NGRAPH_CHECK(false, "Unsupported type to lower");
......@@ -391,61 +393,61 @@ namespace
// ADD
REWRITER(NGAddOp)
{
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGSubOp)
{
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMulOp)
{
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGDivOp)
{
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGGreaterOp)
{
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGLessOp)
{
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMaxOp)
{
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMinOp)
{
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, m_pass);
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGArgMaxRedOp)
{
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass);
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGArgMinRedOp)
{
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass);
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, pass);
return matchSuccess();
}
......@@ -454,7 +456,7 @@ namespace
{
auto loc = cast<NGReluOp>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
......@@ -513,18 +515,18 @@ namespace
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* rhs = operands[1];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp");
auto result_ty = result->getType().dyn_cast<MemRefType>();
auto lhs_ty = lhs->getType().dyn_cast<MemRefType>();
auto rhs_ty = rhs->getType().dyn_cast<MemRefType>();
NGRAPH_CHECK(result_ty, "Unexpected non-memref result type");
NGRAPH_CHECK(lhs_ty, "Unexpected non-memref LHS type");
NGRAPH_CHECK(rhs_ty, "Unexpected non-memref RHS type");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type");
Type elem_ty = result_ty.getElementType();
NGRAPH_CHECK(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType(),
Type elemTy = resultTy.getElementType();
NGRAPH_CHECK(elemTy == lhsTy.getElementType() && elemTy == rhsTy.getElementType(),
"Types mismatch in DotOp");
// Create the following loop nest for matmul operation:
......@@ -534,9 +536,9 @@ namespace
// res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs);
MemRefView vRes(result), vLhs(lhs), vRhs(rhs);
NGRAPH_CHECK(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2,
NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2,
"Dot operation is only supported for 2D tensors");
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
......@@ -544,22 +546,21 @@ namespace
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first
// one.
IndexHandle n, m, k;
unsigned n_dim = v_lhs.fastestVarying() - 1;
unsigned m_dim = v_rhs.fastestVarying();
unsigned k_dim = v_rhs.fastestVarying();
IndexHandle n_lb(v_lhs.lb(n_dim)), m_lb(v_lhs.lb(m_dim)), k_lb(v_rhs.lb(k_dim));
IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim));
int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim);
unsigned nDim = vLhs.fastestVarying() - 1;
unsigned mDim = vRhs.fastestVarying();
unsigned kDim = vRhs.fastestVarying();
IndexHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim));
IndexHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim));
int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim);
// Constants and indexed values to be used inside the loop nest.
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
LoopBuilder(&n, n_lb, n_ub, n_step)([&] {
LoopBuilder(&k, k_lb, k_ub, k_step)([&] {
i_res(n, k) = zero_init;
LoopBuilder(&m, m_lb, m_ub, m_step)(
[&] { i_res(n, k) += i_lhs(n, m) * i_rhs(m, k); });
IndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy)));
LoopBuilder(&n, nLb, nUb, nStep)([&] {
LoopBuilder(&k, kLb, kUb, kStep)([&] {
iRes(n, k) = zeroInit;
LoopBuilder(&m, mLb, mUb, mStep)([&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
});
});
......@@ -575,7 +576,7 @@ namespace
ScopedContext scope(rewriter, loc);
// Create Value for result, and extract type info.
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
// Create view to write into result.
......@@ -653,9 +654,8 @@ namespace
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
auto resultTy = result->getType().cast<MemRefType>();
Value* params = operands[0];
Value* indices = operands[1];
......@@ -775,10 +775,10 @@ namespace
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass)
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// get new operands
Value* lhs = operands[0];
......@@ -850,7 +850,7 @@ namespace
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass)
DialectLoweringPass& pass)
{
static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(),
"Template parameter is not supported by lowerIndexReduction");
......@@ -870,7 +870,7 @@ namespace
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
// Views
MemRefView vRes(result), vArg(arg);
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once
#include "contrib/mlir/compiler.hpp"
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "memory_manager.hpp"
#include <memory>
#include "ngraph/ngraph_visibility.hpp"
......@@ -21,9 +24,9 @@
using namespace ngraph::runtime::ngmlir;
/// Call back to allocate memory for temps from JIT'ed code
extern "C" NGRAPH_API void* __mlir_allocate(MLIRMemMgr* mem_mgr, size_t size)
extern "C" NGRAPH_API void* __mlir_allocate(MLIRMemMgr* memMgr, size_t size)
{
return mem_mgr->allocate(size);
return memMgr->allocate(size);
}
void* MLIRMemMgr::allocate(size_t size)
......
......@@ -14,11 +14,15 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once
#include <stdint.h>
#include <stdlib.h>
#include <vector>
namespace ngraph
{
namespace runtime
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
......
......@@ -14,6 +14,9 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#pragma once
#include <mutex>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment