Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8bb48c81
Commit
8bb48c81
authored
May 30, 2019
by
Diego Caballero
Committed by
nmostafa
Jun 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MLIR] Fix NG tensor type lowering (#29)
Element type was not lowered.
parent
4df55e63
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
28 deletions
+21
-28
type.cpp
src/contrib/mlir/dialect/type.cpp
+0
-6
type.hpp
src/contrib/mlir/dialect/type.hpp
+0
-7
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+21
-15
No files found.
src/contrib/mlir/dialect/type.cpp
View file @
8bb48c81
...
...
@@ -73,9 +73,3 @@ NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shap
{
return
Base
::
get
(
context
,
NGTypeKind
::
NG_TENSOR_TYPE_ID
,
eltType
,
shape
);
}
MemRefType
NGTensorType
::
toMemref
()
{
auto
memRefType
=
MemRefType
::
get
(
getShape
(),
getElementType
(),
{
/* no map used */
},
0
);
return
memRefType
;
}
src/contrib/mlir/dialect/type.hpp
View file @
8bb48c81
...
...
@@ -114,9 +114,6 @@ namespace mlir
/// Return the bitwidth of this integer type.
unsigned
getWidth
()
const
;
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir
::
Type
toStdType
()
{
return
mlir
::
IntegerType
::
get
(
getWidth
(),
getContext
());
}
/// Check if signed type
bool
isSigned
()
const
;
...
...
@@ -163,8 +160,6 @@ namespace mlir
static
bool
kindof
(
unsigned
kind
)
{
return
kind
==
NGTypeKind
::
NG_BOOL_TYPE_ID
;
}
static
NGBoolType
get
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NG_BOOL_TYPE_ID
,
ctx
);
}
/// Convert to equivalent std type. Integer of width 1 in that case
mlir
::
Type
toStdType
()
{
return
mlir
::
IntegerType
::
get
(
1
,
getContext
());
}
};
// Note that dialect types don't add new data members, so always possible
...
...
@@ -240,8 +235,6 @@ namespace mlir
// Multiply times element size
return
s
*
llvm
::
divideCeil
(
getElementType
().
getIntOrFloatBitWidth
(),
8
);
}
/// convert to memref native MLIR type. Used for lowering.
mlir
::
MemRefType
toMemref
();
/// create a unique tensor type based on element type and shape.
static
NGTensorType
get
(
mlir
::
MLIRContext
*
context
,
EltType
eltType
,
Shape
shape
);
/// for llvm RTTI
...
...
src/contrib/mlir/lowerer.cpp
View file @
8bb48c81
...
...
@@ -185,7 +185,7 @@ namespace
auto
tensorType
=
origResult
->
getType
().
cast
<
NGTensorType
>
();
auto
callBackFunc
=
getCallDecl
(
"__mlir_allocate"
,
{
rewriter
.
getIndexType
(),
rewriter
.
getIndexType
()},
{
tensorType
.
toMemref
(
)},
{
m_dialectLowerer
.
convertType
(
tensorType
)},
rewriter
);
auto
size
=
tensorType
.
getSizeInBytes
();
...
...
@@ -265,30 +265,36 @@ namespace
return
callBackFuncPtr
;
}
// NGDialect converters
Type
DialectLowerer
::
convertType
(
Type
t
)
Type
DialectLowerer
::
convertType
(
Type
t
ype
)
{
if
(
auto
tensor
=
t
.
dyn_cast
<
NGTensorType
>
())
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since DialectLowerer is private.
if
(
auto
tensor_type
=
type
.
dyn_cast
<
NGTensorType
>
())
{
return
tensor
.
toMemref
();
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return
MemRefType
::
get
(
tensor_type
.
getShape
(),
convertType
(
tensor_type
.
getElementType
()),
{
/* no map used */
},
0
);
}
// element type
if
(
auto
type
=
t
.
dyn_cast
<
NGFloatType
>
())
if
(
auto
float_type
=
type
.
dyn_cast
<
NGFloatType
>
())
{
// Float
// float types are already std type
return
type
;
// Float types are already std type.
return
float_type
;
}
if
(
auto
type
=
t
.
dyn_cast
<
NGIntegerType
>
())
if
(
auto
int_type
=
type
.
dyn_cast
<
NGIntegerType
>
())
{
// map it to std type
return
type
.
toStdType
();
return
mlir
::
IntegerType
::
get
(
int_type
.
getWidth
(),
int_type
.
getContext
());
}
if
(
auto
type
=
t
.
dyn_cast
<
NGBoolType
>
())
if
(
auto
bool_type
=
type
.
dyn_cast
<
NGBoolType
>
())
{
return
type
.
toStdType
(
);
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
bool_type
.
getContext
()
);
}
NGRAPH_FAIL
()
<<
"Unsupported type to lower"
;
return
t
;
return
t
ype
;
}
#define REWRITER(OP) \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment