Commit a095c587 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Fix naming convention in MLIR files (#3292)

* [MLIR] Fix naming convention in MLIR files

Add naming convention note per file to state which files should use
nGraph naming convention and which MLIR naming convention and align
naming convention in those files with such a note.

* Remove m-prefix
parent ddc261b9
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#include "compiler.hpp" #include "compiler.hpp"
#include "dialect/dialect.hpp" #include "dialect/dialect.hpp"
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#pragma once #pragma once
#include "memory_manager.hpp" #include "memory_manager.hpp"
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "dialect.hpp" #include "dialect.hpp"
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
#include "ops.hpp" #include "ops.hpp"
...@@ -41,12 +44,12 @@ void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const ...@@ -41,12 +44,12 @@ void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
case NG_TENSOR_TYPE_ID: case NG_TENSOR_TYPE_ID:
{ {
os << "tensor<"; os << "tensor<";
auto tensor_ty = type.cast<NGTensorType>(); auto tensorTy = type.cast<NGTensorType>();
for (auto dim : tensor_ty.getShape()) for (auto dim : tensorTy.getShape())
{ {
os << dim << 'x'; os << dim << 'x';
} }
os << tensor_ty.getElementType() << '>'; os << tensorTy.getElementType() << '>';
return; return;
} }
case NG_I8_TYPE_ID: case NG_I8_TYPE_ID:
...@@ -58,8 +61,8 @@ void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const ...@@ -58,8 +61,8 @@ void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
case NG_U32_TYPE_ID: case NG_U32_TYPE_ID:
case NG_U64_TYPE_ID: case NG_U64_TYPE_ID:
{ {
auto int_ty = type.cast<NGIntegerType>(); auto intTy = type.cast<NGIntegerType>();
os << "i" << int_ty.getWidth(); os << "i" << intTy.getWidth();
return; return;
} }
case NG_BOOL_TYPE_ID: case NG_BOOL_TYPE_ID:
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once #pragma once
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
...@@ -23,6 +27,7 @@ ...@@ -23,6 +27,7 @@
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
namespace mlir namespace mlir
{ {
class NGraphOpsDialect : public mlir::Dialect class NGraphOpsDialect : public mlir::Dialect
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "ops.hpp" #include "ops.hpp"
#include "assertion.hpp" #include "assertion.hpp"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once #pragma once
#include <cstdarg> #include <cstdarg>
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
// nGraph Dialect operations definitions // nGraph Dialect operations definitions
......
...@@ -12,6 +12,10 @@ ...@@ -12,6 +12,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "type.hpp" #include "type.hpp"
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once #pragma once
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
...@@ -198,19 +202,19 @@ namespace mlir ...@@ -198,19 +202,19 @@ namespace mlir
return new (storage) NGTensorTypeStorage(eltType, shape); return new (storage) NGTensorTypeStorage(eltType, shape);
} }
Shape getShape() const { return m_shape; } Shape getShape() const { return shape; }
int64_t getRank() const { return m_shape.size(); } int64_t getRank() const { return shape.size(); }
EltType getElementType() const { return m_eltType; } EltType getElementType() const { return eltType; }
private: private:
NGTensorTypeStorage(EltType eltType, Shape shape) NGTensorTypeStorage(EltType eltType, Shape shape)
: m_eltType(eltType) : eltType(eltType)
, m_shape(shape) , shape(shape)
{ {
} }
private: private:
EltType m_eltType; EltType eltType;
Shape m_shape; Shape shape;
}; };
/// NGraph Tensor Type /// NGraph Tensor Type
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "lowerer.hpp" #include "lowerer.hpp"
#include "compiler.hpp" #include "compiler.hpp"
...@@ -51,12 +54,12 @@ namespace ...@@ -51,12 +54,12 @@ namespace
public: public:
NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass) NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass)
: ConversionPattern(rootOpName, /*benefit=*/1, context) : ConversionPattern(rootOpName, /*benefit=*/1, context)
, m_pass(pass){}; , pass(pass){};
protected: protected:
// Back-reference to the lowering pass which contains the lowering state, including the // Back-reference to the lowering pass which contains the lowering state, including the
// nGraph type converter. // nGraph type converter.
DialectLoweringPass& m_pass; DialectLoweringPass& pass;
}; };
// Conversion classes declarations // Conversion classes declarations
...@@ -81,13 +84,13 @@ namespace ...@@ -81,13 +84,13 @@ namespace
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass); DialectLoweringPass& pass);
/// Conversion from types in the nGraph dialect to the Standard dialect. /// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter class NGraphTypeConverter : public TypeConverter
...@@ -106,7 +109,7 @@ namespace ...@@ -106,7 +109,7 @@ namespace
{ {
public: public:
DialectLoweringPass(ngmlir::MLIRCompiler& compiler) DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
: m_compiler(compiler) : compiler(compiler)
{ {
} }
...@@ -129,13 +132,13 @@ namespace ...@@ -129,13 +132,13 @@ namespace
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr); Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private: private:
NGraphTypeConverter m_typeConverter; NGraphTypeConverter typeConverter;
// Value holding mem manager passed pointer // Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs; SmallVector<Value*, 4> memMgrDefs;
// list of results values to add to func signature // list of results values to add to func signature
SmallVector<Value*, 4> m_loweredOutputValues; SmallVector<Value*, 4> loweredOutputValues;
ngmlir::MLIRCompiler& m_compiler; ngmlir::MLIRCompiler& compiler;
}; };
void DialectLoweringPass::runOnModule() void DialectLoweringPass::runOnModule()
...@@ -203,7 +206,7 @@ namespace ...@@ -203,7 +206,7 @@ namespace
// TODO: This resize is making debugging obscure. When the container is not populated due // TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to // to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code. // root-cause. We should try to change the current approach or introduce verification code.
m_loweredOutputValues.resize(outputCount, nullptr); loweredOutputValues.resize(outputCount, nullptr);
} }
/// Inserts a fake def for Mem Mgr pointer at converted func start /// Inserts a fake def for Mem Mgr pointer at converted func start
...@@ -217,7 +220,7 @@ namespace ...@@ -217,7 +220,7 @@ namespace
auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(), auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(&getContext())); IndexType::get(&getContext()));
// will be fixed later to read passed arg instead. // will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult()); memMgrDefs.push_back(op.getResult());
return op.getResult(); return op.getResult();
} }
...@@ -233,19 +236,18 @@ namespace ...@@ -233,19 +236,18 @@ namespace
unsigned argId = (int)attr.getInt(); unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>( auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_typeConverter.convertType(origResult->getType()) /* convert to lowered type */ typeConverter.convertType(origResult->getType()) /* convert to lowered type */
); );
// Fake instrution is short-lived. Verify here. // Fake instrution is short-lived. Verify here.
fakeOp.verify(); fakeOp.verify();
auto newResult = fakeOp.getResult(); auto newResult = fakeOp.getResult();
newResults.push_back(newResult); newResults.push_back(newResult);
m_loweredOutputValues[argId] = newResult; loweredOutputValues[argId] = newResult;
} }
else else
{ {
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult = auto newResult = createTempTensor(typeConverter.convertType(tensorType), rewriter);
createTempTensor(m_typeConverter.convertType(tensorType), rewriter);
newResults.push_back(newResult); newResults.push_back(newResult);
} }
} }
...@@ -302,7 +304,7 @@ namespace ...@@ -302,7 +304,7 @@ namespace
// RAUW fake outputs with result values // RAUW fake outputs with result values
unsigned i = 0; unsigned i = 0;
for (auto value : m_loweredOutputValues) for (auto value : loweredOutputValues)
{ {
auto op = value->getDefiningOp(); auto op = value->getDefiningOp();
NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?"); NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?");
...@@ -310,9 +312,9 @@ namespace ...@@ -310,9 +312,9 @@ namespace
op->erase(); op->erase();
i++; i++;
} }
for (auto v : m_memMgrDefs) for (auto v : memMgrDefs)
{ {
v->replaceAllUsesWith(entryBlock->getArgument(m_compiler.get_mem_mgr_arg_id(f))); v->replaceAllUsesWith(entryBlock->getArgument(compiler.get_mem_mgr_arg_id(f)));
v->getDefiningOp()->erase(); v->getDefiningOp()->erase();
} }
} }
...@@ -357,27 +359,27 @@ namespace ...@@ -357,27 +359,27 @@ namespace
// We may need to refactor this code to a external utility if type conversion is needed // We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since NGraphTypeConverter is private. // outside of the lowering context since NGraphTypeConverter is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>()) if (auto tensorType = type.dyn_cast<NGTensorType>())
{ {
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType. // Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future. // This may change in the future.
return MemRefType::get(tensor_type.getShape(), return MemRefType::get(tensorType.getShape(),
convertType(tensor_type.getElementType()), convertType(tensorType.getElementType()),
{/* no map used */}, {/* no map used */},
0); 0);
} }
if (auto float_type = type.dyn_cast<NGFloatType>()) if (auto floatType = type.dyn_cast<NGFloatType>())
{ {
// Float types are already std type. // Float types are already std type.
return float_type; return floatType;
} }
if (auto int_type = type.dyn_cast<NGIntegerType>()) if (auto intType = type.dyn_cast<NGIntegerType>())
{ {
return mlir::IntegerType::get(int_type.getWidth(), int_type.getContext()); return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
} }
if (auto bool_type = type.dyn_cast<NGBoolType>()) if (auto boolType = type.dyn_cast<NGBoolType>())
{ {
return mlir::IntegerType::get(1 /* width */, bool_type.getContext()); return mlir::IntegerType::get(1 /* width */, boolType.getContext());
} }
NGRAPH_CHECK(false, "Unsupported type to lower"); NGRAPH_CHECK(false, "Unsupported type to lower");
...@@ -391,61 +393,61 @@ namespace ...@@ -391,61 +393,61 @@ namespace
// ADD // ADD
REWRITER(NGAddOp) REWRITER(NGAddOp)
{ {
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGSubOp) REWRITER(NGSubOp)
{ {
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMulOp) REWRITER(NGMulOp)
{ {
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGDivOp) REWRITER(NGDivOp)
{ {
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGGreaterOp) REWRITER(NGGreaterOp)
{ {
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGLessOp) REWRITER(NGLessOp)
{ {
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMaxOp) REWRITER(NGMaxOp)
{ {
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMinOp) REWRITER(NGMinOp)
{ {
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, m_pass); lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGArgMaxRedOp) REWRITER(NGArgMaxRedOp)
{ {
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass); lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGArgMinRedOp) REWRITER(NGArgMinRedOp)
{ {
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass); lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -454,7 +456,7 @@ namespace ...@@ -454,7 +456,7 @@ namespace
{ {
auto loc = cast<NGReluOp>(op).getLoc(); auto loc = cast<NGReluOp>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>()); NGRAPH_CHECK(result->getType().isa<MemRefType>());
// Note that builder's current function is still the original function body. // Note that builder's current function is still the original function body.
// use getBlock to get the new block instead. // use getBlock to get the new block instead.
...@@ -513,18 +515,18 @@ namespace ...@@ -513,18 +515,18 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value* lhs = operands[0];
Value* rhs = operands[1]; Value* rhs = operands[1];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp"); NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp");
auto result_ty = result->getType().dyn_cast<MemRefType>(); auto resultTy = result->getType().dyn_cast<MemRefType>();
auto lhs_ty = lhs->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto rhs_ty = rhs->getType().dyn_cast<MemRefType>(); auto rhsTy = rhs->getType().dyn_cast<MemRefType>();
NGRAPH_CHECK(result_ty, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhs_ty, "Unexpected non-memref LHS type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
NGRAPH_CHECK(rhs_ty, "Unexpected non-memref RHS type"); NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type");
Type elem_ty = result_ty.getElementType(); Type elemTy = resultTy.getElementType();
NGRAPH_CHECK(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType(), NGRAPH_CHECK(elemTy == lhsTy.getElementType() && elemTy == rhsTy.getElementType(),
"Types mismatch in DotOp"); "Types mismatch in DotOp");
// Create the following loop nest for matmul operation: // Create the following loop nest for matmul operation:
...@@ -534,9 +536,9 @@ namespace ...@@ -534,9 +536,9 @@ namespace
// res[n, k] += lhs[n, m] * rhs[m, k] // res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout. // TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs); MemRefView vRes(result), vLhs(lhs), vRhs(rhs);
NGRAPH_CHECK(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2, NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2,
"Dot operation is only supported for 2D tensors"); "Dot operation is only supported for 2D tensors");
// Create induction variables, lower bounds, upper bounds and steps of the loop nest. // Create induction variables, lower bounds, upper bounds and steps of the loop nest.
...@@ -544,22 +546,21 @@ namespace ...@@ -544,22 +546,21 @@ namespace
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first // i.e., fastest varying dimension is the last one, slowest varying dimention is the first
// one. // one.
IndexHandle n, m, k; IndexHandle n, m, k;
unsigned n_dim = v_lhs.fastestVarying() - 1; unsigned nDim = vLhs.fastestVarying() - 1;
unsigned m_dim = v_rhs.fastestVarying(); unsigned mDim = vRhs.fastestVarying();
unsigned k_dim = v_rhs.fastestVarying(); unsigned kDim = vRhs.fastestVarying();
IndexHandle n_lb(v_lhs.lb(n_dim)), m_lb(v_lhs.lb(m_dim)), k_lb(v_rhs.lb(k_dim)); IndexHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim));
IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim)); IndexHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim));
int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim); int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim);
// Constants and indexed values to be used inside the loop nest. // Constants and indexed values to be used inside the loop nest.
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs); IndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty))); ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy)));
LoopBuilder(&n, n_lb, n_ub, n_step)([&] { LoopBuilder(&n, nLb, nUb, nStep)([&] {
LoopBuilder(&k, k_lb, k_ub, k_step)([&] { LoopBuilder(&k, kLb, kUb, kStep)([&] {
i_res(n, k) = zero_init; iRes(n, k) = zeroInit;
LoopBuilder(&m, m_lb, m_ub, m_step)( LoopBuilder(&m, mLb, mUb, mStep)([&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
[&] { i_res(n, k) += i_lhs(n, m) * i_rhs(m, k); });
}); });
}); });
...@@ -575,7 +576,7 @@ namespace ...@@ -575,7 +576,7 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Create Value for result, and extract type info. // Create Value for result, and extract type info.
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp"); NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
// Create view to write into result. // Create view to write into result.
...@@ -653,9 +654,8 @@ namespace ...@@ -653,9 +654,8 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Get operands // Get operands
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp"); NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
auto resultTy = result->getType().cast<MemRefType>();
Value* params = operands[0]; Value* params = operands[0];
Value* indices = operands[1]; Value* indices = operands[1];
...@@ -775,10 +775,10 @@ namespace ...@@ -775,10 +775,10 @@ namespace
void lower_binary_elementwise(Operation* op, void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>()); NGRAPH_CHECK(result->getType().isa<MemRefType>());
// get new operands // get new operands
Value* lhs = operands[0]; Value* lhs = operands[0];
...@@ -850,7 +850,7 @@ namespace ...@@ -850,7 +850,7 @@ namespace
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass) DialectLoweringPass& pass)
{ {
static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(), static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(),
"Template parameter is not supported by lowerIndexReduction"); "Template parameter is not supported by lowerIndexReduction");
...@@ -870,7 +870,7 @@ namespace ...@@ -870,7 +870,7 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* arg = operands[0]; Value* arg = operands[0];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = pass.buildOutputDefs(op, rewriter)[0];
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once #pragma once
#include "contrib/mlir/compiler.hpp" #include "contrib/mlir/compiler.hpp"
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "memory_manager.hpp" #include "memory_manager.hpp"
#include <memory> #include <memory>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
...@@ -21,9 +24,9 @@ ...@@ -21,9 +24,9 @@
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
/// Call back to allocate memory for temps from JIT'ed code /// Call back to allocate memory for temps from JIT'ed code
extern "C" NGRAPH_API void* __mlir_allocate(MLIRMemMgr* mem_mgr, size_t size) extern "C" NGRAPH_API void* __mlir_allocate(MLIRMemMgr* memMgr, size_t size)
{ {
return mem_mgr->allocate(size); return memMgr->allocate(size);
} }
void* MLIRMemMgr::allocate(size_t size) void* MLIRMemMgr::allocate(size_t size)
......
...@@ -14,11 +14,15 @@ ...@@ -14,11 +14,15 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <stdlib.h> #include <stdlib.h>
#include <vector> #include <vector>
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#include "mlir_subgraph_extraction.hpp" #include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
#pragma once #pragma once
#include <mutex> #include <mutex>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment