Commit d34cd05a authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Add element-wise reduction ops to nGraph dialect (#37)

This patch introduces tablegen definitions for element-wise reduction
operations (only definitions, not building/lowering code). This includes
argmin, argmax, min, max, sum, prod, all and any.

From the design point of view, a single base class seems to be enough to
cover all the common ground of these operations. A different base class will
be necessary for avg pool reduction operations.
parent 261a1161
......@@ -80,6 +80,26 @@ static mlir::LogicalResult verifyBinaryArithOp(T* op)
return verifyCompatibleOperandsAndResults(op);
}
template <typename T>
static mlir::LogicalResult verifyAxisReductionOp(T* op)
{
return mlir::failure();
}
template <typename T>
static mlir::LogicalResult verifyLogicalReductionOp(T* op)
{
// TODO: verifyAxisReductionOp(op) + input and return element type.
return mlir::failure();
}
template <typename T>
static mlir::LogicalResult verifyIndexReductionOp(T* op)
{
// TODO: verifyAxisReductionOp(op) + return element type + single axis.
return mlir::failure();
}
template <typename T>
static mlir::LogicalResult verifyOp(T* op)
{
......
......@@ -181,6 +181,69 @@ def NGDotOp : NG_Binary_Op<"dot">
let verifier = [{ return verifyOp(this); }];
}
class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$operand, I64ArrayAttr:$axes)>
{
let summary = "Base class for reduction operations that perform a reduction "
"across the axes of a single tensor.";
let description = "Axes are represented as an array of I64 attributes.";
let parser = [{ NGRAPH_FAIL() << "Parser not implemented"; return mlir::failure(); }];
// TODO
let verifier = [{ return verifyAxisReductionOp(this); }];
}
// Axis reduction operations.
def NGSumRedOp : NG_Axis_Reduction_Op<"sum.red">
{
let summary = "Axis sum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGProdRedOp : NG_Axis_Reduction_Op<"prod.red">
{
let summary = "Axis product reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGMinRedOp : NG_Axis_Reduction_Op<"min.red">
{
let summary = "Axis minimum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGMaxRedOp : NG_Axis_Reduction_Op<"max.red">
{
let summary = "Axis maximum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGArgMinRedOp : NG_Axis_Reduction_Op<"argmin.red">
{
let summary = "Axis minimum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }];
}
def NGArgMaxRedOp : NG_Axis_Reduction_Op<"argmax.red">
{
let summary = "Axis maximum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }];
}
def NGAllRedOp : NG_Axis_Reduction_Op<"all.red">
{
let summary = "Axis logical AND reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }];
}
def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
{
let summary = "Axis logical OR reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }];
}
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment