Commit 8ef5b0ca authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] New Core Ops (V0) and Ops Versioning in NG dialect (#3764)

* Init commit to implement interface

*  Add two op interfaces for v0 and v1. Add a unit-test

* Add missing files

* Move test to separate file

* Add Fused Op interface

* Missing files

* style

* fused ops

* Remove V1 ops for now

* Added enum attributes. WIP

* Completed non-experiemntal non-fused-ops

* Add ops_attributes

* Minor fixes

* Minor fixes

* Added enum setting/reading test

* style-apply

* Added attributes tests

* Fix dialect init

* style

* fix typo

* Fix merge errors

* Include file with MLIR on
parent 4cecf6e4
......@@ -97,14 +97,29 @@ function(ngraph_tablegen ofn)
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} PARENT_SCOPE)
endfunction()
set(MLIR_TABLEGEN_EXE mlir-tblgen)
set(LLVM_TARGET_DEFINITIONS core/ngraph_dialect/ops.td)
# table-gen ops.td
set(LLVM_TARGET_DEFINITIONS core/ngraph_dialect/ops.td)
ngraph_tablegen(ops.h.inc -gen-op-decls)
ngraph_tablegen(ops.cpp.inc -gen-op-defs)
add_public_tablegen_target(ngraph_ops_gen)
add_dependencies(mlir_backend ngraph_ops_gen)
# table-gen ops_interfaces.td
set(LLVM_TARGET_DEFINITIONS core/ngraph_dialect/ops_interfaces.td)
ngraph_tablegen(ops_interfaces.h.inc -gen-op-interface-decls)
ngraph_tablegen(ops_interfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(ngraph_ops_interfaces_gen)
# tabel-gen ops attributes.td
set(LLVM_TARGET_DEFINITIONS core/ngraph_dialect/ops_attributes.td)
ngraph_tablegen(ops_attributes.h.inc -gen-enum-decls)
ngraph_tablegen(ops_attributes.cpp.inc -gen-enum-defs)
add_public_tablegen_target(ngraph_ops_attributes_gen)
add_dependencies(mlir_backend ngraph_ops_gen ngraph_ops_interfaces_gen ngraph_ops_attributes_gen)
target_include_directories(mlir_backend PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
install(TARGETS mlir_backend DESTINATION ${NGRAPH_INSTALL_LIB})
......@@ -97,6 +97,7 @@ void MLIRCompiler::init()
if (!initialized)
{
// TODO: Remove this as it is not part of compiler init
initializeNGraphMLIR();
// Register MLIR command line options in the pool of supported flags and and process flags
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//
// This is the nGraph Dialect Fused Ops definition file
// All Operations in this file implement FusedOp interface.
//===----------------------------------------------------------------------===//
#ifdef NG_FUSED_OPS
#else
#define NG_FUSED_OPS
// Squeeze Op
def NGSqueezeOp :
NG_OneResult_Op<"squeeze", [NoSideEffect, FusedOp]>,
Arguments<(ins NG_TensorType:$data, NG_TensorType:$axes)>
{
let summary = "Squeeze Op";
let description = [{
Squeeze Op
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let extraClassDeclaration = [{
void decompose() {
//TODO: Call a templatized helper: decompose(this) to do the actual decomposition
}
}];
}
#endif //NG_FUSED_OPS
......@@ -19,6 +19,7 @@
#include "ops.hpp"
#include "assertion.hpp"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
......@@ -31,6 +32,8 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
using namespace mlir;
#include "ops_attributes.cpp.inc"
// TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
......@@ -330,6 +333,8 @@ mlir::IntegerAttr getBufferId(mlir::Operation* op)
namespace mlir
{
#include "ops_interfaces.cpp.inc"
#define GET_OP_CLASSES
#include "ops.cpp.inc"
}
......@@ -26,8 +26,16 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
// attributes
// Currently table-gen dictates that enum attributes are in global namespace
#include "ops_attributes.h.inc"
namespace mlir
{
// interfaces
#include "ops_interfaces.h.inc"
// ops
#define GET_OP_CLASSES
#include "ops.h.inc"
#undef GET_OP_CLASSES
......
......@@ -22,7 +22,7 @@
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
include "mlir/IR/OpBase.td"
include "core/ngraph_dialect/ops_interfaces.td"
// nGraph Dialect operations definitions
//
// This files declares nGraph operations that table-gen uses to create C++ code
......@@ -38,7 +38,6 @@ include "mlir/IR/OpBase.td"
//
// Each def will corresponding to a C++ class
def NG_Dialect : Dialect {
let name = "ng";
// TODO: Have the dialect under its own mlir::ngraph namespace
......@@ -46,7 +45,6 @@ def NG_Dialect : Dialect {
let cppNamespace = "";
}
// nGraph Types
// This defines records equivalent to nGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs.
......@@ -123,76 +121,6 @@ class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
}
// Base class for terminator operations.
class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, !listconcat(traits, [Terminator])>,
Arguments<(ins Variadic<NG_TensorType>:$args)>, Results<(outs)> {}
// Unary Operations
def NGAbsOp : NG_Unary_Arith_Op<"abs">;
def NGACosOp : NG_Unary_Arith_Op<"acos">;
def NGASinOp : NG_Unary_Arith_Op<"asin">;
def NGATanOp : NG_Unary_Arith_Op<"atan">;
def NGCeilOp : NG_Unary_Arith_Op<"ceil">;
def NGConvertOp : NG_Unary_Arith_Op<"conv">;
def NGCosOp : NG_Unary_Arith_Op<"cos">;
def NGCoshOp : NG_Unary_Arith_Op<"cosh">;
def NGExpOp : NG_Unary_Arith_Op<"exp">;
def NGFloorOp : NG_Unary_Arith_Op<"floor">;
def NGLogOp : NG_Unary_Arith_Op<"log">;
def NGNegOp : NG_Unary_Arith_Op<"neg">;
def NGNotOp : NG_Unary_Arith_Op<"not">;
def NGSignOp : NG_Unary_Arith_Op<"sign">;
def NGSinOp : NG_Unary_Arith_Op<"sin">;
def NGSinhOp : NG_Unary_Arith_Op<"sinh">;
def NGTanOp : NG_Unary_Arith_Op<"tan">;
def NGTanhOp : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">;
def NGReluOp : NG_Unary_Arith_Op<"relu">;
// Binary Operations
def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>;
def NGAndOp : NG_Binary_Arith_Op<"and", [Commutative]>;
def NGSubOp : NG_Binary_Arith_Op<"sub">;
def NGDivOp : NG_Binary_Arith_Op<"div">;
def NGMaxOp : NG_Binary_Arith_Op<"max", [Commutative]>;
def NGMinOp : NG_Binary_Arith_Op<"min", [Commutative]>;
def NGMulOp : NG_Binary_Arith_Op<"mul", [Commutative]>;
def NGPowOp : NG_Binary_Arith_Op<"pow">;
// Comparison
def NGEqOp : NG_Cmp_Op<"equal">;
def NGGreaterOp : NG_Cmp_Op<"greater">;
def NGGreaterEqOp : NG_Cmp_Op<"greater.eq">;
def NGLessOp : NG_Cmp_Op<"less">;
def NGLessEqOp : NG_Cmp_Op<"less.eq">;
def NGNotEqOp : NG_Cmp_Op<"not.equal">;
// Other
def NGSelectOp : NG_Ternary_Op<"select">
{
let verifier = [{ return verifyOp(this); }];
}
// Dot Product
def NGDotOp : NG_Binary_Op<"dot">
{
// TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }];
}
// TODO(amprocte): Might be nice to rebase this on some sort of NG_Variadic_Op
// class, but I'm not sure how to add concatenation_axis into the args if we
// do that.
def NGConcatOp :
NG_OneResult_Op<"concat", [NoSideEffect]>,
Arguments<(ins Variadic<NG_TensorType>:$args, I64Attr:$concatenation_axis)>
{
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
}
class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
NG_OneResult_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
Arguments<(ins NG_TensorType:$operand, I64ArrayAttr:$axes)>
......@@ -207,101 +135,24 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
let verifier = [{ return verifyAxisReductionOp(this); }];
}
// Axis reduction operations.
def NGSumRedOp : NG_Axis_Reduction_Op<"sum.red">
{
let summary = "Axis sum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGProdRedOp : NG_Axis_Reduction_Op<"prod.red">
{
let summary = "Axis product reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGMinRedOp : NG_Axis_Reduction_Op<"min.red">
{
let summary = "Axis minimum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGMaxRedOp : NG_Axis_Reduction_Op<"max.red">
{
let summary = "Axis maximum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
}
def NGArgMinRedOp : NG_Axis_Reduction_Op<"argmin.red">
{
let summary = "Axis minimum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }];
}
def NGArgMaxRedOp : NG_Axis_Reduction_Op<"argmax.red">
{
let summary = "Axis maximum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }];
}
def NGAllRedOp : NG_Axis_Reduction_Op<"all.red">
{
let summary = "Axis logical AND reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }];
}
// Base class for terminator operations.
class NG_Terminator_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, !listconcat(traits, [Terminator])>,
Arguments<(ins Variadic<NG_TensorType>:$args)>, Results<(outs)> {}
def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
{
let summary = "Axis logical OR reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }];
}
// Gather
def NGGatherOp :
NG_OneResult_Op<"gather", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$params, NG_TensorType:$indices, I64Attr:$axis)>
{
let summary = "Gather slices from params along the specified axis according to indices";
let description = [{
Gather slices from axis of params according to indices
params The tensor from which slices are gathered
indices Index tensor. Data type must be `element::i32` or `element::i64`
axis Axis in params to gather
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
let verifier = [{ return verifyOp(this); }];
}
// ops attributes
include "core/ngraph_dialect/ops_attributes.td"
// Convolution
def NGConvolutionOp :
NG_OneResult_Op<"convolution", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$images, NG_TensorType:$filters,
I64ArrayAttr:$strides,
I64ArrayAttr:$padBelow,
I64ArrayAttr:$padAbove)>
{
let summary = "Convolution of a tensor of filters over a tensor of images with padding support";
let description = [{
Convolution operation with padding and stride support. No dilation supported.
images Input image tensor. Shape is [N, C_IN, D1, ... Df]
filters Set of filters to apply. Shape is [C_OUT, C_IN, F1, ... Ff]
strides Window movement strides. Shape is [f]. Attribute.
padBelow The padding-below sizes. Shape is [f]. Attribute.
padAbove The padding-below sizes. Shape is [f]. Attribute.
Output is of shape [N, C_OUT, R1, ... Rf]
}];
// Version 0 Ops
include "core/ngraph_dialect/ops_v0.td"
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let extraClassDeclaration = [{
void setStrides(ArrayAttr& arrayAttr) { this->setAttr("strides", arrayAttr); }
void setPadBelow(ArrayAttr& arrayAttr) { this->setAttr("padBelow", arrayAttr); }
void setPadAbove(ArrayAttr& arrayAttr) { this->setAttr("padAbove", arrayAttr); }
}];
}
// Version 1 Ops
include "core/ngraph_dialect/ops_v1.td"
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
// Fused Ops
include "core/ngraph_dialect/fused_ops.td"
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//
// This is the nGraph Dialect operation definition file.
//
//===----------------------------------------------------------------------===//
#ifdef NG_OP_ATTRIBUTES
#else
#define NG_OP_ATTRIBUTES
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
// Padding Type used for `Convolution` and `Pooling`
//
// Follows ONNX padding type definitions
// EXPLICIT - Pad dimensions are explicity specified
// SAME_LOWER - Pad dimensions computed to match input shape
// Ceil(num_dims/2) at the beginning and
// Floor(num_dims/2) at the end
// SAME_UPPER - Pad dimensions computed to match input shape
// Floor(num_dims/2) at the beginning and
// Ceil(num_dims/2) at the end
// VALID - No padding
def PadTypeExplicit : I32EnumAttrCase<"EXPLICIT", 0>;
def PadTypeNotSet : I32EnumAttrCase<"NOT_SET", 1>;
def PadTypeSameLower : I32EnumAttrCase<"SAME_LOWER", 2>;
def PadTypeSameUpper : I32EnumAttrCase<"SAME_UPPER", 3>;
def PadTypeAuto : I32EnumAttrCase<"AUTO", 4>;
def PadTypeValid : I32EnumAttrCase<"VALID", 5>;
def PadTypeEnumAttr : I32EnumAttr<"MLIRPadType", "Padding Type used for Convolution and pooling",
[PadTypeExplicit, PadTypeNotSet, PadTypeSameLower,
PadTypeSameUpper, PadTypeAuto, PadTypeValid]>;
// Modes for the `Pad` operator
def PadModeConstant : I32EnumAttrCase<"CONSTANT", 0> ;
def PadModeEdge : I32EnumAttrCase<"EDGE", 1> ;
def PadModeReflect : I32EnumAttrCase<"REFLECT", 2> ;
def PadModeSymmetric: I32EnumAttrCase<"SYMMETRIC", 3> ;
def PadModeEnumAttr : I32EnumAttr<"MLIRPadMode", "Padding modes for pad operator",
[PadModeConstant, PadModeEdge, PadModeReflect, PadModeSymmetric]>;
// Sort Types for TopK
def SortTypeNone : I32EnumAttrCase<"NONE", 0>;
def SortTypeIndices : I32EnumAttrCase<"INDICES", 1>;
def SortTypeValues : I32EnumAttrCase<"VALUES", 2>;
def SortTypeEnumAttr : I32EnumAttr<"MLIRSortType", "Sort types for topk operator",
[SortTypeNone, SortTypeIndices, SortTypeValues]>;
#endif // NG_OP_ATTRIBUTES
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//
// This is the nGraph Dialect operation interfaces definitions
//
//===----------------------------------------------------------------------===//
#ifdef NG_OP_INTERFACES
#else
#define NG_OP_INTERFACES
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
// Op Interfaces for Op Versions
// They are empty for now. To check the version of an op, we do:
// Operation *op = …;
// if (dyn_cast<OpVersion0Interface>(op))
def OpVersion0 : OpInterface<"OpVersion0"> {
let description=[{
Interface for Version 0 Ops
}];
// Interface is empty for now.
}
def OpVersion1 : OpInterface<"OpVersion1"> {
let description=[{
Interface for Version 1 Ops
}];
// Interface is empty for now.
}
def FusedOp : OpInterface<"FusedOp"> {
let description=[{
Interface for fused ops.
Provides an API to decompose an op
}];
let methods = [
InterfaceMethod<
"Decompose the operation",
"void",
"decompose"
>
];
}
#endif // NG_OP_INTERFACES
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
//
// This is the nGraph Dialect Version 1 operations definition file.
// All Operations in this file implement OpVersion1 interface.
//===----------------------------------------------------------------------===//
// TODO: Add Version1 Ops definitions here
......@@ -24,6 +24,7 @@
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
static llvm::cl::opt<bool> clPrintIRAfterAll(
"ngraph-print-ir-after-all",
......@@ -34,7 +35,15 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
void ngraph::runtime::ngmlir::initializeNGraphMLIR()
{
mlir::registerDialect<mlir::NGraphOpsDialect>();
// Initialize a dialect only once.
// We currently have no way to query if a dialect is previously
// registered. So using a global flag instead.
static bool init = false;
if (!init)
{
mlir::registerDialect<mlir::NGraphOpsDialect>();
init = true;
}
}
void ngraph::runtime::ngmlir::dumpMlirModule(const std::string msg, mlir::ModuleOp module)
......
......@@ -468,6 +468,7 @@ endif()
if (NGRAPH_MLIR_ENABLE)
list(APPEND MULTI_TEST_SRC backend/mlir.in.cpp)
list(APPEND SRC mlir/ops_test.cpp)
endif()
if(NGRAPH_DISTRIBUTED_ENABLE)
......@@ -597,6 +598,10 @@ if (NGRAPH_ONNXIFI_ENABLE)
target_link_libraries(unit-test PRIVATE onnxifi-ngraph)
endif()
if (NGRAPH_MLIR_ENABLE)
target_include_directories(unit-test PRIVATE ${CMAKE_BINARY_DIR}/src/contrib/mlir)
endif()
# If all the runtime libraries are installed into one location, that will make life easier.
if (MSVS)
add_custom_target(unit-test-check
......
......@@ -17,6 +17,9 @@
#include <chrono>
#include <iostream>
#ifdef NGRAPH_MLIR_ENABLE
#include "contrib/mlir/utils.hpp"
#endif
#include "gtest/gtest.h"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
......@@ -54,6 +57,12 @@ int main(int argc, char** argv)
#ifdef NGRAPH_INTERPRETER_ENABLE
ngraph_register_interpreter_backend();
#endif
#ifdef NGRAPH_MLIR_ENABLE
// Initialize MLIR
ngraph::runtime::ngmlir::initializeNGraphMLIR();
#endif
auto start = std::chrono::system_clock::now();
int rc = RUN_ALL_TESTS();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
// ops tests for nGraph MLIR dialect
// Test certain invariants about
#include "gtest/gtest.h"
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include "contrib/mlir/core/ngraph_dialect/ops.hpp"
#include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include "contrib/mlir/utils.hpp"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
OpBuilder createBuilder(MLIRContext* context)
{
auto module = ModuleOp::create(UnknownLoc::get(context));
auto funcType = FunctionType::get({}, {}, context);
auto function = FuncOp::create(UnknownLoc::get(context), "main", funcType);
function.addEntryBlock();
OpBuilder builder(function.getBody());
return builder;
}
TEST(MLIR, op_version_interface)
{
MLIRContext context;
llvm::SmallVector<mlir::Type, 1> resultTypes;
OpBuilder builder(&context);
resultTypes.push_back(
mlir::NGTensorType::get(&context, mlir::NGFloatType::getF16(&context), {2, 2}));
auto operation = Operation::create(mlir::UnknownLoc::get(&context),
OperationName("ng.gather", &context),
resultTypes,
llvm::None,
llvm::None,
llvm::None,
0,
false);
EXPECT_TRUE(llvm::dyn_cast<OpVersion0>(operation) != nullptr);
EXPECT_TRUE(llvm::dyn_cast<OpVersion1>(operation) == nullptr);
}
TEST(MLIR, fused_ops_interface)
{
MLIRContext context;
llvm::SmallVector<mlir::Type, 1> resultTypes;
OpBuilder builder(&context);
resultTypes.push_back(
mlir::NGTensorType::get(&context, mlir::NGFloatType::getF16(&context), {2, 2}));
auto operation = Operation::create(mlir::UnknownLoc::get(&context),
OperationName("ng.squeeze", &context),
resultTypes,
llvm::None,
llvm::None,
llvm::None,
0,
false);
EXPECT_TRUE(llvm::dyn_cast<FusedOp>(operation) != nullptr);
if (auto fusedOp = llvm::dyn_cast<FusedOp>(operation))
{
fusedOp.decompose();
}
}
TEST(MLIR, ops_attributes)
{
MLIRContext context;
auto resultType =
mlir::NGTensorType::get(&context, mlir::NGFloatType::getF16(&context), {2, 2});
auto builder = createBuilder(&context);
auto def = builder.create<NGConstantOp>(UnknownLoc::get(&context),
resultType,
builder.getI64ArrayAttr({2, 3, 4}),
builder.getF32ArrayAttr({1.0, 2.3, 5.6}));
auto operation =
builder
.create<NGAvgPoolOp>(
UnknownLoc::get(&context),
resultType,
def.getResult(), // arg
builder.getI64ArrayAttr({2, 3, 4}), // windowShape
builder.getI64ArrayAttr({2, 3, 4}), // windowMovementStrides
builder.getI64ArrayAttr({0, 0, 0}), // padBelow
builder.getI64ArrayAttr({0, 0, 0}), // padAbove
builder.getBoolAttr(false), // includePadding
builder.getI64IntegerAttr(static_cast<int64_t>(MLIRPadType::SAME_LOWER)), // padType
builder.getBoolAttr(false)) // ceilMode
.getOperation();
auto avgPool = cast<NGAvgPoolOp>(operation);
auto padType = static_cast<MLIRPadType>(avgPool.padType().getSExtValue());
EXPECT_TRUE(padType == MLIRPadType::SAME_LOWER);
operation =
builder
.create<NGAvgPoolOp>(UnknownLoc::get(&context),
resultType,
def.getResult(), // arg
builder.getI64ArrayAttr({2, 3, 4}), // windowShape
builder.getI64ArrayAttr({2, 3, 4}), // windowMovementStrides
builder.getI64ArrayAttr({0, 0, 0}), // padBelow
builder.getI64ArrayAttr({0, 0, 0}), // padAbove
builder.getBoolAttr(false)) // includePadding
.getOperation();
avgPool = cast<NGAvgPoolOp>(operation);
padType = static_cast<MLIRPadType>(avgPool.padType().getSExtValue());
EXPECT_TRUE(padType == MLIRPadType::EXPLICIT);
auto ceilMode = avgPool.ceilMode();
EXPECT_TRUE(ceilMode == false);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment