Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e4e3456d
Commit
e4e3456d
authored
Jun 13, 2019
by
nmostafa
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace NGRAPH_ASSRT/FAIL. Minor fix in Ops.td
parent
4d24d157
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
62 additions
and
61 deletions
+62
-61
compiler.cpp
src/contrib/mlir/compiler.cpp
+24
-24
dialect.cpp
src/contrib/mlir/dialect/dialect.cpp
+2
-1
dialect.hpp
src/contrib/mlir/dialect/dialect.hpp
+2
-2
ops.td
src/contrib/mlir/dialect/ops.td
+13
-13
type.cpp
src/contrib/mlir/dialect/type.cpp
+4
-4
type.hpp
src/contrib/mlir/dialect/type.hpp
+3
-3
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+13
-13
mlir_subgraph_extraction.cpp
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
+1
-1
No files found.
src/contrib/mlir/compiler.cpp
View file @
e4e3456d
...
...
@@ -60,9 +60,9 @@ MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
:
m_compiled_kernel
(
compiled_kernel
)
,
m_external_tensors
(
external_tensors
)
{
NGRAPH_
ASSERT
((
m_compiled_kernel
->
get_arguments
().
size
()
+
m_compiled_kernel
->
get_kernel_outputs
().
size
())
==
external_tensors
.
size
())
<<
"Number of arguments and outputs doesn't match number of tensors"
;
NGRAPH_
CHECK
((
m_compiled_kernel
->
get_arguments
().
size
()
+
m_compiled_kernel
->
get_kernel_outputs
().
size
())
==
external_tensors
.
size
(),
"Number of arguments and outputs doesn't match number of tensors"
)
;
}
void
MLIRCompiler
::
init_mlir
()
...
...
@@ -103,8 +103,8 @@ void MLIRCompiler::build_ng_dialect_module()
// Retrieve input and output tensors.
const
auto
&
kernel_inputs
=
m_compiled_kernel
->
get_arguments
();
const
auto
&
kernel_outputs
=
m_compiled_kernel
->
get_kernel_outputs
();
NGRAPH_
ASSERT
(
kernel_inputs
.
size
()
!=
0
)
<<
"Cannot have empty inputs list"
;
NGRAPH_
ASSERT
(
kernel_outputs
.
size
()
!=
0
)
<<
"Cannot have empty outputs list"
;
NGRAPH_
CHECK
(
kernel_inputs
.
size
()
!=
0
,
"Cannot have empty inputs list"
)
;
NGRAPH_
CHECK
(
kernel_outputs
.
size
()
!=
0
,
"Cannot have empty outputs list"
)
;
for
(
auto
input
:
kernel_inputs
)
{
...
...
@@ -138,7 +138,7 @@ void MLIRCompiler::build_ng_dialect_module()
m_module
->
getFunctions
().
push_back
(
function
.
release
());
if
(
failed
(
m_module
->
verify
()))
{
NGRAPH_
FAIL
()
<<
"Invalid module after lowering to NG dialect"
;
NGRAPH_
CHECK
(
false
,
"Invalid module after lowering to NG dialect"
)
;
}
dump_mlir_module
(
"nGraph Dialect Dump:"
);
...
...
@@ -170,7 +170,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
default
:
NGRAPH_
FAIL
()
<<
"MLIR: Unsupported NGraph types"
;
break
;
default
:
NGRAPH_
CHECK
(
false
,
"MLIR: Unsupported NGraph types"
)
;
break
;
case
ngraph
:
:
element
::
Type_t
::
bf16
:
return
mlir
::
NGFloatType
::
getBF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f16
:
return
mlir
::
NGFloatType
::
getF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f32
:
return
mlir
::
NGFloatType
::
getF32
(
&
m_context
);
...
...
@@ -185,7 +185,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
case
ngraph
:
:
element
::
Type_t
::
i64
:
return
mlir
::
NGIntegerType
::
getInt64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u64
:
return
mlir
::
NGIntegerType
::
getUInt64
(
&
m_context
);
}
NGRAPH_
FAIL
()
<<
"Unreachable"
;
NGRAPH_
CHECK
(
false
,
"Unreachable"
)
;
return
mlir
::
Type
();
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
...
...
@@ -195,8 +195,8 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
void
MLIRCompiler
::
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
)
{
NGRAPH_
ASSERT
(
m_tensor_to_value_map
.
find
(
tensor
)
==
m_tensor_to_value_map
.
end
())
<<
"tensor value already defined"
;
NGRAPH_
CHECK
(
m_tensor_to_value_map
.
find
(
tensor
)
==
m_tensor_to_value_map
.
end
(),
"tensor value already defined"
)
;
TensorInfo
tensor_info
{
value
};
m_tensor_to_value_map
.
insert
(
TensorToInfo
(
tensor
,
tensor_info
));
}
...
...
@@ -205,7 +205,7 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
{
auto
it
=
m_tensor_to_value_map
.
find
(
tensor
);
NGRAPH_
ASSERT
(
it
!=
m_tensor_to_value_map
.
end
())
<<
"Undefined tensor"
;
NGRAPH_
CHECK
(
it
!=
m_tensor_to_value_map
.
end
(),
"Undefined tensor"
)
;
return
it
->
second
;
}
...
...
@@ -221,7 +221,7 @@ void MLIRCompiler::lower_ng_dialect()
if
(
failed
(
m_module
->
verify
()))
{
NGRAPH_
FAIL
()
<<
"Incorrect module after dialect lowering"
;
NGRAPH_
CHECK
(
false
,
"Incorrect module after dialect lowering"
)
;
}
dump_mlir_module
(
"Affine Dialect Dump:"
);
...
...
@@ -236,7 +236,7 @@ void MLIRCompiler::optimize()
// Lower affine ops
pm
.
addPass
(
mlir
::
createLowerAffinePass
());
auto
rr
=
pm
.
run
(
m_module
.
get
());
NGRAPH_
ASSERT
(
succeeded
(
rr
))
<<
"Affine loop lowering failed"
;
NGRAPH_
CHECK
(
succeeded
(
rr
),
"Affine loop lowering failed"
)
;
dump_mlir_module
(
"Standard Dialect Dump:"
);
}
...
...
@@ -309,10 +309,10 @@ void MLIRCompiler::create_return()
// helpers to be used inside the function.
void
MLIRCompiler
::
bind_arguments
()
{
NGRAPH_
ASSERT
(
m_module
&&
"MLIR module is not ready."
);
NGRAPH_
CHECK
(
m_module
,
"MLIR module is not ready."
);
mlir
::
Function
*
func
=
m_module
->
getNamedFunction
(
"main"
);
NGRAPH_
ASSERT
(
func
&&
!
func
->
getBlocks
().
empty
())
<<
"Function not found"
;
NGRAPH_
CHECK
(
func
&&
!
func
->
getBlocks
().
empty
(),
"Function not found"
)
;
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a
...
...
@@ -321,11 +321,11 @@ void MLIRCompiler::bind_arguments()
// create MemRef args
auto
expected_arguments
=
allocate_memref_args
(
func
);
NGRAPH_
ASSERT
(
expected_arguments
.
size
())
<<
"Arguments can't be created"
;
NGRAPH_
CHECK
(
expected_arguments
.
size
(),
"Arguments can't be created"
)
;
m_invoke_args
=
std
::
move
(
expected_arguments
);
NGRAPH_
ASSERT
(
m_invoke_args
.
size
()
==
m_external_tensors
.
size
())
<<
"Number of external tensors doesn't match number of function arguments"
;
NGRAPH_
CHECK
(
m_invoke_args
.
size
()
==
m_external_tensors
.
size
(),
"Number of external tensors doesn't match number of function arguments"
)
;
// Assign external tensor pointers to invocation arguments.
for
(
size_t
i
=
0
,
num_args
=
m_invoke_args
.
size
();
i
<
num_args
;
++
i
)
...
...
@@ -339,20 +339,20 @@ void MLIRCompiler::bind_arguments()
MLIRMemMgr
**
mem_mgr_arg
=
reinterpret_cast
<
MLIRMemMgr
**>
(
malloc
(
sizeof
(
void
*
)));
*
mem_mgr_arg
=
&
get_mem_mgr
();
// inserting memory manager ptr in right location ?
NGRAPH_
ASSERT
(
m_invoke_args
.
size
()
==
get_mem_mgr_arg_id
(
func
));
NGRAPH_
CHECK
(
m_invoke_args
.
size
()
==
get_mem_mgr_arg_id
(
func
));
m_invoke_args
.
push_back
(
static_cast
<
void
*>
(
mem_mgr_arg
));
}
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void
MLIRCompiler
::
execute
()
{
NGRAPH_
ASSERT
(
m_module
&&
"MLIR module is not ready."
);
NGRAPH_
CHECK
(
m_module
,
"MLIR module is not ready."
);
// Lower Standard dialect to LLVM dialect.
auto
converter
=
mlir
::
createStdToLLVMConverter
();
auto
r
=
converter
->
convert
(
m_module
.
get
());
(
void
)
r
;
NGRAPH_
ASSERT
(
succeeded
(
r
))
<<
"second conversion failed"
;
NGRAPH_
CHECK
(
succeeded
(
r
),
"second conversion failed"
)
;
dump_mlir_module
(
"LLVM-IR Dialect Dump:"
);
...
...
@@ -365,7 +365,7 @@ void MLIRCompiler::execute()
// LLVM optimizations at level 3.
auto
llvm_transformer
=
mlir
::
makeOptimizingTransformer
(
3
/*optLevel*/
,
0
/*sizeLevel*/
);
auto
maybeEngine
=
mlir
::
ExecutionEngine
::
create
(
m_module
.
get
(),
llvm_transformer
);
NGRAPH_
ASSERT
(
maybeEngine
)
<<
"failed to construct an execution engine"
;
NGRAPH_
CHECK
(
maybeEngine
,
"failed to construct an execution engine"
)
;
m_engine
=
std
::
move
(
maybeEngine
.
get
());
// Invoke the JIT-compiled function with the arguments. Note that, for API
...
...
@@ -373,7 +373,7 @@ void MLIRCompiler::execute()
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto
invocationResult
=
m_engine
->
invoke
(
"main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invoke_args
));
NGRAPH_
ASSERT
(
!
invocationResult
)
<<
"JIT invocation of 'main' failed
\n
"
;
NGRAPH_
CHECK
(
!
invocationResult
,
"JIT invocation of 'main' failed
\n
"
)
;
}
void
MLIRCompiler
::
cleanup
()
...
...
@@ -418,7 +418,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type typ
{
return
nullptr
;
}
NGRAPH_
ASSERT
(
memRefType
.
getNumDynamicDims
()
==
0
)
<<
"No support for dynamic shapes"
;
NGRAPH_
CHECK
(
memRefType
.
getNumDynamicDims
()
==
0
,
"No support for dynamic shapes"
)
;
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
...
...
src/contrib/mlir/dialect/dialect.cpp
View file @
e4e3456d
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "dialect.hpp"
#include "ngraph/check.hpp"
#include "ops.hpp"
#include "type.hpp"
...
...
@@ -66,7 +67,7 @@ void NGDialect::printType(mlir::Type type, raw_ostream& os) const
os
<<
"bool"
;
return
;
}
default
:
{
NGRAPH_
ASSERT
(
0
)
<<
"Incorrect type to print?"
;
default
:
{
NGRAPH_
CHECK
(
false
,
"Incorrect type to print?"
)
;
}
}
}
src/contrib/mlir/dialect/dialect.hpp
View file @
e4e3456d
...
...
@@ -22,7 +22,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "ngraph/
assertion
.hpp"
#include "ngraph/
check
.hpp"
namespace
mlir
{
class
NGDialect
:
public
mlir
::
Dialect
...
...
@@ -31,7 +31,7 @@ namespace mlir
explicit
NGDialect
(
mlir
::
MLIRContext
*
ctx
);
mlir
::
Type
parseType
(
llvm
::
StringRef
tyData
,
mlir
::
Location
loc
)
const
override
{
NGRAPH_
ASSERT
(
0
)
<<
"Unsupported type parsing."
;
NGRAPH_
CHECK
(
false
,
"Unsupported type parsing."
)
;
return
mlir
::
Type
();
}
void
printType
(
mlir
::
Type
type
,
llvm
::
raw_ostream
&
os
)
const
override
;
...
...
src/contrib/mlir/dialect/ops.td
View file @
e4e3456d
...
...
@@ -14,15 +14,15 @@
// limitations under the License.
//*****************************************************************************
//
// This is the
N
Graph Dialect operation definition file.
// This is the
n
Graph Dialect operation definition file.
//
//===----------------------------------------------------------------------===//
include "mlir/IR/OpBase.td"
//
N
Graph Dialect operations definitions
//
n
Graph Dialect operations definitions
//
// This files declares
N
Graph operations that table-gen uses to create C++ code
// This files declares
n
Graph operations that table-gen uses to create C++ code
// For more information about tablegen. See https://llvm.org/docs/TableGen/index.html
//
// The output files are ops.h.inc and ops.cpp.inc and are generated at build time
...
...
@@ -44,17 +44,17 @@ def NG_Dialect : Dialect {
}
//
N
Graph Types
// This defines records equivalent to
N
Graph types. It doesn't generate code.
//
n
Graph Types
// This defines records equivalent to
n
Graph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
"
N
Graph Tensor Type">;
"
n
Graph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
//
N
Graph operation base class.
//
n
Graph operation base class.
// Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> :
Op<NG_Dialect, mnemonic, traits> {}
...
...
@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$arg)>
{
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }];
}
...
...
@@ -89,7 +89,7 @@ class NG_Binary_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
}
// Base class for arithmetic binary operations with verifier.
...
...
@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }];
}
...
...
@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }];
}
...
...
@@ -120,7 +120,7 @@ class NG_Ternary_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$op0, NG_TensorType:$op1, NG_TensorType:$op2)>
{
// TODO: Implement
let parser = [{ NGRAPH_
FAIL() << "No parser support"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
}
...
...
@@ -189,7 +189,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
"across the axes of a single tensor.";
let description = "Axes are represented as an array of I64 attributes.";
let parser = [{ NGRAPH_
FAIL() << "Parser not implemented"
; return mlir::failure(); }];
let parser = [{ NGRAPH_
CHECK(false, "No parser support")
; return mlir::failure(); }];
// TODO
let verifier = [{ return verifyAxisReductionOp(this); }];
...
...
src/contrib/mlir/dialect/type.cpp
View file @
e4e3456d
...
...
@@ -45,7 +45,7 @@ unsigned NGIntegerType::getWidth() const
case
NG_U32_TYPE_ID
:
return
32
;
case
NG_I64_TYPE_ID
:
case
NG_U64_TYPE_ID
:
return
64
;
default
:
NGRAPH_
FAIL
()
<<
"Invalid type ID"
;
default
:
NGRAPH_
CHECK
(
false
,
"Invalid type ID"
)
;
}
return
0
;
}
...
...
@@ -62,7 +62,7 @@ bool NGIntegerType::isSigned() const
case
NG_U16_TYPE_ID
:
case
NG_U32_TYPE_ID
:
case
NG_U64_TYPE_ID
:
return
false
;
default
:
NGRAPH_
FAIL
()
<<
"Invalid type ID"
;
default
:
NGRAPH_
CHECK
(
false
,
"Invalid type ID"
)
;
}
return
false
;
}
...
...
@@ -97,8 +97,8 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
for
(
auto
i
=
0
;
i
<
shape
.
size
();
i
++
)
{
NGRAPH_
ASSERT
(
shape
[
i
]
>=
-
1
)
<<
"Invalid tensor shape"
;
NGRAPH_
ASSERT
(
otherShape
[
i
]
>=
-
1
)
<<
"Invalid tensor shape"
;
NGRAPH_
CHECK
(
shape
[
i
]
>=
-
1
,
"Invalid tensor shape"
,
shape
[
i
])
;
NGRAPH_
CHECK
(
otherShape
[
i
]
>=
-
1
,
"Invalid tensor shape"
,
otherShape
[
i
])
;
if
(
shape
[
i
]
==
-
1
||
otherShape
[
i
]
==
-
1
||
shape
[
i
]
==
otherShape
[
i
])
continue
;
...
...
src/contrib/mlir/dialect/type.hpp
View file @
e4e3456d
...
...
@@ -15,7 +15,6 @@
//*****************************************************************************
#pragma once
#include "assertion.hpp"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
...
...
@@ -23,6 +22,7 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "ngraph/check.hpp"
namespace
mlir
{
using
llvm
::
raw_ostream
;
...
...
@@ -60,7 +60,7 @@ namespace mlir
static
NGIntegerType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
{
NGRAPH_
ASSERT
(
kindof
(
kind
))
<<
"Not an integer kind."
;
NGRAPH_
CHECK
(
kindof
(
kind
),
"Not an integer kind."
)
;
return
Base
::
get
(
context
,
kind
);
}
/// Create signed Int8
...
...
@@ -154,7 +154,7 @@ namespace mlir
using
Base
::
Base
;
static
NGBoolType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
{
NGRAPH_
ASSERT
(
kindof
(
kind
))
<<
"Not a bool type."
;
NGRAPH_
CHECK
(
kindof
(
kind
),
"Not a bool type."
)
;
return
Base
::
get
(
context
,
kind
);
}
...
...
src/contrib/mlir/lowerer.cpp
View file @
e4e3456d
...
...
@@ -133,8 +133,8 @@ namespace
op
->
setAttr
(
"graphOutputIdx"
,
mlir
::
IntegerAttr
::
get
(
IntegerType
::
get
(
8
,
op
->
getContext
()),
i
));
}
NGRAPH_
ASSERT
(
outputCount
==
0
||
outputCount
==
ret
.
getNumOperands
())
<<
"Inconsistent returns in function"
;
NGRAPH_
CHECK
(
outputCount
==
0
||
outputCount
==
ret
.
getNumOperands
(),
"Inconsistent returns in function"
)
;
outputCount
=
ret
.
getNumOperands
();
});
// will be populated with lowered output values later
...
...
@@ -232,7 +232,7 @@ namespace
for
(
auto
value
:
m_loweredOutputValues
)
{
auto
op
=
value
->
getDefiningOp
();
NGRAPH_
ASSERT
(
isa
<
NGFakeInputOp
>
(
op
))
<<
"output value not defined by fake output?"
;
NGRAPH_
CHECK
(
isa
<
NGFakeInputOp
>
(
op
),
"output value not defined by fake output?"
)
;
value
->
replaceAllUsesWith
(
entryBlock
->
getArgument
(
oldFuncType
.
getNumInputs
()
+
i
));
op
->
erase
();
i
++
;
...
...
@@ -289,7 +289,7 @@ namespace
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
bool_type
.
getContext
());
}
NGRAPH_
FAIL
()
<<
"Unsupported type to lower"
;
NGRAPH_
CHECK
(
false
,
"Unsupported type to lower"
)
;
return
type
;
}
...
...
@@ -305,7 +305,7 @@ namespace
auto
loc
=
add
.
getLoc
();
auto
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
NGRAPH_
ASSERT
(
result
->
getType
().
isa
<
MemRefType
>
());
NGRAPH_
CHECK
(
result
->
getType
().
isa
<
MemRefType
>
());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
...
...
@@ -346,18 +346,18 @@ namespace
Value
*
lhs
=
operands
[
0
];
Value
*
rhs
=
operands
[
1
];
Value
*
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
NGRAPH_
ASSERT
(
lhs
&&
rhs
&&
result
)
<<
"Unexpected null values in DotOp"
;
NGRAPH_
CHECK
(
lhs
&&
rhs
&&
result
,
"Unexpected null values in DotOp"
)
;
auto
result_ty
=
result
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
lhs_ty
=
lhs
->
getType
().
dyn_cast
<
MemRefType
>
();
auto
rhs_ty
=
rhs
->
getType
().
dyn_cast
<
MemRefType
>
();
NGRAPH_
ASSERT
(
result_ty
)
<<
"Unexpected non-memref result type"
;
NGRAPH_
ASSERT
(
lhs_ty
)
<<
"Unexpected non-memref LHS type"
;
NGRAPH_
ASSERT
(
rhs_ty
)
<<
"Unexpected non-memref RHS type"
;
NGRAPH_
CHECK
(
result_ty
,
"Unexpected non-memref result type"
)
;
NGRAPH_
CHECK
(
lhs_ty
,
"Unexpected non-memref LHS type"
)
;
NGRAPH_
CHECK
(
rhs_ty
,
"Unexpected non-memref RHS type"
)
;
Type
elem_ty
=
result_ty
.
getElementType
();
NGRAPH_
ASSERT
(
elem_ty
==
lhs_ty
.
getElementType
()
&&
elem_ty
==
rhs_ty
.
getElementType
())
<<
"Types mismatch in DotOp"
;
NGRAPH_
CHECK
(
elem_ty
==
lhs_ty
.
getElementType
()
&&
elem_ty
==
rhs_ty
.
getElementType
(),
"Types mismatch in DotOp"
)
;
// Create the following loop nest for matmul operation:
// for(n, N, 1)
...
...
@@ -368,8 +368,8 @@ namespace
MemRefView
v_res
(
result
),
v_lhs
(
lhs
),
v_rhs
(
rhs
);
NGRAPH_
ASSERT
(
v_lhs
.
rank
()
==
2
&&
v_rhs
.
rank
()
==
2
&&
v_res
.
rank
()
==
2
)
<<
"Dot operation is only supported for 2D tensors"
;
NGRAPH_
CHECK
(
v_lhs
.
rank
()
==
2
&&
v_rhs
.
rank
()
==
2
&&
v_res
.
rank
()
==
2
,
"Dot operation is only supported for 2D tensors"
)
;
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
...
...
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
View file @
e4e3456d
...
...
@@ -65,7 +65,7 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
for
(
size_t
i
=
0
,
end
=
ck_outputs
.
size
();
i
<
end
;
++
i
)
{
auto
&
output_descs
=
ck_outputs
[
i
]
->
get_outputs
();
NGRAPH_
ASSERT
(
output_descs
.
size
()
==
1
)
<<
"Unexpected multiple output descriptors"
;
NGRAPH_
CHECK
(
output_descs
.
size
()
==
1
,
"Unexpected multiple output descriptors"
)
;
auto
&
out_desc
=
output_descs
[
0
];
// 'replace_output' invalidates iterator of the original container. Use a copy instead.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment