Commit 261a1161 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Improve Ops verification for unary and binary ops (#38)

* Improve verification for unary, binary, cmp and select ops. Refactor .td file to enable that.
parent 4504a5e2
......@@ -34,20 +34,50 @@ using namespace mlir;
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
/// Checks if all operands and results are of compatible shapes
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkResult = true)
{
// TODO: Check matching element types
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
Operation* opr = op->getOperation();
auto i = 0;
for (auto operand : opr->getOperands())
{
if (i == 0)
continue;
mlir::Type t = operand->getType();
mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape");
i++;
}
if (checkResult)
{
for (auto result : opr->getResults())
{
mlir::Type t = result->getType();
mlir::NGTensorType resType = t.cast<NGTensorType>();
if (!resType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape");
}
}
return mlir::success();
}
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{
return verifyCompatibleOperandsAndResults(op);
}
template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op)
{
// TODO: Check matching element types
return mlir::success();
return verifyCompatibleOperandsAndResults(op);
}
template <typename T>
......@@ -59,13 +89,54 @@ static mlir::LogicalResult verifyOp(T* op)
template <>
mlir::LogicalResult verifyOp(NGDotOp* op)
{
mlir::LogicalResult result = verifyBinaryArithOp(op);
// TODO(dcab): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGSelectOp* op)
{
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::Type t1 = op->getOperation()->getOperand(1)->getType();
mlir::Type t2 = op->getOperation()->getOperand(2)->getType();
mlir::Type r0 = op->getOperation()->getResult(0)->getType();
NGTensorType opType0 = t0.cast<NGTensorType>();
NGTensorType opType1 = t1.cast<NGTensorType>();
NGTensorType opType2 = t2.cast<NGTensorType>();
NGTensorType resType = r0.cast<NGTensorType>();
// arg1 arg2 of same shape and elt type
if (!opType1.isCompatible(opType2))
return op->emitOpError("Incompatible operand shapes or types for select op");
// arg0 of same shape and elt type is bool
if (!opType0.isCompatibleShape(opType1) || !opType0.getElementType().isa<NGBoolType>())
return op->emitOpError("Incompatible shape for arg0 of select op");
// result is of same shape and elt type as arg1/2
if (!resType.isCompatible(opType1))
return op->emitOpError("Incompatible result shape or type for select op");
return mlir::success();
}
template <typename T>
static mlir::LogicalResult verifyCmpOp(T* op)
{
mlir::LogicalResult result = verifyCompatibleOperandsAndResults(op, false /*checkResult*/);
if (failed(result))
{
return result;
}
// TODO(dcab): Improve verification: proper shapes, etc.
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
mlir::Type r0 = op->getOperation()->getResult(0)->getType();
NGTensorType resType = r0.cast<NGTensorType>();
// result of same shape as input and has bool type
if (!resType.isCompatibleShape(opType0) || !resType.getElementType().isa<NGBoolType>())
return op->emitOpError("Incompatible result shape or type for comparison op");
return mlir::success();
}
......
......@@ -72,10 +72,9 @@ class NG_MemRefDef_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs)> {}
// Arithmetic binary operations
// Input and outputs have same type
// Base class for arithmetic unary operations without side effects.
class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect, SameValueType], traits)>,
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$arg)>
{
// TODO: Implement
......@@ -85,9 +84,18 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
}
// Base class for arithmetic binary operations without side effects.
class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
}
// Base class for arithmetic binary operations with verifier.
class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, traits>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
......@@ -95,6 +103,27 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
let verifier = [{ return verifyBinaryArithOp(this); }];
}
// Base class for comparison operations with verifier.
class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, traits>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }];
}
// Base class for ternary operations without side effects.
class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
}
// Base class for terminator operations.
class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, !listconcat(traits, [Terminator])>,
......@@ -122,28 +151,31 @@ def NGTanhOp : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">;
// Binary Operations
def NGAddOp : NG_Binary_Arith_Op<"add", [SameValueType, Commutative]>;
def NGAndOp : NG_Binary_Arith_Op<"and", [SameValueType, Commutative]>;
def NGSubOp : NG_Binary_Arith_Op<"sub", [SameValueType]>;
def NGDivOp : NG_Binary_Arith_Op<"div", [SameValueType]>;
def NGMaxOp : NG_Binary_Arith_Op<"max", [SameValueType, Commutative]>;
def NGMinOp : NG_Binary_Arith_Op<"min", [SameValueType, Commutative]>;
def NGMulOp : NG_Binary_Arith_Op<"mul", [SameValueType, Commutative]>;
def NGPowOp : NG_Binary_Arith_Op<"pow", [SameValueType]>;
def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>;
def NGAndOp : NG_Binary_Arith_Op<"and", [Commutative]>;
def NGSubOp : NG_Binary_Arith_Op<"sub">;
def NGDivOp : NG_Binary_Arith_Op<"div">;
def NGMaxOp : NG_Binary_Arith_Op<"max", [Commutative]>;
def NGMinOp : NG_Binary_Arith_Op<"min", [Commutative]>;
def NGMulOp : NG_Binary_Arith_Op<"mul", [Commutative]>;
def NGPowOp : NG_Binary_Arith_Op<"pow">;
// Comparison
def NGEqOp : NG_OneResult_Op<"equal", [NoSideEffect]>;
def NGGreaterOp : NG_OneResult_Op<"greater", [NoSideEffect]>;
def NGGreaterEqOp : NG_OneResult_Op<"greater.eq", [NoSideEffect]>;
def NGLessOp : NG_OneResult_Op<"less", [NoSideEffect]>;
def NGLessEqOp : NG_OneResult_Op<"less.eq", [NoSideEffect]>;
def NGNotEqOp : NG_OneResult_Op<"not.equal", [NoSideEffect]>;
def NGEqOp : NG_Cmp_Op<"equal">;
def NGGreaterOp : NG_Cmp_Op<"greater">;
def NGGreaterEqOp : NG_Cmp_Op<"greater.eq">;
def NGLessOp : NG_Cmp_Op<"less">;
def NGLessEqOp : NG_Cmp_Op<"less.eq">;
def NGNotEqOp : NG_Cmp_Op<"not.equal">;
// Other
def NGSelectOp : NG_OneResult_Op<"select", [NoSideEffect]>;
def NGSelectOp : NG_Ternary_Op<"select">
{
let verifier = [{ return verifyOp(this); }];
}
// Matrix Multiply
def NGDotOp : NG_Binary_Arith_Op<"dot">
// Dot Product
def NGDotOp : NG_Binary_Op<"dot">
{
// TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }];
......
......@@ -73,3 +73,36 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap
{
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
}
bool NGTensorType::isCompatible(NGTensorType& other) const
{
// Exact same tensor
if (this == &other)
return true;
// different tensors, check if of same element type and compatible shapes
if (getElementType() != other.getElementType())
return false;
// TODO: Handle dynamic ranks
// MLIR MemRefType doesn't seem to support it at the moment.
return isCompatibleShape(other);
}
bool NGTensorType::isCompatibleShape(NGTensorType& other) const
{
auto shape = getShape();
auto otherShape = other.getShape();
if (shape.size() != otherShape.size())
return false;
for (auto i = 0; i < shape.size(); i++)
{
NGRAPH_ASSERT(shape[i] >= -1) << "Invalid tensor shape";
NGRAPH_ASSERT(otherShape[i] >= -1) << "Invalid tensor shape";
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
continue;
return false;
}
return true;
}
......@@ -235,8 +235,18 @@ namespace mlir
// Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
}
/// Checks if two tensors are compatible. Compatible means:
/// Exactly same element types
/// Compatible shapes: see isCompatibleShape.
bool isCompatible(NGTensorType& other) const;
/// Check if Shapes are of same rank and matching dimensions unless one of them is dynamic.
bool isCompatibleShape(NGTensorType& other) const;
/// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; }
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment