Commit e4e3456d authored by nmostafa's avatar nmostafa

Replace NGRAPH_ASSRT/FAIL. Minor fix in Ops.td

parent 4d24d157
...@@ -60,9 +60,9 @@ MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel, ...@@ -60,9 +60,9 @@ MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
: m_compiled_kernel(compiled_kernel) : m_compiled_kernel(compiled_kernel)
, m_external_tensors(external_tensors) , m_external_tensors(external_tensors)
{ {
NGRAPH_ASSERT((m_compiled_kernel->get_arguments().size() + NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size()) m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(),
<< "Number of arguments and outputs doesn't match number of tensors"; "Number of arguments and outputs doesn't match number of tensors");
} }
void MLIRCompiler::init_mlir() void MLIRCompiler::init_mlir()
...@@ -103,8 +103,8 @@ void MLIRCompiler::build_ng_dialect_module() ...@@ -103,8 +103,8 @@ void MLIRCompiler::build_ng_dialect_module()
// Retrieve input and output tensors. // Retrieve input and output tensors.
const auto& kernel_inputs = m_compiled_kernel->get_arguments(); const auto& kernel_inputs = m_compiled_kernel->get_arguments();
const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs(); const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs();
NGRAPH_ASSERT(kernel_inputs.size() != 0) << "Cannot have empty inputs list"; NGRAPH_CHECK(kernel_inputs.size() != 0, "Cannot have empty inputs list");
NGRAPH_ASSERT(kernel_outputs.size() != 0) << "Cannot have empty outputs list"; NGRAPH_CHECK(kernel_outputs.size() != 0, "Cannot have empty outputs list");
for (auto input : kernel_inputs) for (auto input : kernel_inputs)
{ {
...@@ -138,7 +138,7 @@ void MLIRCompiler::build_ng_dialect_module() ...@@ -138,7 +138,7 @@ void MLIRCompiler::build_ng_dialect_module()
m_module->getFunctions().push_back(function.release()); m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify())) if (failed(m_module->verify()))
{ {
NGRAPH_FAIL() << "Invalid module after lowering to NG dialect"; NGRAPH_CHECK(false, "Invalid module after lowering to NG dialect");
} }
dump_mlir_module("nGraph Dialect Dump:"); dump_mlir_module("nGraph Dialect Dump:");
...@@ -170,7 +170,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -170,7 +170,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{ {
case ngraph::element::Type_t::undefined: case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic: case ngraph::element::Type_t::dynamic:
default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break; default: NGRAPH_CHECK(false, "MLIR: Unsupported NGraph types"); break;
case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context); case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::f16: return mlir::NGFloatType::getF16(&m_context); case ngraph::element::Type_t::f16: return mlir::NGFloatType::getF16(&m_context);
case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context); case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context);
...@@ -185,7 +185,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -185,7 +185,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context); case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context); case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context);
} }
NGRAPH_FAIL() << "Unreachable"; NGRAPH_CHECK(false, "Unreachable");
return mlir::Type(); return mlir::Type();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...@@ -195,8 +195,8 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -195,8 +195,8 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value) void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{ {
NGRAPH_ASSERT(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end()) NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(),
<< "tensor value already defined"; "tensor value already defined");
TensorInfo tensor_info{value}; TensorInfo tensor_info{value};
m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info)); m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info));
} }
...@@ -205,7 +205,7 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens ...@@ -205,7 +205,7 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
{ {
auto it = m_tensor_to_value_map.find(tensor); auto it = m_tensor_to_value_map.find(tensor);
NGRAPH_ASSERT(it != m_tensor_to_value_map.end()) << "Undefined tensor"; NGRAPH_CHECK(it != m_tensor_to_value_map.end(), "Undefined tensor");
return it->second; return it->second;
} }
...@@ -221,7 +221,7 @@ void MLIRCompiler::lower_ng_dialect() ...@@ -221,7 +221,7 @@ void MLIRCompiler::lower_ng_dialect()
if (failed(m_module->verify())) if (failed(m_module->verify()))
{ {
NGRAPH_FAIL() << "Incorrect module after dialect lowering"; NGRAPH_CHECK(false, "Incorrect module after dialect lowering");
} }
dump_mlir_module("Affine Dialect Dump:"); dump_mlir_module("Affine Dialect Dump:");
...@@ -236,7 +236,7 @@ void MLIRCompiler::optimize() ...@@ -236,7 +236,7 @@ void MLIRCompiler::optimize()
// Lower affine ops // Lower affine ops
pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get()); auto rr = pm.run(m_module.get());
NGRAPH_ASSERT(succeeded(rr)) << "Affine loop lowering failed"; NGRAPH_CHECK(succeeded(rr), "Affine loop lowering failed");
dump_mlir_module("Standard Dialect Dump:"); dump_mlir_module("Standard Dialect Dump:");
} }
...@@ -309,10 +309,10 @@ void MLIRCompiler::create_return() ...@@ -309,10 +309,10 @@ void MLIRCompiler::create_return()
// helpers to be used inside the function. // helpers to be used inside the function.
void MLIRCompiler::bind_arguments() void MLIRCompiler::bind_arguments()
{ {
NGRAPH_ASSERT(m_module && "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
mlir::Function* func = m_module->getNamedFunction("main"); mlir::Function* func = m_module->getNamedFunction("main");
NGRAPH_ASSERT(func && !func->getBlocks().empty()) << "Function not found"; NGRAPH_CHECK(func && !func->getBlocks().empty(), "Function not found");
// Create list with a type-erased double pointer for each invocation arguments. // Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a // We currently use 'allocateMemRefArguments', which creates a
...@@ -321,11 +321,11 @@ void MLIRCompiler::bind_arguments() ...@@ -321,11 +321,11 @@ void MLIRCompiler::bind_arguments()
// create MemRef args // create MemRef args
auto expected_arguments = allocate_memref_args(func); auto expected_arguments = allocate_memref_args(func);
NGRAPH_ASSERT(expected_arguments.size()) << "Arguments can't be created"; NGRAPH_CHECK(expected_arguments.size(), "Arguments can't be created");
m_invoke_args = std::move(expected_arguments); m_invoke_args = std::move(expected_arguments);
NGRAPH_ASSERT(m_invoke_args.size() == m_external_tensors.size()) NGRAPH_CHECK(m_invoke_args.size() == m_external_tensors.size(),
<< "Number of external tensors doesn't match number of function arguments"; "Number of external tensors doesn't match number of function arguments");
// Assign external tensor pointers to invocation arguments. // Assign external tensor pointers to invocation arguments.
for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i) for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i)
...@@ -339,20 +339,20 @@ void MLIRCompiler::bind_arguments() ...@@ -339,20 +339,20 @@ void MLIRCompiler::bind_arguments()
MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*))); MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
*mem_mgr_arg = &get_mem_mgr(); *mem_mgr_arg = &get_mem_mgr();
// inserting memory manager ptr in right location ? // inserting memory manager ptr in right location ?
NGRAPH_ASSERT(m_invoke_args.size() == get_mem_mgr_arg_id(func)); NGRAPH_CHECK(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg)); m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
} }
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code. // Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void MLIRCompiler::execute() void MLIRCompiler::execute()
{ {
NGRAPH_ASSERT(m_module && "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect. // Lower Standard dialect to LLVM dialect.
auto converter = mlir::createStdToLLVMConverter(); auto converter = mlir::createStdToLLVMConverter();
auto r = converter->convert(m_module.get()); auto r = converter->convert(m_module.get());
(void)r; (void)r;
NGRAPH_ASSERT(succeeded(r)) << "second conversion failed"; NGRAPH_CHECK(succeeded(r), "second conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:"); dump_mlir_module("LLVM-IR Dialect Dump:");
...@@ -365,7 +365,7 @@ void MLIRCompiler::execute() ...@@ -365,7 +365,7 @@ void MLIRCompiler::execute()
// LLVM optimizations at level 3. // LLVM optimizations at level 3.
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/); auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer); auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
NGRAPH_ASSERT(maybeEngine) << "failed to construct an execution engine"; NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
// Invoke the JIT-compiled function with the arguments. Note that, for API // Invoke the JIT-compiled function with the arguments. Note that, for API
...@@ -373,7 +373,7 @@ void MLIRCompiler::execute() ...@@ -373,7 +373,7 @@ void MLIRCompiler::execute()
// Please, note that 'invoke' method is overloaded with a parameter pack version. // Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked. // Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args)); auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
NGRAPH_ASSERT(!invocationResult) << "JIT invocation of 'main' failed\n"; NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
} }
void MLIRCompiler::cleanup() void MLIRCompiler::cleanup()
...@@ -418,7 +418,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type typ ...@@ -418,7 +418,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type typ
{ {
return nullptr; return nullptr;
} }
NGRAPH_ASSERT(memRefType.getNumDynamicDims() == 0) << "No support for dynamic shapes"; NGRAPH_CHECK(memRefType.getNumDynamicDims() == 0, "No support for dynamic shapes");
// We only use StaticFloatMemRef because that's what MLIR currently offers. // We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs // We should expand this with different types and dynamic MemRefs
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "dialect.hpp" #include "dialect.hpp"
#include "ngraph/check.hpp"
#include "ops.hpp" #include "ops.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -66,7 +67,7 @@ void NGDialect::printType(mlir::Type type, raw_ostream& os) const ...@@ -66,7 +67,7 @@ void NGDialect::printType(mlir::Type type, raw_ostream& os) const
os << "bool"; os << "bool";
return; return;
} }
default: { NGRAPH_ASSERT(0) << "Incorrect type to print?"; default: { NGRAPH_CHECK(false, "Incorrect type to print?");
} }
} }
} }
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "ngraph/assertion.hpp" #include "ngraph/check.hpp"
namespace mlir namespace mlir
{ {
class NGDialect : public mlir::Dialect class NGDialect : public mlir::Dialect
...@@ -31,7 +31,7 @@ namespace mlir ...@@ -31,7 +31,7 @@ namespace mlir
explicit NGDialect(mlir::MLIRContext* ctx); explicit NGDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{ {
NGRAPH_ASSERT(0) << "Unsupported type parsing."; NGRAPH_CHECK(false, "Unsupported type parsing.");
return mlir::Type(); return mlir::Type();
} }
void printType(mlir::Type type, llvm::raw_ostream& os) const override; void printType(mlir::Type type, llvm::raw_ostream& os) const override;
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// //
// This is the NGraph Dialect operation definition file. // This is the nGraph Dialect operation definition file.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
// NGraph Dialect operations definitions // nGraph Dialect operations definitions
// //
// This files declares NGraph operations that table-gen uses to create C++ code // This files declares nGraph operations that table-gen uses to create C++ code
// For more information about tablegen. See https://llvm.org/docs/TableGen/index.html // For more information about tablegen. See https://llvm.org/docs/TableGen/index.html
// //
// The output files are ops.h.inc and ops.cpp.inc and are generated at build time // The output files are ops.h.inc and ops.cpp.inc and are generated at build time
...@@ -44,17 +44,17 @@ def NG_Dialect : Dialect { ...@@ -44,17 +44,17 @@ def NG_Dialect : Dialect {
} }
// NGraph Types // nGraph Types
// This defines records equivalent to NGraph types. It doesn't generate code. // This defines records equivalent to nGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs. // This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification // Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">, def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
"NGraph Tensor Type">; "nGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering // A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">; def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
// NGraph operation base class. // nGraph operation base class.
// Prepends "ng." to operation name // Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Op<string mnemonic, list<OpTrait> traits = []> :
Op<NG_Dialect, mnemonic, traits> {} Op<NG_Dialect, mnemonic, traits> {}
...@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$arg)> Arguments<(ins NG_TensorType:$arg)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }]; let verifier = [{ return verifyUnaryArithOp(this); }];
} }
...@@ -89,7 +89,7 @@ class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -89,7 +89,7 @@ class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)> Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
} }
// Base class for arithmetic binary operations with verifier. // Base class for arithmetic binary operations with verifier.
...@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)> Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }]; let verifier = [{ return verifyBinaryArithOp(this); }];
} }
...@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)> Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }]; let verifier = [{ return verifyCmpOp(this); }];
} }
...@@ -120,7 +120,7 @@ class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -120,7 +120,7 @@ class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)> Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
} }
...@@ -189,7 +189,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -189,7 +189,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
"across the axes of a single tensor."; "across the axes of a single tensor.";
let description = "Axes are represented as an array of I64 attributes."; let description = "Axes are represented as an array of I64 attributes.";
let parser = [{ NGRAPH_FAIL() << "Parser not implemented"; return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
// TODO // TODO
let verifier = [{ return verifyAxisReductionOp(this); }]; let verifier = [{ return verifyAxisReductionOp(this); }];
......
...@@ -45,7 +45,7 @@ unsigned NGIntegerType::getWidth() const ...@@ -45,7 +45,7 @@ unsigned NGIntegerType::getWidth() const
case NG_U32_TYPE_ID: return 32; case NG_U32_TYPE_ID: return 32;
case NG_I64_TYPE_ID: case NG_I64_TYPE_ID:
case NG_U64_TYPE_ID: return 64; case NG_U64_TYPE_ID: return 64;
default: NGRAPH_FAIL() << "Invalid type ID"; default: NGRAPH_CHECK(false, "Invalid type ID");
} }
return 0; return 0;
} }
...@@ -62,7 +62,7 @@ bool NGIntegerType::isSigned() const ...@@ -62,7 +62,7 @@ bool NGIntegerType::isSigned() const
case NG_U16_TYPE_ID: case NG_U16_TYPE_ID:
case NG_U32_TYPE_ID: case NG_U32_TYPE_ID:
case NG_U64_TYPE_ID: return false; case NG_U64_TYPE_ID: return false;
default: NGRAPH_FAIL() << "Invalid type ID"; default: NGRAPH_CHECK(false, "Invalid type ID");
} }
return false; return false;
} }
...@@ -97,8 +97,8 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const ...@@ -97,8 +97,8 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
for (auto i = 0; i < shape.size(); i++) for (auto i = 0; i < shape.size(); i++)
{ {
NGRAPH_ASSERT(shape[i] >= -1) << "Invalid tensor shape"; NGRAPH_CHECK(shape[i] >= -1, "Invalid tensor shape", shape[i]);
NGRAPH_ASSERT(otherShape[i] >= -1) << "Invalid tensor shape"; NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]);
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i]) if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
continue; continue;
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
//***************************************************************************** //*****************************************************************************
#pragma once #pragma once
#include "assertion.hpp"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
...@@ -23,6 +22,7 @@ ...@@ -23,6 +22,7 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "ngraph/check.hpp"
namespace mlir namespace mlir
{ {
using llvm::raw_ostream; using llvm::raw_ostream;
...@@ -60,7 +60,7 @@ namespace mlir ...@@ -60,7 +60,7 @@ namespace mlir
static NGIntegerType get(NGTypeKind kind, mlir::MLIRContext* context) static NGIntegerType get(NGTypeKind kind, mlir::MLIRContext* context)
{ {
NGRAPH_ASSERT(kindof(kind)) << "Not an integer kind."; NGRAPH_CHECK(kindof(kind), "Not an integer kind.");
return Base::get(context, kind); return Base::get(context, kind);
} }
/// Create signed Int8 /// Create signed Int8
...@@ -154,7 +154,7 @@ namespace mlir ...@@ -154,7 +154,7 @@ namespace mlir
using Base::Base; using Base::Base;
static NGBoolType get(NGTypeKind kind, mlir::MLIRContext* context) static NGBoolType get(NGTypeKind kind, mlir::MLIRContext* context)
{ {
NGRAPH_ASSERT(kindof(kind)) << "Not a bool type."; NGRAPH_CHECK(kindof(kind), "Not a bool type.");
return Base::get(context, kind); return Base::get(context, kind);
} }
......
...@@ -133,8 +133,8 @@ namespace ...@@ -133,8 +133,8 @@ namespace
op->setAttr("graphOutputIdx", op->setAttr("graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i)); mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i));
} }
NGRAPH_ASSERT(outputCount == 0 || outputCount == ret.getNumOperands()) NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(),
<< "Inconsistent returns in function"; "Inconsistent returns in function");
outputCount = ret.getNumOperands(); outputCount = ret.getNumOperands();
}); });
// will be populated with lowered output values later // will be populated with lowered output values later
...@@ -232,7 +232,7 @@ namespace ...@@ -232,7 +232,7 @@ namespace
for (auto value : m_loweredOutputValues) for (auto value : m_loweredOutputValues)
{ {
auto op = value->getDefiningOp(); auto op = value->getDefiningOp();
NGRAPH_ASSERT(isa<NGFakeInputOp>(op)) << "output value not defined by fake output?"; NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?");
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i)); value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase(); op->erase();
i++; i++;
...@@ -289,7 +289,7 @@ namespace ...@@ -289,7 +289,7 @@ namespace
return mlir::IntegerType::get(1 /* width */, bool_type.getContext()); return mlir::IntegerType::get(1 /* width */, bool_type.getContext());
} }
NGRAPH_FAIL() << "Unsupported type to lower"; NGRAPH_CHECK(false, "Unsupported type to lower");
return type; return type;
} }
...@@ -305,7 +305,7 @@ namespace ...@@ -305,7 +305,7 @@ namespace
auto loc = add.getLoc(); auto loc = add.getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0]; auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(result->getType().isa<MemRefType>()); NGRAPH_CHECK(result->getType().isa<MemRefType>());
// Note that builder's current function is still the original function body. // Note that builder's current function is still the original function body.
// use getBlock to get the new block instead. // use getBlock to get the new block instead.
...@@ -346,18 +346,18 @@ namespace ...@@ -346,18 +346,18 @@ namespace
Value* lhs = operands[0]; Value* lhs = operands[0];
Value* rhs = operands[1]; Value* rhs = operands[1];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in DotOp"; NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp");
auto result_ty = result->getType().dyn_cast<MemRefType>(); auto result_ty = result->getType().dyn_cast<MemRefType>();
auto lhs_ty = lhs->getType().dyn_cast<MemRefType>(); auto lhs_ty = lhs->getType().dyn_cast<MemRefType>();
auto rhs_ty = rhs->getType().dyn_cast<MemRefType>(); auto rhs_ty = rhs->getType().dyn_cast<MemRefType>();
NGRAPH_ASSERT(result_ty) << "Unexpected non-memref result type"; NGRAPH_CHECK(result_ty, "Unexpected non-memref result type");
NGRAPH_ASSERT(lhs_ty) << "Unexpected non-memref LHS type"; NGRAPH_CHECK(lhs_ty, "Unexpected non-memref LHS type");
NGRAPH_ASSERT(rhs_ty) << "Unexpected non-memref RHS type"; NGRAPH_CHECK(rhs_ty, "Unexpected non-memref RHS type");
Type elem_ty = result_ty.getElementType(); Type elem_ty = result_ty.getElementType();
NGRAPH_ASSERT(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType()) NGRAPH_CHECK(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType(),
<< "Types mismatch in DotOp"; "Types mismatch in DotOp");
// Create the following loop nest for matmul operation: // Create the following loop nest for matmul operation:
// for(n, N, 1) // for(n, N, 1)
...@@ -368,8 +368,8 @@ namespace ...@@ -368,8 +368,8 @@ namespace
MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs); MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs);
NGRAPH_ASSERT(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2) NGRAPH_CHECK(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2,
<< "Dot operation is only supported for 2D tensors"; "Dot operation is only supported for 2D tensors");
// Create induction variables, lower bounds, upper bounds and steps of the loop nest. // Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order", // It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
......
...@@ -65,7 +65,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -65,7 +65,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for (size_t i = 0, end = ck_outputs.size(); i < end; ++i) for (size_t i = 0, end = ck_outputs.size(); i < end; ++i)
{ {
auto& output_descs = ck_outputs[i]->get_outputs(); auto& output_descs = ck_outputs[i]->get_outputs();
NGRAPH_ASSERT(output_descs.size() == 1) << "Unexpected multiple output descriptors"; NGRAPH_CHECK(output_descs.size() == 1, "Unexpected multiple output descriptors");
auto& out_desc = output_descs[0]; auto& out_desc = output_descs[0];
// 'replace_output' invalidates iterator of the original container. Use a copy instead. // 'replace_output' invalidates iterator of the original container. Use a copy instead.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment