Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e4e3456d
Commit
e4e3456d
authored
Jun 13, 2019
by
nmostafa
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace NGRAPH_ASSRT/FAIL. Minor fix in Ops.td
parent
4d24d157
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
62 additions
and
61 deletions
+62
-61
compiler.cpp
src/contrib/mlir/compiler.cpp
+24
-24
dialect.cpp
src/contrib/mlir/dialect/dialect.cpp
+2
-1
dialect.hpp
src/contrib/mlir/dialect/dialect.hpp
+2
-2
ops.td
src/contrib/mlir/dialect/ops.td
+13
-13
type.cpp
src/contrib/mlir/dialect/type.cpp
+4
-4
type.hpp
src/contrib/mlir/dialect/type.hpp
+3
-3
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+13
-13
mlir_subgraph_extraction.cpp
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
+1
-1
No files found.
src/contrib/mlir/compiler.cpp
View file @
e4e3456d
...
@@ -60,9 +60,9 @@ MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
...
@@ -60,9 +60,9 @@ MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
:
m_compiled_kernel
(
compiled_kernel
)
:
m_compiled_kernel
(
compiled_kernel
)
,
m_external_tensors
(
external_tensors
)
,
m_external_tensors
(
external_tensors
)
{
{
NGRAPH_
ASSERT
((
m_compiled_kernel
->
get_arguments
().
size
()
+
NGRAPH_
CHECK
((
m_compiled_kernel
->
get_arguments
().
size
()
+
m_compiled_kernel
->
get_kernel_outputs
().
size
())
==
external_tensors
.
size
())
m_compiled_kernel
->
get_kernel_outputs
().
size
())
==
external_tensors
.
size
(),
<<
"Number of arguments and outputs doesn't match number of tensors"
;
"Number of arguments and outputs doesn't match number of tensors"
)
;
}
}
void
MLIRCompiler
::
init_mlir
()
void
MLIRCompiler
::
init_mlir
()
...
@@ -103,8 +103,8 @@ void MLIRCompiler::build_ng_dialect_module()
...
@@ -103,8 +103,8 @@ void MLIRCompiler::build_ng_dialect_module()
// Retrieve input and output tensors.
// Retrieve input and output tensors.
const
auto
&
kernel_inputs
=
m_compiled_kernel
->
get_arguments
();
const
auto
&
kernel_inputs
=
m_compiled_kernel
->
get_arguments
();
const
auto
&
kernel_outputs
=
m_compiled_kernel
->
get_kernel_outputs
();
const
auto
&
kernel_outputs
=
m_compiled_kernel
->
get_kernel_outputs
();
NGRAPH_
ASSERT
(
kernel_inputs
.
size
()
!=
0
)
<<
"Cannot have empty inputs list"
;
NGRAPH_
CHECK
(
kernel_inputs
.
size
()
!=
0
,
"Cannot have empty inputs list"
)
;
NGRAPH_
ASSERT
(
kernel_outputs
.
size
()
!=
0
)
<<
"Cannot have empty outputs list"
;
NGRAPH_
CHECK
(
kernel_outputs
.
size
()
!=
0
,
"Cannot have empty outputs list"
)
;
for
(
auto
input
:
kernel_inputs
)
for
(
auto
input
:
kernel_inputs
)
{
{
...
@@ -138,7 +138,7 @@ void MLIRCompiler::build_ng_dialect_module()
...
@@ -138,7 +138,7 @@ void MLIRCompiler::build_ng_dialect_module()
m_module
->
getFunctions
().
push_back
(
function
.
release
());
m_module
->
getFunctions
().
push_back
(
function
.
release
());
if
(
failed
(
m_module
->
verify
()))
if
(
failed
(
m_module
->
verify
()))
{
{
NGRAPH_
FAIL
()
<<
"Invalid module after lowering to NG dialect"
;
NGRAPH_
CHECK
(
false
,
"Invalid module after lowering to NG dialect"
)
;
}
}
dump_mlir_module
(
"nGraph Dialect Dump:"
);
dump_mlir_module
(
"nGraph Dialect Dump:"
);
...
@@ -170,7 +170,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
...
@@ -170,7 +170,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
{
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
default
:
NGRAPH_
FAIL
()
<<
"MLIR: Unsupported NGraph types"
;
break
;
default
:
NGRAPH_
CHECK
(
false
,
"MLIR: Unsupported NGraph types"
)
;
break
;
case
ngraph
:
:
element
::
Type_t
::
bf16
:
return
mlir
::
NGFloatType
::
getBF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
bf16
:
return
mlir
::
NGFloatType
::
getBF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f16
:
return
mlir
::
NGFloatType
::
getF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f16
:
return
mlir
::
NGFloatType
::
getF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f32
:
return
mlir
::
NGFloatType
::
getF32
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f32
:
return
mlir
::
NGFloatType
::
getF32
(
&
m_context
);
...
@@ -185,7 +185,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
...
@@ -185,7 +185,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
case
ngraph
:
:
element
::
Type_t
::
i64
:
return
mlir
::
NGIntegerType
::
getInt64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i64
:
return
mlir
::
NGIntegerType
::
getInt64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u64
:
return
mlir
::
NGIntegerType
::
getUInt64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u64
:
return
mlir
::
NGIntegerType
::
getUInt64
(
&
m_context
);
}
}
NGRAPH_
FAIL
()
<<
"Unreachable"
;
NGRAPH_
CHECK
(
false
,
"Unreachable"
)
;
return
mlir
::
Type
();
return
mlir
::
Type
();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...
@@ -195,8 +195,8 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
...
@@ -195,8 +195,8 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
void
MLIRCompiler
::
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
)
void
MLIRCompiler
::
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
)
{
{
NGRAPH_
ASSERT
(
m_tensor_to_value_map
.
find
(
tensor
)
==
m_tensor_to_value_map
.
end
())
NGRAPH_
CHECK
(
m_tensor_to_value_map
.
find
(
tensor
)
==
m_tensor_to_value_map
.
end
(),
<<
"tensor value already defined"
;
"tensor value already defined"
)
;
TensorInfo
tensor_info
{
value
};
TensorInfo
tensor_info
{
value
};
m_tensor_to_value_map
.
insert
(
TensorToInfo
(
tensor
,
tensor_info
));
m_tensor_to_value_map
.
insert
(
TensorToInfo
(
tensor
,
tensor_info
));
}
}
...
@@ -205,7 +205,7 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
...
@@ -205,7 +205,7 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
{
{
auto
it
=
m_tensor_to_value_map
.
find
(
tensor
);
auto
it
=
m_tensor_to_value_map
.
find
(
tensor
);
NGRAPH_
ASSERT
(
it
!=
m_tensor_to_value_map
.
end
())
<<
"Undefined tensor"
;
NGRAPH_
CHECK
(
it
!=
m_tensor_to_value_map
.
end
(),
"Undefined tensor"
)
;
return
it
->
second
;
return
it
->
second
;
}
}
...
@@ -221,7 +221,7 @@ void MLIRCompiler::lower_ng_dialect()
...
@@ -221,7 +221,7 @@ void MLIRCompiler::lower_ng_dialect()
if
(
failed
(
m_module
->
verify
()))
if
(
failed
(
m_module
->
verify
()))
{
{
NGRAPH_
FAIL
()
<<
"Incorrect module after dialect lowering"
;
NGRAPH_
CHECK
(
false
,
"Incorrect module after dialect lowering"
)
;
}
}
dump_mlir_module
(
"Affine Dialect Dump:"
);
dump_mlir_module
(
"Affine Dialect Dump:"
);
...
@@ -236,7 +236,7 @@ void MLIRCompiler::optimize()
...
@@ -236,7 +236,7 @@ void MLIRCompiler::optimize()
// Lower affine ops
// Lower affine ops
pm
.
addPass
(
mlir
::
createLowerAffinePass
());
pm
.
addPass
(
mlir
::
createLowerAffinePass
());
auto
rr
=
pm
.
run
(
m_module
.
get
());
auto
rr
=
pm
.
run
(
m_module
.
get
());
NGRAPH_
ASSERT
(
succeeded
(
rr
))
<<
"Affine loop lowering failed"
;
NGRAPH_
CHECK
(
succeeded
(
rr
),
"Affine loop lowering failed"
)
;
dump_mlir_module
(
"Standard Dialect Dump:"
);
dump_mlir_module
(
"Standard Dialect Dump:"
);
}
}
...
@@ -309,10 +309,10 @@ void MLIRCompiler::create_return()
...
@@ -309,10 +309,10 @@ void MLIRCompiler::create_return()
// helpers to be used inside the function.
// helpers to be used inside the function.
void
MLIRCompiler
::
bind_arguments
()
void
MLIRCompiler
::
bind_arguments
()
{
{
NGRAPH_
ASSERT
(
m_module
&&
"MLIR module is not ready."
);
NGRAPH_
CHECK
(
m_module
,
"MLIR module is not ready."
);
mlir
::
Function
*
func
=
m_module
->
getNamedFunction
(
"main"
);
mlir
::
Function
*
func
=
m_module
->
getNamedFunction
(
"main"
);
NGRAPH_
ASSERT
(
func
&&
!
func
->
getBlocks
().
empty
())
<<
"Function not found"
;
NGRAPH_
CHECK
(
func
&&
!
func
->
getBlocks
().
empty
(),
"Function not found"
)
;
// Create list with a type-erased double pointer for each invocation arguments.
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a
// We currently use 'allocateMemRefArguments', which creates a
...
@@ -321,11 +321,11 @@ void MLIRCompiler::bind_arguments()
...
@@ -321,11 +321,11 @@ void MLIRCompiler::bind_arguments()
// create MemRef args
// create MemRef args
auto
expected_arguments
=
allocate_memref_args
(
func
);
auto
expected_arguments
=
allocate_memref_args
(
func
);
NGRAPH_
ASSERT
(
expected_arguments
.
size
())
<<
"Arguments can't be created"
;
NGRAPH_
CHECK
(
expected_arguments
.
size
(),
"Arguments can't be created"
)
;
m_invoke_args
=
std
::
move
(
expected_arguments
);
m_invoke_args
=
std
::
move
(
expected_arguments
);
NGRAPH_
ASSERT
(
m_invoke_args
.
size
()
==
m_external_tensors
.
size
())
NGRAPH_
CHECK
(
m_invoke_args
.
size
()
==
m_external_tensors
.
size
(),
<<
"Number of external tensors doesn't match number of function arguments"
;
"Number of external tensors doesn't match number of function arguments"
)
;
// Assign external tensor pointers to invocation arguments.
// Assign external tensor pointers to invocation arguments.
for
(
size_t
i
=
0
,
num_args
=
m_invoke_args
.
size
();
i
<
num_args
;
++
i
)
for
(
size_t
i
=
0
,
num_args
=
m_invoke_args
.
size
();
i
<
num_args
;
++
i
)
...
@@ -339,20 +339,20 @@ void MLIRCompiler::bind_arguments()
...
@@ -339,20 +339,20 @@ void MLIRCompiler::bind_arguments()
MLIRMemMgr
**
mem_mgr_arg
=
reinterpret_cast
<
MLIRMemMgr
**>
(
malloc
(
sizeof
(
void
*
)));
MLIRMemMgr
**
mem_mgr_arg
=
reinterpret_cast
<
MLIRMemMgr
**>
(
malloc
(
sizeof
(
void
*
)));
*
mem_mgr_arg
=
&
get_mem_mgr
();
*
mem_mgr_arg
=
&
get_mem_mgr
();
// inserting memory manager ptr in right location ?
// inserting memory manager ptr in right location ?
NGRAPH_
ASSERT
(
m_invoke_args
.
size
()
==
get_mem_mgr_arg_id
(
func
));
NGRAPH_
CHECK
(
m_invoke_args
.
size
()
==
get_mem_mgr_arg_id
(
func
));
m_invoke_args
.
push_back
(
static_cast
<
void
*>
(
mem_mgr_arg
));
m_invoke_args
.
push_back
(
static_cast
<
void
*>
(
mem_mgr_arg
));
}
}
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void
MLIRCompiler
::
execute
()
void
MLIRCompiler
::
execute
()
{
{
NGRAPH_
ASSERT
(
m_module
&&
"MLIR module is not ready."
);
NGRAPH_
CHECK
(
m_module
,
"MLIR module is not ready."
);
// Lower Standard dialect to LLVM dialect.
// Lower Standard dialect to LLVM dialect.
auto
converter
=
mlir
::
createStdToLLVMConverter
();
auto
converter
=
mlir
::
createStdToLLVMConverter
();
auto
r
=
converter
->
convert
(
m_module
.
get
());
auto
r
=
converter
->
convert
(
m_module
.
get
());
(
void
)
r
;
(
void
)
r
;
NGRAPH_
ASSERT
(
succeeded
(
r
))
<<
"second conversion failed"
;
NGRAPH_
CHECK
(
succeeded
(
r
),
"second conversion failed"
)
;
dump_mlir_module
(
"LLVM-IR Dialect Dump:"
);
dump_mlir_module
(
"LLVM-IR Dialect Dump:"
);
...
@@ -365,7 +365,7 @@ void MLIRCompiler::execute()
...
@@ -365,7 +365,7 @@ void MLIRCompiler::execute()
// LLVM optimizations at level 3.
// LLVM optimizations at level 3.
auto
llvm_transformer
=
mlir
::
makeOptimizingTransformer
(
3
/*optLevel*/
,
0
/*sizeLevel*/
);
auto
llvm_transformer
=
mlir
::
makeOptimizingTransformer
(
3
/*optLevel*/
,
0
/*sizeLevel*/
);
auto
maybeEngine
=
mlir
::
ExecutionEngine
::
create
(
m_module
.
get
(),
llvm_transformer
);
auto
maybeEngine
=
mlir
::
ExecutionEngine
::
create
(
m_module
.
get
(),
llvm_transformer
);
NGRAPH_
ASSERT
(
maybeEngine
)
<<
"failed to construct an execution engine"
;
NGRAPH_
CHECK
(
maybeEngine
,
"failed to construct an execution engine"
)
;
m_engine
=
std
::
move
(
maybeEngine
.
get
());
m_engine
=
std
::
move
(
maybeEngine
.
get
());
// Invoke the JIT-compiled function with the arguments. Note that, for API
// Invoke the JIT-compiled function with the arguments. Note that, for API
...
@@ -373,7 +373,7 @@ void MLIRCompiler::execute()
...
@@ -373,7 +373,7 @@ void MLIRCompiler::execute()
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
// Make sure the MutableArrayRef version is invoked.
auto
invocationResult
=
m_engine
->
invoke
(
"main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invoke_args
));
auto
invocationResult
=
m_engine
->
invoke
(
"main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invoke_args
));
NGRAPH_
ASSERT
(
!
invocationResult
)
<<
"JIT invocation of 'main' failed
\n
"
;
NGRAPH_
CHECK
(
!
invocationResult
,
"JIT invocation of 'main' failed
\n
"
)
;
}
}
void
MLIRCompiler
::
cleanup
()
void
MLIRCompiler
::
cleanup
()
...
@@ -418,7 +418,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type typ
...
@@ -418,7 +418,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type typ
{
{
return
nullptr
;
return
nullptr
;
}
}
NGRAPH_
ASSERT
(
memRefType
.
getNumDynamicDims
()
==
0
)
<<
"No support for dynamic shapes"
;
NGRAPH_
CHECK
(
memRefType
.
getNumDynamicDims
()
==
0
,
"No support for dynamic shapes"
)
;
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
// We should expand this with different types and dynamic MemRefs
...
...
src/contrib/mlir/dialect/dialect.cpp
View file @
e4e3456d
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include "dialect.hpp"
#include "dialect.hpp"
#include "ngraph/check.hpp"
#include "ops.hpp"
#include "ops.hpp"
#include "type.hpp"
#include "type.hpp"
...
@@ -66,7 +67,7 @@ void NGDialect::printType(mlir::Type type, raw_ostream& os) const
...
@@ -66,7 +67,7 @@ void NGDialect::printType(mlir::Type type, raw_ostream& os) const
os
<<
"bool"
;
os
<<
"bool"
;
return
;
return
;
}
}
default
:
{
NGRAPH_
ASSERT
(
0
)
<<
"Incorrect type to print?"
;
default
:
{
NGRAPH_
CHECK
(
false
,
"Incorrect type to print?"
)
;
}
}
}
}
}
}
src/contrib/mlir/dialect/dialect.hpp
View file @
e4e3456d
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Types.h"
#include "ngraph/
assertion
.hpp"
#include "ngraph/
check
.hpp"
namespace
mlir
namespace
mlir
{
{
class
NGDialect
:
public
mlir
::
Dialect
class
NGDialect
:
public
mlir
::
Dialect
...
@@ -31,7 +31,7 @@ namespace mlir
...
@@ -31,7 +31,7 @@ namespace mlir
explicit
NGDialect
(
mlir
::
MLIRContext
*
ctx
);
explicit
NGDialect
(
mlir
::
MLIRContext
*
ctx
);
mlir
::
Type
parseType
(
llvm
::
StringRef
tyData
,
mlir
::
Location
loc
)
const
override
mlir
::
Type
parseType
(
llvm
::
StringRef
tyData
,
mlir
::
Location
loc
)
const
override
{
{
NGRAPH_
ASSERT
(
0
)
<<
"Unsupported type parsing."
;
NGRAPH_
CHECK
(
false
,
"Unsupported type parsing."
)
;
return
mlir
::
Type
();
return
mlir
::
Type
();
}
}
void
printType
(
mlir
::
Type
type
,
llvm
::
raw_ostream
&
os
)
const
override
;
void
printType
(
mlir
::
Type
type
,
llvm
::
raw_ostream
&
os
)
const
override
;
...
...
src/contrib/mlir/dialect/ops.td
View file @
e4e3456d
...
@@ -14,15 +14,15 @@
...
@@ -14,15 +14,15 @@
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
//
//
// This is the
N
Graph Dialect operation definition file.
// This is the
n
Graph Dialect operation definition file.
//
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
include "mlir/IR/OpBase.td"
include "mlir/IR/OpBase.td"
//
N
Graph Dialect operations definitions
//
n
Graph Dialect operations definitions
//
//
// This files declares
N
Graph operations that table-gen uses to create C++ code
// This files declares
n
Graph operations that table-gen uses to create C++ code
// For more information about tablegen. See https://llvm.org/docs/TableGen/index.html
// For more information about tablegen. See https://llvm.org/docs/TableGen/index.html
//
//
// The output files are ops.h.inc and ops.cpp.inc and are generated at build time
// The output files are ops.h.inc and ops.cpp.inc and are generated at build time
...
@@ -44,17 +44,17 @@ def NG_Dialect : Dialect {
...
@@ -44,17 +44,17 @@ def NG_Dialect : Dialect {
}
}
//
N
Graph Types
//
n
Graph Types
// This defines records equivalent to
N
Graph types. It doesn't generate code.
// This defines records equivalent to
n
Graph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs.
// This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification
// Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
"
N
Graph Tensor Type">;
"
n
Graph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
//
N
Graph operation base class.
//
n
Graph operation base class.
// Prepends "ng." to operation name
// Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_Op<string mnemonic, list<OpTrait> traits = []> :
Op<NG_Dialect, mnemonic, traits> {}
Op<NG_Dialect, mnemonic, traits> {}
...
@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
...
@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$arg)>
Arguments<(ins NG_TensorType:$arg)>
{
{
// TODO: Implement
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }];
let verifier = [{ return verifyUnaryArithOp(this); }];
}
}
...
@@ -89,7 +89,7 @@ class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
...
@@ -89,7 +89,7 @@ class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
{
// TODO: Implement
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
}
}
// Base class for arithmetic binary operations with verifier.
// Base class for arithmetic binary operations with verifier.
...
@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
...
@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
{
// TODO: Implement
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }];
let verifier = [{ return verifyBinaryArithOp(this); }];
}
}
...
@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
...
@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
{
// TODO: Implement
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }];
let verifier = [{ return verifyCmpOp(this); }];
}
}
...
@@ -120,7 +120,7 @@ class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
...
@@ -120,7 +120,7 @@ class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
{
{
// TODO: Implement
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
}
}
...
@@ -189,7 +189,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
...
@@ -189,7 +189,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
"across the axes of a single tensor.";
"across the axes of a single tensor.";
let description = "Axes are represented as an array of I64 attributes.";
let description = "Axes are represented as an array of I64 attributes.";
let parser = [{ NGRAPH_
FAIL() << "Parser not implemented"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
// TODO
// TODO
let verifier = [{ return verifyAxisReductionOp(this); }];
let verifier = [{ return verifyAxisReductionOp(this); }];
...
...
src/contrib/mlir/dialect/type.cpp
View file @
e4e3456d
...
@@ -45,7 +45,7 @@ unsigned NGIntegerType::getWidth() const
...
@@ -45,7 +45,7 @@ unsigned NGIntegerType::getWidth() const
case
NG_U32_TYPE_ID
:
return
32
;
case
NG_U32_TYPE_ID
:
return
32
;
case
NG_I64_TYPE_ID
:
case
NG_I64_TYPE_ID
:
case
NG_U64_TYPE_ID
:
return
64
;
case
NG_U64_TYPE_ID
:
return
64
;
default
:
NGRAPH_
FAIL
()
<<
"Invalid type ID"
;
default
:
NGRAPH_
CHECK
(
false
,
"Invalid type ID"
)
;
}
}
return
0
;
return
0
;
}
}
...
@@ -62,7 +62,7 @@ bool NGIntegerType::isSigned() const
...
@@ -62,7 +62,7 @@ bool NGIntegerType::isSigned() const
case
NG_U16_TYPE_ID
:
case
NG_U16_TYPE_ID
:
case
NG_U32_TYPE_ID
:
case
NG_U32_TYPE_ID
:
case
NG_U64_TYPE_ID
:
return
false
;
case
NG_U64_TYPE_ID
:
return
false
;
default
:
NGRAPH_
FAIL
()
<<
"Invalid type ID"
;
default
:
NGRAPH_
CHECK
(
false
,
"Invalid type ID"
)
;
}
}
return
false
;
return
false
;
}
}
...
@@ -97,8 +97,8 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
...
@@ -97,8 +97,8 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
for
(
auto
i
=
0
;
i
<
shape
.
size
();
i
++
)
for
(
auto
i
=
0
;
i
<
shape
.
size
();
i
++
)
{
{
NGRAPH_
ASSERT
(
shape
[
i
]
>=
-
1
)
<<
"Invalid tensor shape"
;
NGRAPH_
CHECK
(
shape
[
i
]
>=
-
1
,
"Invalid tensor shape"
,
shape
[
i
])
;
NGRAPH_
ASSERT
(
otherShape
[
i
]
>=
-
1
)
<<
"Invalid tensor shape"
;
NGRAPH_
CHECK
(
otherShape
[
i
]
>=
-
1
,
"Invalid tensor shape"
,
otherShape
[
i
])
;
if
(
shape
[
i
]
==
-
1
||
otherShape
[
i
]
==
-
1
||
shape
[
i
]
==
otherShape
[
i
])
if
(
shape
[
i
]
==
-
1
||
otherShape
[
i
]
==
-
1
||
shape
[
i
]
==
otherShape
[
i
])
continue
;
continue
;
...
...
src/contrib/mlir/dialect/type.hpp
View file @
e4e3456d
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
//*****************************************************************************
//*****************************************************************************
#pragma once
#pragma once
#include "assertion.hpp"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpDefinition.h"
...
@@ -23,6 +22,7 @@
...
@@ -23,6 +22,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Types.h"
#include "ngraph/check.hpp"
namespace
mlir
namespace
mlir
{
{
using
llvm
::
raw_ostream
;
using
llvm
::
raw_ostream
;
...
@@ -60,7 +60,7 @@ namespace mlir
...
@@ -60,7 +60,7 @@ namespace mlir
static
NGIntegerType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
static
NGIntegerType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
{
{
NGRAPH_
ASSERT
(
kindof
(
kind
))
<<
"Not an integer kind."
;
NGRAPH_
CHECK
(
kindof
(
kind
),
"Not an integer kind."
)
;
return
Base
::
get
(
context
,
kind
);
return
Base
::
get
(
context
,
kind
);
}
}
/// Create signed Int8
/// Create signed Int8
...
@@ -154,7 +154,7 @@ namespace mlir
...
@@ -154,7 +154,7 @@ namespace mlir
using
Base
::
Base
;
using
Base
::
Base
;
static
NGBoolType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
static
NGBoolType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
{
{
NGRAPH_
ASSERT
(
kindof
(
kind
))
<<
"Not a bool type."
;
NGRAPH_
CHECK
(
kindof
(
kind
),
"Not a bool type."
)
;
return
Base
::
get
(
context
,
kind
);
return
Base
::
get
(
context
,
kind
);
}
}
...
...
src/contrib/mlir/lowerer.cpp
View file @
e4e3456d
...
@@ -133,8 +133,8 @@ namespace
...
@@ -133,8 +133,8 @@ namespace
op
->
setAttr
(
"graphOutputIdx"
,
op
->
setAttr
(
"graphOutputIdx"
,
mlir
::
IntegerAttr
::
get
(
IntegerType
::
get
(
8
,
op
->
getContext
()),
i
));
mlir
::
IntegerAttr
::
get
(
IntegerType
::
get
(
8
,
op
->
getContext
()),
i
));
}
}
NGRAPH_
ASSERT
(
outputCount
==
0
||
outputCount
==
ret
.
getNumOperands
())
NGRAPH_
CHECK
(
outputCount
==
0
||
outputCount
==
ret
.
getNumOperands
(),
<<
"Inconsistent returns in function"
;
"Inconsistent returns in function"
)
;
outputCount
=
ret
.
getNumOperands
();
outputCount
=
ret
.
getNumOperands
();
});
});
// will be populated with lowered output values later
// will be populated with lowered output values later
...
@@ -232,7 +232,7 @@ namespace
...
@@ -232,7 +232,7 @@ namespace
for
(
auto
value
:
m_loweredOutputValues
)
for
(
auto
value
:
m_loweredOutputValues
)
{
{
auto
op
=
value
->
getDefiningOp
();
auto
op
=
value
->
getDefiningOp
();
NGRAPH_
ASSERT
(
isa
<
NGFakeInputOp
>
(
op
))
<<
"output value not defined by fake output?"
;
NGRAPH_
CHECK
(
isa
<
NGFakeInputOp
>
(
op
),
"output value not defined by fake output?"
)
;
value
->
replaceAllUsesWith
(
entryBlock
->
getArgument
(
oldFuncType
.
getNumInputs
()
+
i
));
value
->
replaceAllUsesWith
(
entryBlock
->
getArgument
(
oldFuncType
.
getNumInputs
()
+
i
));
op
->
erase
();
op
->
erase
();
i
++
;
i
++
;
...
@@ -289,7 +289,7 @@ namespace
...
@@ -289,7 +289,7 @@ namespace
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
bool_type
.
getContext
());
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
bool_type
.
getContext
());
}
}
NGRAPH_
FAIL
()
<<
"Unsupported type to lower"
;
NGRAPH_
CHECK
(
false
,
"Unsupported type to lower"
)
;
return
type
;
return
type
;
}
}
...
@@ -305,7 +305,7 @@ namespace
...
@@ -305,7 +305,7 @@ namespace
auto
loc
=
add
.
getLoc
();
auto
loc
=
add
.
getLoc
();
auto
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
auto
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
NGRAPH_
ASSERT
(
result
->
getType
().
isa
<
MemRefType
>
());
NGRAPH_
CHECK
(
result
->
getType
().
isa
<
MemRefType
>
());
// Note that builder's current function is still the original function body.
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
// use getBlock to get the new block instead.
...
@@ -346,18 +346,18 @@ namespace
...
@@ -346,18 +346,18 @@ namespace
Value
*
lhs
=
operands
[
0
];
Value
*
lhs
=
operands
[
0
];
Value
*
rhs
=
operands
[
1
];
Value
*
rhs
=
operands
[
1
];
Value
*
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
Value
*
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
NGRAPH_
ASSERT
(
lhs
&&
rhs
&&
result
)
<<
"Unexpected null values in DotOp"
;
NGRAPH_
CHECK
(
lhs
&&
rhs
&&
result
,
"Unexpected null values in DotOp"
)
;
auto
result_ty
=
result
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
result_ty
=
result
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
lhs_ty
=
lhs
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
lhs_ty
=
lhs
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
rhs_ty
=
rhs
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
rhs_ty
=
rhs
->
getType
().
dyn_cast
<
MemRefType
>
();
NGRAPH_
ASSERT
(
result_ty
)
<<
"Unexpected non-memref result type"
;
NGRAPH_
CHECK
(
result_ty
,
"Unexpected non-memref result type"
)
;
NGRAPH_
ASSERT
(
lhs_ty
)
<<
"Unexpected non-memref LHS type"
;
NGRAPH_
CHECK
(
lhs_ty
,
"Unexpected non-memref LHS type"
)
;
NGRAPH_
ASSERT
(
rhs_ty
)
<<
"Unexpected non-memref RHS type"
;
NGRAPH_
CHECK
(
rhs_ty
,
"Unexpected non-memref RHS type"
)
;
Type
elem_ty
=
result_ty
.
getElementType
();
Type
elem_ty
=
result_ty
.
getElementType
();
NGRAPH_
ASSERT
(
elem_ty
==
lhs_ty
.
getElementType
()
&&
elem_ty
==
rhs_ty
.
getElementType
())
NGRAPH_
CHECK
(
elem_ty
==
lhs_ty
.
getElementType
()
&&
elem_ty
==
rhs_ty
.
getElementType
(),
<<
"Types mismatch in DotOp"
;
"Types mismatch in DotOp"
)
;
// Create the following loop nest for matmul operation:
// Create the following loop nest for matmul operation:
// for(n, N, 1)
// for(n, N, 1)
...
@@ -368,8 +368,8 @@ namespace
...
@@ -368,8 +368,8 @@ namespace
MemRefView
v_res
(
result
),
v_lhs
(
lhs
),
v_rhs
(
rhs
);
MemRefView
v_res
(
result
),
v_lhs
(
lhs
),
v_rhs
(
rhs
);
NGRAPH_
ASSERT
(
v_lhs
.
rank
()
==
2
&&
v_rhs
.
rank
()
==
2
&&
v_res
.
rank
()
==
2
)
NGRAPH_
CHECK
(
v_lhs
.
rank
()
==
2
&&
v_rhs
.
rank
()
==
2
&&
v_res
.
rank
()
==
2
,
<<
"Dot operation is only supported for 2D tensors"
;
"Dot operation is only supported for 2D tensors"
)
;
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
...
...
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
View file @
e4e3456d
...
@@ -65,7 +65,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
...
@@ -65,7 +65,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for
(
size_t
i
=
0
,
end
=
ck_outputs
.
size
();
i
<
end
;
++
i
)
for
(
size_t
i
=
0
,
end
=
ck_outputs
.
size
();
i
<
end
;
++
i
)
{
{
auto
&
output_descs
=
ck_outputs
[
i
]
->
get_outputs
();
auto
&
output_descs
=
ck_outputs
[
i
]
->
get_outputs
();
NGRAPH_
ASSERT
(
output_descs
.
size
()
==
1
)
<<
"Unexpected multiple output descriptors"
;
NGRAPH_
CHECK
(
output_descs
.
size
()
==
1
,
"Unexpected multiple output descriptors"
)
;
auto
&
out_desc
=
output_descs
[
0
];
auto
&
out_desc
=
output_descs
[
0
];
// 'replace_output' invalidates iterator of the original container. Use a copy instead.
// 'replace_output' invalidates iterator of the original container. Use a copy instead.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment