ops.td 11.4 KB
Newer Older
1
//*****************************************************************************
nmostafa's avatar
nmostafa committed
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//
17
// This is the nGraph Dialect operation definition file.
18 19 20
//
//===----------------------------------------------------------------------===//

21 22 23
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.

24 25
include "mlir/IR/OpBase.td"

26
// nGraph Dialect operations definitions
27
//
28
// This files declares nGraph operations that table-gen uses to create C++ code
29 30 31 32 33 34 35 36 37 38 39 40 41
// For more information about tablegen. See https://llvm.org/docs/TableGen/index.html
//
// The output files are ops.h.inc and ops.cpp.inc and are generated at build time
// The file declares base classes to ease opcode definitions and hoist common parts out.
// Each class fixes a set of attributes. For example:
// class NG_Unary_Arith_Op defines a base class for all unary arithmetic ops without side-effects
//
// An opcode is a record definition of the form
//      def AbsOp      : NG_Unary_Arith_Op<"abs">;
//
// Each def will corresponding to a C++ class


42 43 44 45 46 47 48 49
def NG_Dialect : Dialect {
  let name = "ng";
  // TODO: Have the dialect under its own mlir::ngraph namespace
  // At mlir top-level for now
  let cppNamespace = "";
}


50 51
// nGraph Types
// This defines records equivalent to nGraph types. It doesn't generate code.
52 53
// This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification
54
def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
55
                     "nGraph Tensor Type">;
56

57 58 59
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;

60
// nGraph operation base class.
61 62
// Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> :
63
    Op<NG_Dialect, mnemonic, traits> {}
64 65 66 67 68 69 70 71 72 73

// Operations producing single result.
// Will set OneResult trait based on Results out dag.
class NG_OneResult_Op<string mnemonic, list<OpTrait> traits = []> :
      NG_Op<mnemonic, traits>, Results<(outs NG_TensorType:$res)> {}

// Operations producing no results
class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> :
      NG_Op<mnemonic, traits>, Results<(outs)> {}

74
// Base class for arithmetic unary operations without side effects.
75
class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
76
      NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
77 78 79
      Arguments<(ins NG_TensorType:$arg)>
{
  // TODO: Implement
80
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
81 82 83 84

  let verifier = [{ return verifyUnaryArithOp(this); }];
}

85
// Base class for arithmetic binary operations without side effects.
86
class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
87
      NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
88
      Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
89 90
{
  // TODO: Implement
91
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
92 93 94 95 96 97
}

// Base class for arithmetic binary operations with verifier.
class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
      NG_OneResult_Op<mnemonic, traits>,
      Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
98 99
{
  // TODO: Implement
100
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
101 102 103 104

  let verifier = [{ return verifyBinaryArithOp(this); }];
}

105 106 107 108 109 110
// Base class for comparison operations with verifier.
class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
      NG_OneResult_Op<mnemonic, traits>,
      Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
  // TODO: Implement
111
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
112 113 114 115 116 117 118 119 120 121

  let verifier = [{ return verifyCmpOp(this); }];
}

// Base class for ternary operations without side effects.
class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
      NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
      Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
{
  // TODO: Implement
122
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
123 124 125
}


126 127 128 129 130 131 132
// Base class for terminator operations.
class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> :
    NG_Op<mnemonic, !listconcat(traits, [Terminator])>,
    Arguments<(ins Variadic<NG_TensorType>:$args)>, Results<(outs)> {}

// Unary Operations
def NGAbsOp      : NG_Unary_Arith_Op<"abs">;
133 134 135
def NGACosOp     : NG_Unary_Arith_Op<"acos">;
def NGASinOp     : NG_Unary_Arith_Op<"asin">;
def NGATanOp     : NG_Unary_Arith_Op<"atan">;
136 137
def NGCeilOp     : NG_Unary_Arith_Op<"ceil">;
def NGConvertOp  : NG_Unary_Arith_Op<"conv">;
138 139
def NGCosOp      : NG_Unary_Arith_Op<"cos">;
def NGCoshOp     : NG_Unary_Arith_Op<"cosh">;
140
def NGExpOp      : NG_Unary_Arith_Op<"exp">;
141 142 143 144 145 146 147 148 149 150
def NGFloorOp    : NG_Unary_Arith_Op<"floor">;
def NGLogOp      : NG_Unary_Arith_Op<"log">;
def NGNegOp      : NG_Unary_Arith_Op<"neg">;
def NGNotOp      : NG_Unary_Arith_Op<"not">;
def NGSignOp     : NG_Unary_Arith_Op<"sign">;
def NGSinOp      : NG_Unary_Arith_Op<"sin">;
def NGSinhOp     : NG_Unary_Arith_Op<"sinh">;
def NGTanOp      : NG_Unary_Arith_Op<"tan">;
def NGTanhOp     : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp     : NG_Unary_Arith_Op<"sqrt">;
151
def NGReluOp     : NG_Unary_Arith_Op<"relu">;
152 153

// Binary Operations
154 155 156 157 158 159 160 161
def NGAddOp      : NG_Binary_Arith_Op<"add", [Commutative]>;
def NGAndOp      : NG_Binary_Arith_Op<"and", [Commutative]>;
def NGSubOp      : NG_Binary_Arith_Op<"sub">;
def NGDivOp      : NG_Binary_Arith_Op<"div">;
def NGMaxOp      : NG_Binary_Arith_Op<"max", [Commutative]>;
def NGMinOp      : NG_Binary_Arith_Op<"min", [Commutative]>;
def NGMulOp      : NG_Binary_Arith_Op<"mul", [Commutative]>;
def NGPowOp      : NG_Binary_Arith_Op<"pow">;
162 163

// Comparison
164 165 166 167 168 169
def NGEqOp        : NG_Cmp_Op<"equal">;
def NGGreaterOp   : NG_Cmp_Op<"greater">;
def NGGreaterEqOp : NG_Cmp_Op<"greater.eq">;
def NGLessOp      : NG_Cmp_Op<"less">;
def NGLessEqOp    : NG_Cmp_Op<"less.eq">;
def NGNotEqOp     : NG_Cmp_Op<"not.equal">;
170 171

// Other
172 173 174 175
def NGSelectOp    : NG_Ternary_Op<"select">
{
  let verifier = [{ return verifyOp(this); }];
}
176

177 178
// Dot Product
def NGDotOp : NG_Binary_Op<"dot">
179
{
180 181
  // TODO: Add reduction axis attribute when needed.
  let verifier = [{ return verifyOp(this); }];
182 183
}

Adam Procter's avatar
Adam Procter committed
184 185 186 187
// TODO(amprocte): Might be nice to rebase this on some sort of NG_Variadic_Op
// class, but I'm not sure how to add concatenation_axis into the args if we
// do that.
def NGConcatOp :
nmostafa's avatar
nmostafa committed
188 189
    NG_OneResult_Op<"concat", [NoSideEffect]>,
    Arguments<(ins Variadic<NG_TensorType>:$args, I64Attr:$concatenation_axis)>
Adam Procter's avatar
Adam Procter committed
190 191 192 193 194 195
{
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];

  let verifier = [{ return verifyOp(this); }];
}

196 197 198 199 200 201
class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
      NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
      Arguments<(ins NG_TensorType:$operand, I64ArrayAttr:$axes)>
{
  let summary = "Base class for reduction operations that perform a reduction "
                "across the axes of a  single tensor.";
nmostafa's avatar
nmostafa committed
202
  let description = [{Axes are represented as an array of I64 attributes.}];
203

204
  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258

  // TODO
  let verifier = [{ return verifyAxisReductionOp(this); }];
}

// Axis reduction operations.
def NGSumRedOp : NG_Axis_Reduction_Op<"sum.red">
{
  let summary = "Axis sum reduction of a tensor.";
  let verifier = [{ return verifyAxisReductionOp(this); }];
}

def NGProdRedOp : NG_Axis_Reduction_Op<"prod.red">
{
  let summary = "Axis product reduction of a tensor.";
  let verifier = [{ return verifyAxisReductionOp(this); }];
}

def NGMinRedOp : NG_Axis_Reduction_Op<"min.red">
{
  let summary = "Axis minimum reduction of a tensor.";
  let verifier = [{ return verifyAxisReductionOp(this); }];
}

def NGMaxRedOp : NG_Axis_Reduction_Op<"max.red">
{
  let summary = "Axis maximum reduction of a tensor.";
  let verifier = [{ return verifyAxisReductionOp(this); }];
}

def NGArgMinRedOp : NG_Axis_Reduction_Op<"argmin.red">
{
  let summary = "Axis minimum index reduction of a tensor.";
  let verifier = [{ return verifyIndexReductionOp(this); }];
}

def NGArgMaxRedOp : NG_Axis_Reduction_Op<"argmax.red">
{
  let summary = "Axis maximum index reduction of a tensor.";
  let verifier = [{ return verifyIndexReductionOp(this); }];
}

def NGAllRedOp : NG_Axis_Reduction_Op<"all.red">
{
  let summary = "Axis logical AND reduction of a boolean tensor.";
  let verifier = [{ return verifyLogicalReductionOp(this); }];
}

def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
{
  let summary = "Axis logical OR reduction of a boolean tensor.";
  let verifier = [{ return verifyLogicalReductionOp(this); }];
}

nmostafa's avatar
nmostafa committed
259 260 261
// Gather
def NGGatherOp : 
    NG_OneResult_Op<"gather", [NoSideEffect]>,
262
    Arguments<(ins NG_TensorType:$params, NG_TensorType:$indices, I64Attr:$axis)>
nmostafa's avatar
nmostafa committed
263
{
264
  let summary = "Gather slices from params along the specified axis according to indices";
nmostafa's avatar
nmostafa committed
265
  let description = [{
266 267 268 269
    Gather slices from axis of params according to indices
    params The tensor from which slices are gathered
    indices Index tensor. Data type must be `element::i32` or `element::i64`
    axis Axis in params to gather
nmostafa's avatar
nmostafa committed
270 271 272 273 274 275 276
  }];

  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];

  let verifier = [{ return verifyOp(this); }];
}

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
// Convolution
def NGConvolutionOp :
    NG_OneResult_Op<"convolution", [NoSideEffect]>,
    Arguments<(ins NG_TensorType:$images, NG_TensorType:$filters, 
               I64ArrayAttr:$strides,
               I64ArrayAttr:$padBelow,
               I64ArrayAttr:$padAbove)>
{
  let summary = "Convolution of a tensor of filters over a tensor of images with padding support";
  let description = [{
    Convolution operation with padding and stride support. No dilation supported.
    images        Input image tensor. Shape is [N, C_IN, D1, ... Df]
    filters       Set of filters to apply. Shape is [C_OUT, C_IN, F1, ... Ff]
    strides       Window movement strides. Shape is [f]. Attribute.
    padBelow      The padding-below sizes. Shape is [f]. Attribute.
    padAbove      The padding-below sizes. Shape is [f]. Attribute.
   
    Output is of shape [N, C_OUT, R1, ... Rf]
  }];

  let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
  let verifier = [{ return verifyOp(this); }];
  let extraClassDeclaration = [{
    void setStrides(ArrayAttr& arrayAttr)  { this->setAttr("strides", arrayAttr);  }
    void setPadBelow(ArrayAttr& arrayAttr) { this->setAttr("padBelow", arrayAttr); }
    void setPadAbove(ArrayAttr& arrayAttr) { this->setAttr("padAbove", arrayAttr); }
  }];
}

306 307
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;