Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
33ec9a8b
Commit
33ec9a8b
authored
Jul 20, 2019
by
nmostafa
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Lowering to affine. CPU.gather_* pass
parent
67536ddf
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
10 deletions
+123
-10
ops.cpp
src/contrib/mlir/dialect/ops.cpp
+1
-1
ops.td
src/contrib/mlir/dialect/ops.td
+6
-6
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+116
-3
No files found.
src/contrib/mlir/dialect/ops.cpp
View file @
33ec9a8b
...
...
@@ -171,7 +171,7 @@ static mlir::LogicalResult verifyCmpOp(T* op)
template
<>
mlir
::
LogicalResult
verifyOp
(
NGGatherOp
*
op
)
{
Type
ty
=
op
->
input
()
->
getType
();
Type
ty
=
op
->
params
()
->
getType
();
NGTensorType
inputType
=
ty
.
cast
<
NGTensorType
>
();
ty
=
op
->
indices
()
->
getType
();
...
...
src/contrib/mlir/dialect/ops.td
View file @
33ec9a8b
...
...
@@ -260,14 +260,14 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
// Gather
def NGGatherOp :
NG_OneResult_Op<"gather", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$
input
, NG_TensorType:$indices, I64Attr:$axis)>
Arguments<(ins NG_TensorType:$
params
, NG_TensorType:$indices, I64Attr:$axis)>
{
let summary = "Gather slices from
input
along the specified axis according to indices";
let summary = "Gather slices from
params
along the specified axis according to indices";
let description = [{
Gather slices from axis of
input
according to indices
input
The tensor from which slices are gathered
indices Index tensor
:
Data type must be `element::i32` or `element::i64`
axis Axis in
input
to gather
Gather slices from axis of
params
according to indices
params
The tensor from which slices are gathered
indices Index tensor
.
Data type must be `element::i32` or `element::i64`
axis Axis in
params
to gather
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
...
...
src/contrib/mlir/lowerer.cpp
View file @
33ec9a8b
...
...
@@ -647,14 +647,127 @@ namespace
return
matchSuccess
();
}
REWRITER
(
NG
Return
Op
)
REWRITER
(
NG
Gather
Op
)
{
rewriter
.
replaceOpWithNewOp
<
ReturnOp
>
(
op
);
auto
gatherOp
=
cast
<
NGGatherOp
>
(
op
);
auto
loc
=
gatherOp
.
getLoc
();
ScopedContext
scope
(
rewriter
,
loc
);
// Get operands
Value
*
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
NGRAPH_CHECK
(
result
,
"Unexpected null result in GatherOp"
);
auto
resultTy
=
result
->
getType
().
cast
<
MemRefType
>
();
Value
*
params
=
operands
[
0
];
Value
*
indices
=
operands
[
1
];
auto
axis
=
gatherOp
.
axis
().
getSExtValue
();
// Create view to write into result.
MemRefView
vRes
(
result
),
vParams
(
params
),
vIndices
(
indices
);
// Indexed Values
IndexedValue
iRes
(
result
),
iParams
(
params
),
iIndices
(
indices
);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector
<
ValueHandle
,
4
>
paramsLbs
,
paramsUbs
;
SmallVector
<
IndexHandle
,
4
>
paramsIVs
;
SmallVector
<
int64_t
,
4
>
paramsSteps
;
SmallVector
<
ValueHandle
*
,
4
>
paramsIVPtrs
;
for
(
auto
i
=
0
;
i
<
vParams
.
rank
();
i
++
)
{
// skip gather axis
if
(
i
==
axis
)
continue
;
paramsLbs
.
push_back
(
IndexHandle
(
vParams
.
lb
(
i
)));
paramsUbs
.
push_back
(
IndexHandle
(
vParams
.
ub
(
i
)));
paramsSteps
.
push_back
(
vParams
.
step
(
i
));
}
NGRAPH_CHECK
(
paramsLbs
.
size
()
==
vParams
.
rank
()
-
1
&&
paramsUbs
.
size
()
==
paramsLbs
.
size
()
&&
paramsSteps
.
size
()
==
paramsLbs
.
size
(),
"Incorrect loop nest bounds size for gather params"
);
paramsIVs
=
IndexHandle
::
makeIndexHandles
(
vParams
.
rank
()
-
1
);
paramsIVPtrs
=
IndexHandle
::
makeIndexHandlePointers
(
paramsIVs
);
auto
indicesLbs
=
vIndices
.
getLbs
();
auto
indicesUbs
=
vIndices
.
getUbs
();
auto
indicesSteps
=
vIndices
.
getSteps
();
auto
indicesIVs
=
IndexHandle
::
makeIndexHandles
(
vIndices
.
rank
());
auto
indicesIVPtrs
=
IndexHandle
::
makeIndexHandlePointers
(
indicesIVs
);
SmallVector
<
IndexHandle
,
8
>
paramsIndices
,
resIndices
;
// Make sure we are going to create loops
NGRAPH_CHECK
(
vParams
.
rank
()
>
0
,
"Invalid size for indices steps"
);
// Let params rank : N
// Let indices rank : M
// Let axis be A
// Generate
// params loops
// for P_0: 0 -> params.dim[0]
// for P_1: 0 -> params.dim[1]
// for P_2: 0 -> params.dim[2]
// ...
// for P_(A-1):0 -> params.dim[A-1]
// for P_(A+1):0 -> params.dim[A+1]
// ...
// for P_(N-1):0 -> params.dim[N-1]
// indices loops
// for I_0:0 -> indices.dim[0]
// ...
// for I_(M-1):0 -> indices.dim[M-1]
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], P_(A+1), ... P_(N-1)];
LoopNestBuilder
(
paramsIVPtrs
,
paramsLbs
,
paramsUbs
,
paramsSteps
)([
&
]
{
LoopNestBuilder
(
indicesIVPtrs
,
indicesLbs
,
indicesUbs
,
indicesSteps
)([
&
]
{
// Load axis value from indices array and cast it to Index Type
ValueHandle
axisIdx
=
ValueHandle
::
create
<
IndexCastOp
>
(
(
ValueHandle
)
iIndices
(
indicesIVs
),
rewriter
.
getIndexType
());
// construct indices for param
// [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
for
(
auto
i
=
0
,
j
=
0
;
i
<
vParams
.
rank
();
i
++
)
{
if
(
i
==
axis
)
{
paramsIndices
.
push_back
(
IndexHandle
(
axisIdx
));
}
else
{
paramsIndices
.
push_back
(
paramsIVs
[
j
++
]);
}
}
// construct indices for result
// [P_0, P_1, .. P_axis-1, I0, I1, .. I_k-1, P_axis+1, P_axis+2, .. P_n-1]
for
(
auto
i
=
0
,
j
=
0
;
i
<
vParams
.
rank
()
+
vIndices
.
rank
()
-
1
;)
{
if
(
i
==
axis
&&
indicesIVs
.
size
()
>
0
)
{
resIndices
.
append
(
indicesIVs
.
begin
(),
indicesIVs
.
end
());
i
+=
indicesIVs
.
size
();
}
else
{
resIndices
.
push_back
(
paramsIVs
[
j
++
]);
i
++
;
}
}
// Store into result
iRes
(
resIndices
)
=
iParams
(
paramsIndices
);
});
});
rewriter
.
replaceOp
(
op
,
{
result
});
return
matchSuccess
();
}
REWRITER
(
NG
Gather
Op
)
REWRITER
(
NG
Return
Op
)
{
rewriter
.
replaceOpWithNewOp
<
ReturnOp
>
(
op
);
return
matchSuccess
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment