Commit 261a1161 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Improve Ops verification for unary and binary ops (#38)

* Improve verification for unary, binary, cmp and select ops. Refactor .td file to enable that.
parent 4504a5e2
...@@ -34,20 +34,50 @@ using namespace mlir; ...@@ -34,20 +34,50 @@ using namespace mlir;
// to Ops classes, we will add helper classes with static methods for each Op that needs it // to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods // Additional verification methods
// Tensor type checks are already verified by the caller of these methods // Tensor type checks are already verified by the caller of these methods
/// Checks if all operands and results are of compatible shapes
template <typename T> template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op) static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkResult = true)
{ {
// TODO: Check matching element types mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
Operation* opr = op->getOperation();
auto i = 0;
for (auto operand : opr->getOperands())
{
if (i == 0)
continue;
mlir::Type t = operand->getType();
mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape");
i++;
}
if (checkResult)
{
for (auto result : opr->getResults())
{
mlir::Type t = result->getType();
mlir::NGTensorType resType = t.cast<NGTensorType>();
if (!resType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape");
}
}
return mlir::success(); return mlir::success();
} }
// Additional verification methods template <typename T>
// Tensor type checks are already verified by the caller of these methods static mlir::LogicalResult verifyUnaryArithOp(T* op)
{
return verifyCompatibleOperandsAndResults(op);
}
template <typename T> template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op) static mlir::LogicalResult verifyBinaryArithOp(T* op)
{ {
// TODO: Check matching element types return verifyCompatibleOperandsAndResults(op);
return mlir::success();
} }
template <typename T> template <typename T>
...@@ -59,13 +89,54 @@ static mlir::LogicalResult verifyOp(T* op) ...@@ -59,13 +89,54 @@ static mlir::LogicalResult verifyOp(T* op)
template <> template <>
mlir::LogicalResult verifyOp(NGDotOp* op) mlir::LogicalResult verifyOp(NGDotOp* op)
{ {
mlir::LogicalResult result = verifyBinaryArithOp(op); // TODO(dcab): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGSelectOp* op)
{
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::Type t1 = op->getOperation()->getOperand(1)->getType();
mlir::Type t2 = op->getOperation()->getOperand(2)->getType();
mlir::Type r0 = op->getOperation()->getResult(0)->getType();
NGTensorType opType0 = t0.cast<NGTensorType>();
NGTensorType opType1 = t1.cast<NGTensorType>();
NGTensorType opType2 = t2.cast<NGTensorType>();
NGTensorType resType = r0.cast<NGTensorType>();
// arg1 arg2 of same shape and elt type
if (!opType1.isCompatible(opType2))
return op->emitOpError("Incompatible operand shapes or types for select op");
// arg0 of same shape and elt type is bool
if (!opType0.isCompatibleShape(opType1) || !opType0.getElementType().isa<NGBoolType>())
return op->emitOpError("Incompatible shape for arg0 of select op");
// result is of same shape and elt type as arg1/2
if (!resType.isCompatible(opType1))
return op->emitOpError("Incompatible result shape or type for select op");
return mlir::success();
}
template <typename T>
static mlir::LogicalResult verifyCmpOp(T* op)
{
mlir::LogicalResult result = verifyCompatibleOperandsAndResults(op, false /*checkResult*/);
if (failed(result)) if (failed(result))
{ {
return result; return result;
} }
// TODO(dcab): Improve verification: proper shapes, etc. mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
mlir::Type r0 = op->getOperation()->getResult(0)->getType();
NGTensorType resType = r0.cast<NGTensorType>();
// result of same shape as input and has bool type
if (!resType.isCompatibleShape(opType0) || !resType.getElementType().isa<NGBoolType>())
return op->emitOpError("Incompatible result shape or type for comparison op");
return mlir::success(); return mlir::success();
} }
......
...@@ -72,10 +72,9 @@ class NG_MemRefDef_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -72,10 +72,9 @@ class NG_MemRefDef_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> : class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs)> {} NG_Op<mnemonic, traits>, Results<(outs)> {}
// Arithmetic binary operations // Base class for arithmetic unary operations without side effects.
// Input and outputs have same type
class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect, SameValueType], traits)>, NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$arg)> Arguments<(ins NG_TensorType:$arg)>
{ {
// TODO: Implement // TODO: Implement
...@@ -85,9 +84,18 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -85,9 +84,18 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
} }
// Base class for arithmetic binary operations without side effects. // Base class for arithmetic binary operations without side effects.
class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>, NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)> Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
}
// Base class for arithmetic binary operations with verifier.
class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, traits>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }]; let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
...@@ -95,6 +103,27 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -95,6 +103,27 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
let verifier = [{ return verifyBinaryArithOp(this); }]; let verifier = [{ return verifyBinaryArithOp(this); }];
} }
// Base class for comparison operations with verifier.
class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, traits>,
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }];
}
// Base class for ternary operations without side effects.
class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
}
// Base class for terminator operations. // Base class for terminator operations.
class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, !listconcat(traits, [Terminator])>, NG_Op<mnemonic, !listconcat(traits, [Terminator])>,
...@@ -122,28 +151,31 @@ def NGTanhOp : NG_Unary_Arith_Op<"tanh">; ...@@ -122,28 +151,31 @@ def NGTanhOp : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">; def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">;
// Binary Operations // Binary Operations
def NGAddOp : NG_Binary_Arith_Op<"add", [SameValueType, Commutative]>; def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>;
def NGAndOp : NG_Binary_Arith_Op<"and", [SameValueType, Commutative]>; def NGAndOp : NG_Binary_Arith_Op<"and", [Commutative]>;
def NGSubOp : NG_Binary_Arith_Op<"sub", [SameValueType]>; def NGSubOp : NG_Binary_Arith_Op<"sub">;
def NGDivOp : NG_Binary_Arith_Op<"div", [SameValueType]>; def NGDivOp : NG_Binary_Arith_Op<"div">;
def NGMaxOp : NG_Binary_Arith_Op<"max", [SameValueType, Commutative]>; def NGMaxOp : NG_Binary_Arith_Op<"max", [Commutative]>;
def NGMinOp : NG_Binary_Arith_Op<"min", [SameValueType, Commutative]>; def NGMinOp : NG_Binary_Arith_Op<"min", [Commutative]>;
def NGMulOp : NG_Binary_Arith_Op<"mul", [SameValueType, Commutative]>; def NGMulOp : NG_Binary_Arith_Op<"mul", [Commutative]>;
def NGPowOp : NG_Binary_Arith_Op<"pow", [SameValueType]>; def NGPowOp : NG_Binary_Arith_Op<"pow">;
// Comparison // Comparison
def NGEqOp : NG_OneResult_Op<"equal", [NoSideEffect]>; def NGEqOp : NG_Cmp_Op<"equal">;
def NGGreaterOp : NG_OneResult_Op<"greater", [NoSideEffect]>; def NGGreaterOp : NG_Cmp_Op<"greater">;
def NGGreaterEqOp : NG_OneResult_Op<"greater.eq", [NoSideEffect]>; def NGGreaterEqOp : NG_Cmp_Op<"greater.eq">;
def NGLessOp : NG_OneResult_Op<"less", [NoSideEffect]>; def NGLessOp : NG_Cmp_Op<"less">;
def NGLessEqOp : NG_OneResult_Op<"less.eq", [NoSideEffect]>; def NGLessEqOp : NG_Cmp_Op<"less.eq">;
def NGNotEqOp : NG_OneResult_Op<"not.equal", [NoSideEffect]>; def NGNotEqOp : NG_Cmp_Op<"not.equal">;
// Other // Other
def NGSelectOp : NG_OneResult_Op<"select", [NoSideEffect]>; def NGSelectOp : NG_Ternary_Op<"select">
{
let verifier = [{ return verifyOp(this); }];
}
// Matrix Multiply // Dot Product
def NGDotOp : NG_Binary_Arith_Op<"dot"> def NGDotOp : NG_Binary_Op<"dot">
{ {
// TODO: Add reduction axis attribute when needed. // TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(this); }];
......
...@@ -73,3 +73,36 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap ...@@ -73,3 +73,36 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap
{ {
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape); return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
} }
bool NGTensorType::isCompatible(NGTensorType& other) const
{
// Exact same tensor
if (this == &other)
return true;
// different tensors, check if of same element type and compatible shapes
if (getElementType() != other.getElementType())
return false;
// TODO: Handle dynamic ranks
// MLIR MemRefType doesn't seem to support it at the moment.
return isCompatibleShape(other);
}
bool NGTensorType::isCompatibleShape(NGTensorType& other) const
{
auto shape = getShape();
auto otherShape = other.getShape();
if (shape.size() != otherShape.size())
return false;
for (auto i = 0; i < shape.size(); i++)
{
NGRAPH_ASSERT(shape[i] >= -1) << "Invalid tensor shape";
NGRAPH_ASSERT(otherShape[i] >= -1) << "Invalid tensor shape";
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
continue;
return false;
}
return true;
}
...@@ -235,8 +235,18 @@ namespace mlir ...@@ -235,8 +235,18 @@ namespace mlir
// Multiply times element size // Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8); return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
} }
/// Checks if two tensors are compatible. Compatible means:
/// Exactly same element types
/// Compatible shapes: see isCompatibleShape.
bool isCompatible(NGTensorType& other) const;
/// Check if Shapes are of same rank and matching dimensions unless one of them is dynamic.
bool isCompatibleShape(NGTensorType& other) const;
/// create a unique tensor type based on element type and shape. /// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape); static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI /// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; } static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; }
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment