Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
9bb2fad3
Commit
9bb2fad3
authored
May 20, 2019
by
Nagy Mostafa
Committed by
nmostafa
Jun 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MLIR] Add NG integer type. Map float types to std types
parent
3bd00e23
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
282 additions
and
58 deletions
+282
-58
compiler.cpp
src/contrib/mlir/compiler.cpp
+26
-16
dialect.cpp
src/contrib/mlir/dialect/dialect.cpp
+47
-25
ops.cpp
src/contrib/mlir/dialect/ops.cpp
+10
-10
type.cpp
src/contrib/mlir/dialect/type.cpp
+36
-1
type.hpp
src/contrib/mlir/dialect/type.hpp
+143
-3
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+19
-2
memory_manager.cpp
src/contrib/mlir/memory_manager.cpp
+1
-1
No files found.
src/contrib/mlir/compiler.cpp
View file @
9bb2fad3
...
...
@@ -148,21 +148,32 @@ namespace ngraph
{
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
case
ngraph
:
:
element
::
Type_t
::
boolean
:
case
ngraph
:
:
element
::
Type_t
::
bf16
:
default
:
NGRAPH_ASSERT
(
false
)
<<
"MLIR: Unsupported NGraph types"
;
break
;
case
ngraph
:
:
element
::
Type_t
::
f32
:
return
mlir
::
FloatType
::
getF32
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f64
:
return
mlir
::
FloatType
::
getF64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i8
:
case
ngraph
:
:
element
::
Type_t
::
u8
:
return
mlir
::
IntegerType
::
get
(
8
,
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i16
:
case
ngraph
:
:
element
::
Type_t
::
u16
:
return
mlir
::
IntegerType
::
get
(
16
,
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i32
:
case
ngraph
:
:
element
::
Type_t
::
u32
:
return
mlir
::
IntegerType
::
get
(
32
,
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i64
:
case
ngraph
:
:
element
::
Type_t
::
u64
:
return
mlir
::
IntegerType
::
get
(
64
,
&
m_context
);
default
:
NGRAPH_FAIL
()
<<
"MLIR: Unsupported NGraph types"
;
break
;
case
ngraph
:
:
element
::
Type_t
::
bf16
:
return
NGFloatType
::
getBF16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f32
:
return
NGFloatType
::
getF32
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
f64
:
return
NGFloatType
::
getF64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i8
:
return
NGIntegerType
::
getInt8
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u8
:
case
ngraph
:
:
element
::
Type_t
::
boolean
:
return
NGIntegerType
::
getUInt8
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i16
:
return
NGIntegerType
::
getInt16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u16
:
return
NGIntegerType
::
getInt16
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i32
:
return
NGIntegerType
::
getInt32
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u32
:
return
NGIntegerType
::
getUInt32
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
i64
:
return
NGIntegerType
::
getInt64
(
&
m_context
);
case
ngraph
:
:
element
::
Type_t
::
u64
:
return
NGIntegerType
::
getUInt64
(
&
m_context
);
}
NGRAPH_
ASSERT
(
false
)
<<
"Unreachable"
;
NGRAPH_
FAIL
();
// Unreachable
return
mlir
::
Type
();
}
...
...
@@ -378,8 +389,7 @@ namespace ngraph
auto
memRefType
=
type
.
dyn_cast
<
mlir
::
MemRefType
>
();
if
(
!
memRefType
)
return
nullptr
;
if
(
memRefType
.
getNumDynamicDims
()
!=
0
)
NGRAPH_FAIL
();
NGRAPH_ASSERT
(
memRefType
.
getNumDynamicDims
()
==
0
)
<<
"No support for dynamic shapes"
;
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
...
...
src/contrib/mlir/dialect/dialect.cpp
View file @
9bb2fad3
...
...
@@ -17,36 +17,58 @@
#include "dialect.hpp"
#include "ops.hpp"
#include "type.hpp"
namespace
ngraph
using
namespace
ngraph
::
runtime
::
ngmlir
;
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>();
NGDialect
::
NGDialect
(
mlir
::
MLIRContext
*
ctx
)
:
mlir
::
Dialect
(
"ng"
,
ctx
)
{
using
namespace
runtime
::
ngmlir
;
addTypes
<
NGTensorType
>
();
addTypes
<
NGIntegerType
>
();
addTypes
<
NGBoolType
>
();
addOperations
<
NG_AddOp
>
();
addOperations
<
NG_MatmulBiasOp
>
();
addOperations
<
NG_ReturnOp
>
();
addOperations
<
NG_FakeInput
>
();
}
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>();
NGDialect
::
NGDialect
(
mlir
::
MLIRContext
*
ctx
)
:
mlir
::
Dialect
(
"ng"
,
ctx
)
void
NGDialect
::
printType
(
mlir
::
Type
type
,
raw_ostream
&
os
)
const
{
switch
(
type
.
getKind
())
{
addTypes
<
NGTensorType
>
();
addOperations
<
NG_AddOp
>
();
addOperations
<
NG_MatmulBiasOp
>
();
addOperations
<
NG_ReturnOp
>
();
addOperations
<
NG_FakeInput
>
();
}
void
NGDialect
::
printType
(
mlir
::
Type
type
,
raw_ostream
&
os
)
const
case
NG_TENSOR_TYPE_ID
:
{
auto
arrayTy
=
type
.
dyn_cast
<
NGTensorType
>
();
if
(
!
arrayTy
)
{
NGRAPH_ASSERT
(
0
)
<<
"Incorrect type to print?"
;
}
os
<<
"tensor"
;
if
(
!
arrayTy
.
getShape
().
empty
())
os
<<
"tensor<"
;
auto
tensor_ty
=
type
.
cast
<
NGTensorType
>
();
for
(
auto
dim
:
tensor_ty
.
getShape
())
{
os
<<
"<"
;
mlir
::
interleaveComma
(
arrayTy
.
getShape
(),
os
);
os
<<
">"
;
os
<<
dim
<<
'x'
;
}
os
<<
tensor_ty
.
getElementType
()
<<
'>'
;
return
;
}
case
NG_I8_TYPE_ID
:
case
NG_I16_TYPE_ID
:
case
NG_I32_TYPE_ID
:
case
NG_I64_TYPE_ID
:
case
NG_U8_TYPE_ID
:
case
NG_U16_TYPE_ID
:
case
NG_U32_TYPE_ID
:
case
NG_U64_TYPE_ID
:
{
auto
int_ty
=
type
.
cast
<
NGIntegerType
>
();
os
<<
"i"
<<
int_ty
.
getWidth
();
return
;
}
case
NG_BOOL_TYPE_ID
:
{
os
<<
"bool"
;
return
;
}
default
:
{
NGRAPH_ASSERT
(
0
)
<<
"Incorrect type to print?"
;
}
}
}
src/contrib/mlir/dialect/ops.cpp
View file @
9bb2fad3
...
...
@@ -70,8 +70,8 @@ namespace ngraph
}
void
runtime
::
ngmlir
::
NG_FakeInput
::
build
(
mlir
::
Builder
*
builder
,
mlir
::
OperationState
*
state
,
mlir
::
Type
resultType
)
mlir
::
OperationState
*
state
,
mlir
::
Type
resultType
)
{
state
->
types
.
push_back
(
std
::
move
(
resultType
));
}
...
...
@@ -83,9 +83,9 @@ namespace ngraph
}
void
runtime
::
ngmlir
::
NG_AddOp
::
build
(
mlir
::
Builder
*
builder
,
mlir
::
OperationState
*
state
,
mlir
::
Value
*
lhs
,
mlir
::
Value
*
rhs
)
mlir
::
OperationState
*
state
,
mlir
::
Value
*
lhs
,
mlir
::
Value
*
rhs
)
{
state
->
types
.
push_back
(
lhs
->
getType
());
state
->
operands
.
push_back
(
lhs
);
...
...
@@ -100,9 +100,9 @@ namespace ngraph
}
void
runtime
::
ngmlir
::
NG_MatmulBiasOp
::
build
(
mlir
::
Builder
*
builder
,
mlir
::
OperationState
*
state
,
mlir
::
Value
*
lhs
,
mlir
::
Value
*
rhs
)
mlir
::
OperationState
*
state
,
mlir
::
Value
*
lhs
,
mlir
::
Value
*
rhs
)
{
state
->
types
.
push_back
(
lhs
->
getType
());
state
->
operands
.
push_back
(
lhs
);
...
...
@@ -147,8 +147,8 @@ namespace ngraph
}
void
runtime
::
ngmlir
::
NG_ReturnOp
::
build
(
mlir
::
Builder
*
builder
,
mlir
::
OperationState
*
state
,
std
::
vector
<
mlir
::
Value
*>
value_list
)
mlir
::
OperationState
*
state
,
std
::
vector
<
mlir
::
Value
*>
value_list
)
{
for
(
auto
value
:
value_list
)
{
...
...
src/contrib/mlir/dialect/type.cpp
View file @
9bb2fad3
...
...
@@ -34,11 +34,46 @@ using llvm::Twine;
namespace
ngraph
{
using
namespace
runtime
::
ngmlir
;
unsigned
NGIntegerType
::
getWidth
()
const
{
switch
(
getKind
())
{
case
NG_I8_TYPE_ID
:
case
NG_U8_TYPE_ID
:
return
8
;
case
NG_I16_TYPE_ID
:
case
NG_U16_TYPE_ID
:
return
16
;
case
NG_I32_TYPE_ID
:
case
NG_U32_TYPE_ID
:
return
32
;
case
NG_I64_TYPE_ID
:
case
NG_U64_TYPE_ID
:
return
64
;
default
:
NGRAPH_FAIL
()
<<
"Invalid type ID"
;
}
return
0
;
}
bool
NGIntegerType
::
isSigned
()
const
{
switch
(
getKind
())
{
case
NG_I8_TYPE_ID
:
case
NG_I16_TYPE_ID
:
case
NG_I32_TYPE_ID
:
case
NG_I64_TYPE_ID
:
return
true
;
case
NG_U8_TYPE_ID
:
case
NG_U16_TYPE_ID
:
case
NG_U32_TYPE_ID
:
case
NG_U64_TYPE_ID
:
return
false
;
default
:
NGRAPH_FAIL
()
<<
"Invalid type ID"
;
}
return
false
;
}
/// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same.
NGTensorType
NGTensorType
::
get
(
mlir
::
MLIRContext
*
context
,
EltType
eltType
,
Shape
shape
)
{
return
Base
::
get
(
context
,
NGTypeKind
::
TENSOR_TYPE_ID
,
eltType
,
shape
);
return
Base
::
get
(
context
,
NGTypeKind
::
NG_
TENSOR_TYPE_ID
,
eltType
,
shape
);
}
mlir
::
MemRefType
NGTensorType
::
toMemref
()
...
...
src/contrib/mlir/dialect/type.hpp
View file @
9bb2fad3
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#pragma once
#include "assertion.hpp"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
...
...
@@ -22,7 +23,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
namespace
ngraph
{
namespace
runtime
...
...
@@ -36,9 +36,146 @@ namespace ngraph
// The enum starts at the range reserved for this dialect.
// These values are pre-defined in MLIR lib and not configurable from here.
NG_TYPE
=
mlir
::
Type
::
Kind
::
FIRST_PRIVATE_EXPERIMENTAL_0_TYPE
,
TENSOR_TYPE_ID
// Element types that are added by the dialect.
// Other types are just re-use of std dialect types.
NG_FIRST_INT_TYPE_ID
,
NG_I8_TYPE_ID
=
NG_FIRST_INT_TYPE_ID
,
NG_I16_TYPE_ID
,
NG_I32_TYPE_ID
,
NG_I64_TYPE_ID
,
NG_U8_TYPE_ID
,
NG_U16_TYPE_ID
,
NG_U32_TYPE_ID
,
NG_U64_TYPE_ID
,
NG_LAST_INT_TYPE_ID
=
NG_U64_TYPE_ID
,
NG_BOOL_TYPE_ID
,
// Tensor type
NG_TENSOR_TYPE_ID
};
// reuse std float types as-is
using
NGFloatType
=
mlir
::
FloatType
;
/// Integer type. It represents an integer of width 8,16,32,64. Signed or not.
class
NGIntegerType
:
public
mlir
::
Type
::
TypeBase
<
NGIntegerType
,
mlir
::
Type
>
{
public
:
using
Base
::
Base
;
static
NGIntegerType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
{
NGRAPH_ASSERT
(
kindof
(
kind
))
<<
"Not an integer kind."
;
return
Base
::
get
(
context
,
kind
);
}
/// Create signed Int8
static
NGIntegerType
getInt8
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_I8_TYPE_ID
,
ctx
);
}
/// Create signed Int16
static
NGIntegerType
getInt16
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_I16_TYPE_ID
,
ctx
);
}
/// Create signed Int32
static
NGIntegerType
getInt32
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_I32_TYPE_ID
,
ctx
);
}
/// Create signed Int64
static
NGIntegerType
getInt64
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_I64_TYPE_ID
,
ctx
);
}
/// Create unsigned Int8
static
NGIntegerType
getUInt8
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_U8_TYPE_ID
,
ctx
);
}
/// Create unsigned Int16
static
NGIntegerType
getUInt16
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_U16_TYPE_ID
,
ctx
);
}
/// Create unsigned Int32
static
NGIntegerType
getUInt32
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_U32_TYPE_ID
,
ctx
);
}
/// Create unsigned Int64
static
NGIntegerType
getUInt64
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NGTypeKind
::
NG_U64_TYPE_ID
,
ctx
);
}
/// RTTI support. So we can do obj->isa<NGIntegerType>()
static
bool
kindof
(
unsigned
kind
)
{
return
kind
>=
NGTypeKind
::
NG_FIRST_INT_TYPE_ID
&&
kind
<=
NGTypeKind
::
NG_LAST_INT_TYPE_ID
;
}
/// Return the bitwidth of this integer type.
unsigned
getWidth
()
const
;
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir
::
Type
toStdType
()
const
{
return
mlir
::
IntegerType
::
get
(
getWidth
(),
getContext
());
}
/// Check if signed type
bool
isSigned
()
const
;
/// Check if Int8
bool
isInt8
()
const
{
return
getKind
()
==
NG_I8_TYPE_ID
;
}
/// Check if UInt8
bool
isUInt8
()
const
{
return
getKind
()
==
NG_U8_TYPE_ID
;
}
/// Check if Int16
bool
isInt16
()
const
{
return
getKind
()
==
NG_I16_TYPE_ID
;
}
/// Check if UInt16
bool
isUInt16
()
const
{
return
getKind
()
==
NG_U16_TYPE_ID
;
}
/// Check if Int32
bool
isInt32
()
const
{
return
getKind
()
==
NG_I32_TYPE_ID
;
}
/// Check if UInt32
bool
isUInt32
()
const
{
return
getKind
()
==
NG_U32_TYPE_ID
;
}
/// Check if Int64
bool
isInt64
()
const
{
return
getKind
()
==
NG_I64_TYPE_ID
;
}
/// Check if UInt64
bool
isUInt64
()
const
{
return
getKind
()
==
NG_U64_TYPE_ID
;
}
// Delete convenience methods inherited from MLIR Type class.
// This would avoid confusion if we do something like this and get false.
//
// if (type->cast<NGIntegerType>()->isInteger(32)) {}
//
// Those helpers use type id, and since we have our own Integer type id, they
// don't apply.
bool
isInteger
(
unsigned
width
)
const
=
delete
;
unsigned
getIntOrFloatBitWidth
()
const
=
delete
;
bool
isIntOrIndex
()
const
=
delete
;
bool
isIntOrIndexOrFloat
()
const
=
delete
;
bool
isIntOrFloat
()
const
=
delete
;
};
/// Boolean Type.
class
NGBoolType
:
public
mlir
::
Type
::
TypeBase
<
NGBoolType
,
mlir
::
Type
>
{
public
:
using
Base
::
Base
;
static
NGBoolType
get
(
NGTypeKind
kind
,
mlir
::
MLIRContext
*
context
)
{
NGRAPH_ASSERT
(
kindof
(
kind
))
<<
"Not a bool type."
;
return
Base
::
get
(
context
,
kind
);
}
static
bool
kindof
(
unsigned
kind
)
{
return
kind
==
NGTypeKind
::
NG_BOOL_TYPE_ID
;
}
static
NGBoolType
get
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NG_BOOL_TYPE_ID
,
ctx
);
}
/// Convert to equivalent std type. Integer of width 1 in that case
mlir
::
Type
toStdType
()
const
{
return
mlir
::
IntegerType
::
get
(
1
,
getContext
());
}
};
// Note that dialect types don't add new data members, so always possible
// to use NG or std types here
using
EltType
=
mlir
::
Type
;
// TODO: Can we use ngraph::shape here (given the hashing requirements)
using
Shape
=
llvm
::
ArrayRef
<
int64_t
>
;
...
...
@@ -86,6 +223,7 @@ namespace ngraph
Shape
m_shape
;
};
/// NGraph Tensor Type
class
NGTensorType
:
public
mlir
::
Type
::
TypeBase
<
NGTensorType
,
mlir
::
Type
,
NGTensorTypeStorage
>
{
...
...
@@ -93,7 +231,9 @@ namespace ngraph
using
Base
::
Base
;
EltType
getElementType
()
const
{
return
getImpl
()
->
getElementType
();
}
Shape
getShape
()
const
{
return
getImpl
()
->
getShape
();
}
/// Tensor Rank. Static shape only for now
int
getRank
()
{
return
getShape
().
size
();
}
/// Computes tensor size in bytes
size_t
getSizeInBytes
()
{
size_t
s
=
1
;
...
...
@@ -113,7 +253,7 @@ namespace ngraph
/// create a unique tensor type based on element type and shape.
static
NGTensorType
get
(
mlir
::
MLIRContext
*
context
,
EltType
eltType
,
Shape
shape
);
/// for llvm RTTI
static
bool
kindof
(
unsigned
kind
)
{
return
kind
==
NGTypeKind
::
TENSOR_TYPE_ID
;
}
static
bool
kindof
(
unsigned
kind
)
{
return
kind
==
NGTypeKind
::
NG_
TENSOR_TYPE_ID
;
}
};
}
}
...
...
src/contrib/mlir/lowerer.cpp
View file @
9bb2fad3
...
...
@@ -17,6 +17,8 @@
#include "lowerer.hpp"
#include <map>
#include "compiler.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
#include "llvm/ADT/DenseSet.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
...
...
@@ -25,8 +27,6 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "ngraph/assertion.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
using
namespace
ngraph
::
runtime
::
ngmlir
;
// anonymous namespace
...
...
@@ -272,6 +272,23 @@ namespace
{
return
tensor
.
toMemref
();
}
// element type
if
(
auto
type
=
t
.
dyn_cast
<
NGFloatType
>
())
{
// Float
// float types are already std type
return
type
;
}
if
(
auto
type
=
t
.
dyn_cast
<
NGIntegerType
>
())
{
// map it to std type
return
type
.
toStdType
();
}
if
(
auto
type
=
t
.
dyn_cast
<
NGBoolType
>
())
{
return
type
.
toStdType
();
}
NGRAPH_FAIL
()
<<
"Unsupported type to lower"
;
return
t
;
}
...
...
src/contrib/mlir/memory_manager.cpp
View file @
9bb2fad3
...
...
@@ -14,8 +14,8 @@
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "memory_manager.hpp"
#include <memory>
#include "ngraph/ngraph_visibility.hpp"
using
namespace
ngraph
::
runtime
::
ngmlir
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment