Commit 86bc31cc authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Move mlir related classes to MLIR namespace (#23)

* Move dialect and types to mlir namespace

* PR fixes and some cleanup

* Merge fix
parent ea441a6e
......@@ -51,37 +51,35 @@ using namespace ngraph::runtime::ngmlir;
#define COMPILE_OP_DECL(op_name) \
create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node)
namespace ngraph
{
MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors)
: m_compiled_kernel(compiled_kernel)
, m_external_tensors(external_tensors)
{
{
NGRAPH_ASSERT((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size())
<< "Number of arguments and outputs doesn't match number of tensors";
}
}
void MLIRCompiler::init_mlir()
{
mlir::registerDialect<NGDialect>();
void MLIRCompiler::init_mlir()
{
mlir::registerDialect<mlir::NGDialect>();
// Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
}
}
void MLIRCompiler::compile_and_run()
{
void MLIRCompiler::compile_and_run()
{
build_module(); // MLIR gen
lower_dialect();
optimize();
bind_arguments();
execute();
cleanup();
}
}
void MLIRCompiler::build_module()
{
void MLIRCompiler::build_module()
{
// initialize an empty module
m_module = make_unique<mlir::Module>(&m_context);
......@@ -131,76 +129,77 @@ namespace ngraph
{
m_module->dump();
}
}
}
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> shape;
for (auto d : tensor->get_shape())
{
shape.push_back(d);
}
return NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
}
return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
}
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
switch (type.get_type_enum())
{
case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic:
default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break;
case ngraph::element::Type_t::bf16: return NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::f32: return NGFloatType::getF32(&m_context);
case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context);
case ngraph::element::Type_t::f64: return NGFloatType::getF64(&m_context);
case ngraph::element::Type_t::f64: return mlir::NGFloatType::getF64(&m_context);
case ngraph::element::Type_t::i8: return NGIntegerType::getInt8(&m_context);
case ngraph::element::Type_t::i8: return mlir::NGIntegerType::getInt8(&m_context);
case ngraph::element::Type_t::u8:
case ngraph::element::Type_t::boolean: return NGIntegerType::getUInt8(&m_context);
case ngraph::element::Type_t::boolean: return mlir::NGIntegerType::getUInt8(&m_context);
case ngraph::element::Type_t::i16: return NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::i16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::u16: return NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::u16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::i32: return NGIntegerType::getInt32(&m_context);
case ngraph::element::Type_t::i32: return mlir::NGIntegerType::getInt32(&m_context);
case ngraph::element::Type_t::u32: return NGIntegerType::getUInt32(&m_context);
case ngraph::element::Type_t::u32: return mlir::NGIntegerType::getUInt32(&m_context);
case ngraph::element::Type_t::i64: return NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::u64: return NGIntegerType::getUInt64(&m_context);
case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context);
}
NGRAPH_FAIL(); // Unreachable
return mlir::Type();
}
}
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{
NGRAPH_ASSERT(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end())
<< "tensor value already defined";
TensorInfo tensor_info{value};
m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info));
}
}
MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor)
{
MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor)
{
auto it = m_tensor_to_value_map.find(tensor);
NGRAPH_ASSERT(it != m_tensor_to_value_map.end()) << "Undefined tensor";
return it->second;
}
}
void MLIRCompiler::lower_dialect()
{
void MLIRCompiler::lower_dialect()
{
mlir::PassManager pm;
pm.addPass(createDialectLoweringPass(this));
pm.addPass(mlir::createDialectLoweringPass(this));
pm.addPass(mlir::createCanonicalizerPass());
pm.run(m_module.get());
if (failed(m_module->verify()))
......@@ -211,23 +210,23 @@ namespace ngraph
{
m_module->dump();
}
}
}
void MLIRCompiler::optimize()
{
void MLIRCompiler::optimize()
{
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
}
}
// MLIR builders
#define TI(x) std::type_index(typeid(x))
void MLIRCompiler::build_ng_dialect()
{
void MLIRCompiler::build_ng_dialect()
{
const NodeVector& sub_graph = m_compiled_kernel->get_node_list();
NGRAPH_ASSERT(sub_graph.size() == 1) << "Supporting code-gen for a single node for now";
......@@ -247,52 +246,51 @@ namespace ngraph
}
create_return();
}
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{
return compiler.create_binary_op<NGAddOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{
return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias)
{
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias)
{
// TODO(dcab): Implement all the variants of a Matmul/MatmulBias op.
// Keeping it simple for now.
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2)
<< "Bias is not supported in MatmulBias operation";
return compiler.create_binary_op<NGMatMulBiasOp>(ng_node);
}
return compiler.create_binary_op<mlir::NGMatMulBiasOp>(ng_node);
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>},
{TI(ngraph::op::MatmulBias), &MLIRCompiler::create_op<ngraph::op::MatmulBias>}};
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto rhs_v = get_tensor_value(rhs.get()).m_value;
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), lhs_v, rhs_v)
.getResult();
}
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), lhs_v, rhs_v).getResult();
}
void MLIRCompiler::create_return()
{
void MLIRCompiler::create_return()
{
std::vector<mlir::Value*> value_list;
for (auto output : m_compiled_kernel->get_kernel_outputs())
{
value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value);
}
m_builder->create<NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
void MLIRCompiler::bind_arguments()
{
void MLIRCompiler::bind_arguments()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
mlir::Function* func = m_module->getNamedFunction("main");
......@@ -325,10 +323,10 @@ namespace ngraph
// inserting memory manager ptr in right location ?
NGRAPH_ASSERT(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}
}
void MLIRCompiler::execute()
{
void MLIRCompiler::execute()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
......@@ -354,13 +352,12 @@ namespace ngraph
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto invocationResult =
m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
NGRAPH_ASSERT(!invocationResult) << "JIT invocation of 'main' failed\n";
}
}
void MLIRCompiler::cleanup()
{
void MLIRCompiler::cleanup()
{
// Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invoke_args)
{
......@@ -373,10 +370,10 @@ namespace ngraph
// Free allocated memory for JIT'ed code temps
m_mem_mgr.freeAll();
}
}
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func)
{
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func)
{
SmallVector<void*, 8> args;
args.reserve(func->getNumArguments());
for (const auto& arg : func->getArguments())
......@@ -388,10 +385,10 @@ namespace ngraph
args.push_back(descriptor);
}
return args;
}
}
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type)
{
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type)
{
auto memRefType = type.dyn_cast<mlir::MemRefType>();
if (!memRefType)
return nullptr;
......@@ -403,5 +400,4 @@ namespace ngraph
reinterpret_cast<mlir::StaticFloatMemRef*>(malloc(sizeof(mlir::StaticFloatMemRef)));
descriptor->data = nullptr;
return descriptor;
}
}
......@@ -18,11 +18,8 @@
#include "ops.hpp"
#include "type.hpp"
using namespace ngraph::runtime::ngmlir;
using namespace mlir;
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>();
NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx)
{
......
......@@ -23,13 +23,8 @@
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "ngraph/assertion.hpp"
namespace ngraph
namespace mlir
{
namespace runtime
{
namespace ngmlir
{
class NGDialect : public mlir::Dialect
{
public:
......@@ -41,6 +36,4 @@ namespace ngraph
}
void printType(mlir::Type type, llvm::raw_ostream& os) const override;
};
}
}
}
......@@ -26,47 +26,40 @@ using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace ngraph
using namespace mlir;
// TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
// - Op helpers: Since it is not possible to add arbitrary code (and would complicate the .td file)
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{
namespace runtime
{
namespace ngmlir
{
// TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
// - Op helpers: Since it is not possible to add arbitrary code (and would complicate the .td file)
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{
// TODO: Check matching element types
return mlir::success();
}
}
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op)
{
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op)
{
// TODO: Check matching element types
return mlir::success();
}
}
template <typename T>
static mlir::LogicalResult verifyOp(T* op)
{
template <typename T>
static mlir::LogicalResult verifyOp(T* op)
{
return op->emitOpError("Unsupported verifier for this operation");
}
}
// Per op specializations
template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op)
{
// Per op specializations
template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op)
{
// Verify that we have 2 operands
// Bias operand must be null for now (not implemented)
if (op->getNumOperands() != 2)
......@@ -92,17 +85,10 @@ namespace ngraph
// TODO(dcab): Improve verification: matching types, proper shapes, etc.
return mlir::success();
}
}
}
}
using namespace mlir;
namespace runtime
{
namespace ngmlir
{
namespace mlir
{
#define GET_OP_CLASSES
#include "ops.cpp.inc"
}
}
}
......@@ -22,19 +22,8 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
namespace ngraph
namespace mlir
{
namespace runtime
{
namespace ngmlir
{
// TODO: We shouldn't have this here, but we need to expose mlir types for the .inc file to use
// we cannot forward declare the mlir types since they rely on the Ops we are defining (see. Op<NGAddOp, ...>)
//
// Other ways to avoid namespace pollution ?
using namespace mlir;
#define GET_OP_CLASSES
#include "ops.h.inc"
}
}
}
......@@ -40,7 +40,7 @@ include "mlir/IR/OpBase.td"
// This defines records equivalent to NGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"{0}.isa<ngraph::runtime::ngmlir::NGTensorType>()">,
def NG_TensorType : Type<CPred<"{0}.isa<mlir::NGTensorType>()">,
"NGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
......
......@@ -31,12 +31,10 @@ using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace ngraph
{
using namespace runtime::ngmlir;
using namespace mlir;
unsigned NGIntegerType::getWidth() const
{
unsigned NGIntegerType::getWidth() const
{
switch (getKind())
{
case NG_I8_TYPE_ID:
......@@ -50,10 +48,10 @@ namespace ngraph
default: NGRAPH_FAIL() << "Invalid type ID";
}
return 0;
}
}
bool NGIntegerType::isSigned() const
{
bool NGIntegerType::isSigned() const
{
switch (getKind())
{
case NG_I8_TYPE_ID:
......@@ -67,19 +65,17 @@ namespace ngraph
default: NGRAPH_FAIL() << "Invalid type ID";
}
return false;
}
}
/// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same.
NGTensorType NGTensorType::get(mlir::MLIRContext* context, EltType eltType, Shape shape)
{
/// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same.
NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shape)
{
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
}
}
mlir::MemRefType NGTensorType::toMemref()
{
auto memRefType =
mlir::MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
MemRefType NGTensorType::toMemref()
{
auto memRefType = MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
return memRefType;
}
}
......@@ -23,12 +23,8 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
namespace ngraph
namespace mlir
{
namespace runtime
{
namespace ngmlir
{
using llvm::raw_ostream;
enum NGTypeKind
......@@ -120,10 +116,7 @@ namespace ngraph
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir::Type toStdType() const
{
return mlir::IntegerType::get(getWidth(), getContext());
}
mlir::Type toStdType() const { return mlir::IntegerType::get(getWidth(), getContext()); }
/// Check if signed type
bool isSigned() const;
......@@ -224,8 +217,7 @@ namespace ngraph
};
/// NGraph Tensor Type
class NGTensorType
: public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage>
class NGTensorType : public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage>
{
public:
using Base::Base;
......@@ -255,6 +247,4 @@ namespace ngraph
/// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; }
};
}
}
}
......@@ -131,7 +131,7 @@ namespace
// we find out output values by looking at returned values
// any return should return all outputs of the subgraph
f->walk<ngmlir::NGReturnOp>([this, &outputCount](ngmlir::NGReturnOp ret) {
f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++)
{
this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i));
......@@ -151,7 +151,7 @@ namespace
// however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction()
// will give you the original func). This makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<ngmlir::NGFakeInputOp>(rewriter->getUnknownLoc(),
auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext()));
// will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult());
......@@ -170,7 +170,7 @@ namespace
if (it != outputMap.end())
{
unsigned argId = (*it).second;
auto fakeOp = rewriter.create<ngmlir::NGFakeInputOp>(
auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(),
m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */
......@@ -183,7 +183,7 @@ namespace
}
else
{
auto tensorType = origResult->getType().cast<ngmlir::NGTensorType>();
auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{tensorType.toMemref()},
......@@ -237,8 +237,7 @@ namespace
for (auto value : m_loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<ngmlir::NGFakeInputOp>())
<< "output value not defined by fake output?";
NGRAPH_ASSERT(op->isa<NGFakeInputOp>()) << "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
......@@ -269,23 +268,23 @@ namespace
// NGDialect converters
Type DialectLowerer::convertType(Type t)
{
if (auto tensor = t.dyn_cast<ngmlir::NGTensorType>())
if (auto tensor = t.dyn_cast<NGTensorType>())
{
return tensor.toMemref();
}
// element type
if (auto type = t.dyn_cast<ngmlir::NGFloatType>())
if (auto type = t.dyn_cast<NGFloatType>())
{
// Float
// float types are already std type
return type;
}
if (auto type = t.dyn_cast<ngmlir::NGIntegerType>())
if (auto type = t.dyn_cast<NGIntegerType>())
{
// map it to std type
return type.toStdType();
}
if (auto type = t.dyn_cast<ngmlir::NGBoolType>())
if (auto type = t.dyn_cast<NGBoolType>())
{
return type.toStdType();
}
......@@ -298,7 +297,7 @@ namespace
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{
auto add = op->cast<ngmlir::NGAddOp>();
auto add = op->cast<NGAddOp>();
auto loc = add.getLoc();
Value *origResult, *newResult;
......@@ -335,7 +334,7 @@ namespace
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{
auto matmul = op->cast<ngmlir::NGMatMulBiasOp>();
auto matmul = op->cast<NGMatMulBiasOp>();
auto loc = matmul.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
......@@ -406,16 +405,10 @@ namespace
}
}
namespace ngraph
namespace mlir
{
namespace runtime
{
namespace ngmlir
{
Pass* createDialectLoweringPass(MLIRCompiler* compiler)
Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler)
{
return new DialectLoweringPass(*compiler);
}
}
}
}
......@@ -16,9 +16,9 @@
#pragma once
#include "contrib/mlir/compiler.hpp"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
namespace ngraph
{
namespace runtime
......@@ -26,8 +26,10 @@ namespace ngraph
namespace ngmlir
{
class MLIRCompiler;
mlir::Pass* createDialectLoweringPass(MLIRCompiler* compiler);
}
}
}
namespace mlir
{
mlir::Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler);
}
......@@ -22,7 +22,7 @@ class OP##Conversion : public mlir::DialectOpConversion \
{\
public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectOpConversion(ngraph::runtime::ngmlir::OP::getOperationName(), 1, context),\
: mlir::DialectOpConversion(mlir::OP::getOperationName(), 1, context),\
m_pass(pass)\
{} \
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment