Commit e5436889 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fused Ops dialect declaration (#3860)

* WIP

* WIP

* WIP

* All ops

* Fix layernorm backprop op name

* WIP: Adding tests

* WIP: Adding LIT parsing/printing tests

* WIP

* Added LSTM cells. Fixed some ops

* All builder tests

* PR fixes

* Fix spacing. Add missing setter to SpaceToDepth

* Update spaceToDepth lit test

* PR fixes

* Build fix

* Another fix

* Fixed optional args
parent 3ee833b7
...@@ -335,3 +335,100 @@ namespace mlir ...@@ -335,3 +335,100 @@ namespace mlir
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "ops.cpp.inc" #include "ops.cpp.inc"
} }
// Fused Ops decompose
// Stubs for now
// TODO: Implement and move to another file
void mlir::NGSpaceToDepthOp::decompose()
{
}
void mlir::NGSplitOp::decompose()
{
}
void mlir::NGScaleShiftOp::decompose()
{
}
void mlir::NGUnSqueezeOp::decompose()
{
}
void mlir::NGSquaredDiffOp::decompose()
{
}
void mlir::NGSqueezeOp::decompose()
{
}
void mlir::NGShuffleChannelsOp::decompose()
{
}
void mlir::NGRNNCellOp::decompose()
{
}
void mlir::NGFakeQuantOp::decompose()
{
}
void mlir::NGMVN::decompose()
{
}
void mlir::NGHardSigmoid::decompose()
{
}
void mlir::NGGRNOp::decompose()
{
}
void mlir::NGNormalizeL2Op::decompose()
{
}
void mlir::NGConvBiasBackpropFiltersBias::decompose()
{
}
void mlir::NGPrelu::decompose()
{
}
void mlir::NGLayerNormBackpropOp::decompose()
{
}
void mlir::NGGemmOp::decompose()
{
}
void mlir::NGClampOp::decompose()
{
}
void mlir::NGGroupConvTransposeOp::decompose()
{
}
void mlir::NGConvBiasOp::decompose()
{
}
void mlir::NGConvBiasAddOp::decompose()
{
}
void mlir::NGGRUCellOp::decompose()
{
}
void mlir::NGGroupConvOp::decompose()
{
}
void mlir::NGGeluOp::decompose()
{
}
void mlir::NGGeluBackpropFactorOp::decompose()
{
}
void mlir::NGLSTMCellOp::decompose()
{
}
void mlir::NGLSTMSequenceOp::decompose()
{
}
void mlir::NGMatMul::decompose()
{
}
void mlir::NGLayerNormOp::decompose()
{
}
void mlir::NGDepthToSpaceOp::decompose()
{
}
void mlir::NGEluOp::decompose()
{
}
...@@ -140,7 +140,9 @@ class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -140,7 +140,9 @@ class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, !listconcat(traits, [Terminator])>, NG_Op<mnemonic, !listconcat(traits, [Terminator])>,
Arguments<(ins Variadic<NG_TensorType>:$args)>, Results<(outs)> {} Arguments<(ins Variadic<NG_TensorType>:$args)>, Results<(outs)> {}
class NG_Variadic_Result_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, !listconcat(traits, [])>,
Results<(outs Variadic<NG_TensorType>:$args)> {}
// Terminator Ops // Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">; def NGReturnOp : NG_Terminator_Op<"return">;
......
...@@ -56,7 +56,8 @@ def PadModeEdge : I32EnumAttrCase<"EDGE", 1> ; ...@@ -56,7 +56,8 @@ def PadModeEdge : I32EnumAttrCase<"EDGE", 1> ;
def PadModeReflect : I32EnumAttrCase<"REFLECT", 2> ; def PadModeReflect : I32EnumAttrCase<"REFLECT", 2> ;
def PadModeSymmetric: I32EnumAttrCase<"SYMMETRIC", 3> ; def PadModeSymmetric: I32EnumAttrCase<"SYMMETRIC", 3> ;
def PadModeEnumAttr : I32EnumAttr<"MLIRPadMode", "Padding modes for pad operator", def PadModeEnumAttr : I32EnumAttr<"MLIRPadMode",
"Padding modes for pad operator",
[PadModeConstant, PadModeEdge, PadModeReflect, PadModeSymmetric]>; [PadModeConstant, PadModeEdge, PadModeReflect, PadModeSymmetric]>;
// Sort Types for TopK // Sort Types for TopK
...@@ -67,4 +68,51 @@ def SortTypeValues : I32EnumAttrCase<"VALUES", 2>; ...@@ -67,4 +68,51 @@ def SortTypeValues : I32EnumAttrCase<"VALUES", 2>;
def SortTypeEnumAttr : I32EnumAttr<"MLIRSortType", "Sort types for topk operator", def SortTypeEnumAttr : I32EnumAttr<"MLIRSortType", "Sort types for topk operator",
[SortTypeNone, SortTypeIndices, SortTypeValues]>; [SortTypeNone, SortTypeIndices, SortTypeValues]>;
// Modes for normalizeL2
def EpsModeAdd : I32EnumAttrCase<"ADD", 0>;
def EpsModeMax : I32EnumAttrCase<"MAX", 1>;
def EpsModeEnumAttr : I32EnumAttr<"MLIREpsMode",
"Specifies how eps is combined with L2 value",
[EpsModeAdd, EpsModeMax]>;
def AutoBroadcastNone : I32EnumAttrCase<"NONE", 0>;
def AutoBroadcastExplicit : I32EnumAttrCase<"EXPLICIT", 1>;
def AutoBroadcastNumPy : I32EnumAttrCase<"NUMPY", 2>;
def AutoBroadcastPDPD : I32EnumAttrCase<"PDPD", 3>;
def AutoBroadcastEnumAttr : I32EnumAttr<"MLIRAutoBroadcastMode",
"Specifies auto-broadcast for an op",
[AutoBroadcastNone, AutoBroadcastExplicit,
AutoBroadcastNumPy, AutoBroadcastPDPD]>;
def DepthSpaceModeBlocks : I32EnumAttrCase<"BLOCKS_FIRST", 0>;
def DepthSpaceModeDepth : I32EnumAttrCase<"DEPTH_FIRST", 1>;
def DepthSpaceModeEnumAttr: I32EnumAttr<"MLIRDepthToSpaceMode",
"Specifies how the input depth dimension is split to block coordinates",
[DepthSpaceModeBlocks, DepthSpaceModeDepth]>;
def LSTMWeightsFormatFICO : I32EnumAttrCase<"FICO", 0>; // IE
def LSTMWeightsFormatICOF : I32EnumAttrCase<"ICOF", 1>; // PyTorch
def LSTMWeightsFormatIFCO : I32EnumAttrCase<"IFCO", 2>; // DNNL, TF, MxNet
def LSTMWeightsFormatIFOC : I32EnumAttrCase<"IFOC", 3>; // Caffe
def LSTMWeightsFormatIOFC : I32EnumAttrCase<"IOFC", 4>; // ONNX
def LSTMWeightsFormatEnumAttr: I32EnumAttr<"MLIRLSTMWeightsFormat",
"LSTM Cell Weights Format",
[LSTMWeightsFormatFICO, LSTMWeightsFormatICOF,
LSTMWeightsFormatIFCO, LSTMWeightsFormatIFOC,
LSTMWeightsFormatIOFC]>;
def LSTMSeqDirectionFWD : I32EnumAttrCase<"FORWARD", 0>;
def LSTMSeqDirectionRVS : I32EnumAttrCase<"REVERSE", 1>;
def LSTMSeqDirectionBID : I32EnumAttrCase<"BIDIRECTIONAL", 2>;
def LSTMSeqDirectionsEnumAttr: I32EnumAttr<"MLIRLSTMSeqDirection",
"LSTM Sequence Direction",
[LSTMSeqDirectionFWD, LSTMSeqDirectionRVS,
LSTMSeqDirectionBID]>;
#endif // NG_OP_ATTRIBUTES #endif // NG_OP_ATTRIBUTES
...@@ -500,7 +500,7 @@ def NGMaxPoolBackPropOp : ...@@ -500,7 +500,7 @@ def NGMaxPoolBackPropOp :
} }
// OneHot // OneHot
def NGOneHOtOp : def NGOneHotOp :
NG_OneResult_Op<"oneHot", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"oneHot", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType :$arg, Arguments<(ins NG_TensorType :$arg,
I64ArrayAttr :$shape, I64ArrayAttr :$shape,
...@@ -552,7 +552,7 @@ def NGPadOp : ...@@ -552,7 +552,7 @@ def NGPadOp :
} }
// ReplaceSlice // ReplaceSlice
def NGReplaceSlice : def NGReplaceSliceOp :
NG_OneResult_Op<"replaceSlice", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"replaceSlice", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType:$arg0, Arguments<(ins NG_TensorType:$arg0,
NG_TensorType :$arg1, NG_TensorType :$arg1,
...@@ -583,7 +583,7 @@ def NGReplaceSlice : ...@@ -583,7 +583,7 @@ def NGReplaceSlice :
} }
// slice // slice
def NGSlice : def NGSliceOp :
NG_OneResult_Op<"slice", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"slice", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType:$arg, Arguments<(ins NG_TensorType:$arg,
I64ArrayAttr :$lowerBounds, I64ArrayAttr :$lowerBounds,
...@@ -611,7 +611,7 @@ def NGSlice : ...@@ -611,7 +611,7 @@ def NGSlice :
} }
// reshape // reshape
def NGReshape : def NGReshapeOp :
NG_OneResult_Op<"reshape", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"reshape", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType:$arg, Arguments<(ins NG_TensorType:$arg,
I64ArrayAttr :$axisOrder, I64ArrayAttr :$axisOrder,
...@@ -636,7 +636,7 @@ def NGReshape : ...@@ -636,7 +636,7 @@ def NGReshape :
} }
// softmax // softmax
def NGSoftMax : def NGSoftMaxOp :
NG_OneResult_Op<"softmax", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"softmax", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType :$arg, Arguments<(ins NG_TensorType :$arg,
I64ArrayAttr :$axes)> I64ArrayAttr :$axes)>
...@@ -655,7 +655,7 @@ def NGSoftMax : ...@@ -655,7 +655,7 @@ def NGSoftMax :
} }
// topk // topk
def NGTopK : def NGTopKOp :
NG_OneResult_Op<"topk", [NoSideEffect, OpVersion0]>, NG_OneResult_Op<"topk", [NoSideEffect, OpVersion0]>,
Arguments<(ins NG_TensorType :$arg, Arguments<(ins NG_TensorType :$arg,
NG_TensorType :$k, NG_TensorType :$k,
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment