Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e0135089
Unverified
Commit
e0135089
authored
Feb 21, 2020
by
Sang Ik Lee
Committed by
GitHub
Feb 21, 2020
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into feature/android_fix
parents
fa2d5225
cbf84017
Show whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
807 additions
and
272 deletions
+807
-272
external_mkldnn_v1.cmake
cmake/external_mkldnn_v1.cmake
+3
-3
external_mlir.cmake
cmake/external_mlir.cmake
+1
-1
memory_analysis.cpp
src/contrib/mlir/backend/analysis/memory_analysis.cpp
+0
-1
cpu_backend.cpp
src/contrib/mlir/backend/cpu/cpu_backend.cpp
+2
-1
affine_lowerer.cpp
src/contrib/mlir/backend/pass/affine_lowerer.cpp
+170
-188
ng_dialect_fused_ops.cpp
src/contrib/mlir/core/pass/ng_dialect_fused_ops.cpp
+1
-3
cpu_callbacks.cpp
src/contrib/mlir/runtime/cpu/cpu_callbacks.cpp
+4
-4
cpu_runtime.cpp
src/contrib/mlir/runtime/cpu/cpu_runtime.cpp
+4
-3
CMakeLists.txt
src/contrib/mlir/tools/ngraph-opt/CMakeLists.txt
+1
-1
utils.cpp
src/contrib/mlir/utils.cpp
+54
-11
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+2
-0
average_pool.cpp
src/ngraph/frontend/onnx_import/op/average_pool.cpp
+1
-2
max_pool.cpp
src/ngraph/frontend/onnx_import/op/max_pool.cpp
+1
-1
onehot.cpp
src/ngraph/frontend/onnx_import/op/onehot.cpp
+2
-2
pad.cpp
src/ngraph/frontend/onnx_import/op/pad.cpp
+7
-2
round.cpp
src/ngraph/frontend/onnx_import/op/round.cpp
+41
-0
round.hpp
src/ngraph/frontend/onnx_import/op/round.hpp
+38
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+2
-0
convpool.cpp
src/ngraph/frontend/onnx_import/utils/convpool.cpp
+41
-16
convpool.hpp
src/ngraph/frontend/onnx_import/utils/convpool.hpp
+3
-14
pooling_factory.cpp
src/ngraph/frontend/onnx_import/utils/pooling_factory.cpp
+24
-6
pooling_factory.hpp
src/ngraph/frontend/onnx_import/utils/pooling_factory.hpp
+14
-2
matcher.cpp
src/ngraph/pattern/matcher.cpp
+10
-2
cpu_runtime_context.hpp
src/ngraph/runtime/cpu/cpu_runtime_context.hpp
+1
-1
cpu_mkldnn_primitive_build.hpp
src/ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.hpp
+1
-1
unit_test.manifest
src/ngraph/runtime/gpu/unit_test.manifest
+1
-0
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+1
-0
callback_ops.mlir
test/mlir/affine_conversion/callback_ops.mlir
+7
-7
average_pool_2d_dyn.prototxt
test/models/onnx/dynamic_shapes/average_pool_2d_dyn.prototxt
+57
-0
global_average_pool_dyn.prototxt
...dels/onnx/dynamic_shapes/global_average_pool_dyn.prototxt
+51
-0
global_max_pool_dyn.prototxt
test/models/onnx/dynamic_shapes/global_max_pool_dyn.prototxt
+51
-0
max_pool_2d_dyn.prototxt
test/models/onnx/dynamic_shapes/max_pool_2d_dyn.prototxt
+65
-0
round.prototxt
test/models/onnx/round.prototxt
+39
-0
onnx_import.in.cpp
test/onnx/onnx_import.in.cpp
+27
-0
onnx_import_dyn_shapes.in.cpp
test/onnx/onnx_import_dyn_shapes.in.cpp
+80
-0
No files found.
cmake/external_mkldnn_v1.cmake
View file @
e0135089
...
...
@@ -18,12 +18,12 @@ include(ExternalProject)
# Includes blas 3.8.0 in mkldnn
set
(
NGRAPH_MKLDNN_SHORT_VERSION 1
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 1.
1.1
.0
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 1.
2.0
.0
)
set
(
NGRAPH_MKLDNN_MKLML_ASSET_VERSION
"v0.21"
)
set
(
NGRAPH_MKLDNN_VERSION
"v1.
1.1
"
)
set
(
NGRAPH_MKLDNN_VERSION
"v1.
2
"
)
set
(
NGRAPH_MKLDNN_MKLML_VERSION
"2019.0.5.20190502"
)
set
(
NGRAPH_MKLDNN_MKLML_WIN32_VERSION
"2020.0.20190813"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"v1.
1.1
"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"v1.
2
"
)
#------------------------------------------------------------------------------
# Fetch and install MKL-DNN
...
...
cmake/external_mlir.cmake
View file @
e0135089
...
...
@@ -19,7 +19,7 @@ include(ExternalProject)
set
(
MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git
)
# Change these commit IDs to move to latest stable versions
set
(
MLIR_LLVM_COMMIT_ID
96400ae
)
set
(
MLIR_LLVM_COMMIT_ID
376c6853
)
# MLIR environment variables. Some of them are used by LIT tool.
...
...
src/contrib/mlir/backend/analysis/memory_analysis.cpp
View file @
e0135089
...
...
@@ -26,7 +26,6 @@
#include <llvm/ADT/DenseSet.h>
#include <map>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/IntegerSet.h>
...
...
src/contrib/mlir/backend/cpu/cpu_backend.cpp
View file @
e0135089
...
...
@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect()
void
MLIRCPUBackend
::
lowerStandardDialect
()
{
mlir
::
PassManager
pm
(
&
m_context
);
pm
.
addPass
(
mlir
::
createLowerToLLVMPass
());
pm
.
addPass
(
mlir
::
createLowerToLLVMPass
(
/*useAlloca=*/
false
,
/*useBarePtrCallConv=*/
false
,
/*emitCWrappers=*/
true
));
// Apply any generic pass manager command line options.
mlir
::
applyPassManagerCLOptions
(
pm
);
...
...
src/contrib/mlir/backend/pass/affine_lowerer.cpp
View file @
e0135089
...
...
@@ -28,9 +28,9 @@
#include <llvm/ADT/DenseSet.h>
#include <llvm/Support/Debug.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/
EDSC/Helper
s.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/
Dialect/AffineOps/
EDSC/Builders.h>
#include <mlir/
Dialect/AffineOps/EDSC/Intrinsic
s.h>
#include <mlir/
Dialect/StandardOps/
EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/IntegerSet.h>
...
...
@@ -51,11 +51,10 @@ namespace
{
using
namespace
mlir
;
using
namespace
mlir
::
edsc
;
using
namespace
mlir
::
edsc
::
intrinsics
;
using
namespace
mlir
::
edsc
::
op
;
using
namespace
ngraph
::
runtime
;
using
namespace
ngraph
::
runtime
::
ngmlir
;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using
StdIndexedValue
=
TemplatedIndexedValue
<
intrinsics
::
std_load
,
intrinsics
::
std_store
>
;
class
DialectLoweringPass
;
...
...
@@ -215,9 +214,37 @@ namespace
NGraphTypeConverter
()
:
TypeConverter
()
{
// TODO(dcaballe): split this into independent conversion patterns when there is a
// way to check if a type is valid in Std dialect.
addConversion
([
this
](
Type
type
)
->
Type
{
if
(
auto
tensorType
=
type
.
dyn_cast
<
NGTensorType
>
())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std
// TensorType. This may change in the future.
return
MemRefType
::
get
(
tensorType
.
getShape
(),
convertType
(
tensorType
.
getElementType
()),
{
/* no map used */
},
0
);
}
if
(
auto
floatType
=
type
.
dyn_cast
<
NGFloatType
>
())
{
// Float types are already std type.
return
floatType
;
}
if
(
auto
intType
=
type
.
dyn_cast
<
NGIntegerType
>
())
{
return
mlir
::
IntegerType
::
get
(
intType
.
getWidth
(),
intType
.
getContext
());
}
if
(
auto
boolType
=
type
.
dyn_cast
<
NGBoolType
>
())
{
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
boolType
.
getContext
());
}
Type
convertType
(
Type
t
)
override
;
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to
// return the input type if the type is not supported.
return
type
;
});
}
};
/// Dialect Lowering Pass to affine ops
...
...
@@ -317,7 +344,8 @@ namespace
// TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion.
insertNoAliasArgAttrs
();
// TODO: To be enabled in follow-up commit.
// insertNoAliasArgAttrs();
}
opAttrsVec
=
m_attrsVec
;
...
...
@@ -492,22 +520,22 @@ namespace
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics.
void
DialectLoweringPass
::
insertNoAliasArgAttrs
()
{
FuncOp
func
=
getModule
().
lookupSymbol
<
mlir
::
FuncOp
>
(
funcName
);
NGRAPH_CHECK
(
func
,
"FuncOp '"
+
funcName
.
str
()
+
"' not found"
);
unsigned
int
argIdx
=
0
;
for
(
auto
arg
:
func
.
getArguments
())
{
if
(
arg
.
getType
().
isa
<
MemRefType
>
())
{
func
.
setArgAttr
(
argIdx
,
"llvm.noalias"
,
BoolAttr
::
get
(
true
,
&
getContext
()));
}
++
argIdx
;
}
}
//
void DialectLoweringPass::insertNoAliasArgAttrs()
//
{
//
FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
//
NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
//
unsigned int argIdx = 0;
//
for (auto arg : func.getArguments())
//
{
//
if (arg.getType().isa<MemRefType>())
//
{
//
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
//
}
//
++argIdx;
//
}
//
}
void
DialectLoweringPass
::
insertDeallocs
(
PatternRewriter
&
rewriter
)
{
...
...
@@ -543,40 +571,6 @@ namespace
return
m_attrsVec
.
size
()
-
1
;
}
// NGDialect converters
Type
NGraphTypeConverter
::
convertType
(
Type
type
)
{
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since NGraphTypeConverter is private.
if
(
auto
tensorType
=
type
.
dyn_cast
<
NGTensorType
>
())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return
MemRefType
::
get
(
tensorType
.
getShape
(),
convertType
(
tensorType
.
getElementType
()),
{
/* no map used */
},
0
);
}
if
(
auto
floatType
=
type
.
dyn_cast
<
NGFloatType
>
())
{
// Float types are already std type.
return
floatType
;
}
if
(
auto
intType
=
type
.
dyn_cast
<
NGIntegerType
>
())
{
return
mlir
::
IntegerType
::
get
(
intType
.
getWidth
(),
intType
.
getContext
());
}
if
(
auto
boolType
=
type
.
dyn_cast
<
NGBoolType
>
())
{
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
boolType
.
getContext
());
}
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return
// the input type if the type is not supported.
return
type
;
}
#define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
...
...
@@ -680,15 +674,15 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
// Views
MemRef
View
vRes
(
result
),
vLHS
(
lhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLHS
(
lhs
);
// Index Values
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
Affine
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
// Bounds Index Handles
auto
lbs
=
vLHS
.
getLbs
();
auto
ubs
=
vLHS
.
getUbs
();
// Loop induction vars
auto
ivs
=
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
ivs
);
// Steps
auto
steps
=
vLHS
.
getSteps
();
...
...
@@ -698,7 +692,7 @@ namespace
AffineLoopNestBuilder
(
pivs
,
lbs
,
ubs
,
steps
)([
&
]
{
ValueHandle
val
=
iLHS
(
ivs
);
ValueHandle
zero
=
createZeroConstant
(
elemTy
);
iRes
(
ivs
)
=
intrinsics
::
select
(
val
>
zero
,
val
,
zero
);
iRes
(
ivs
)
=
std_
select
(
val
>
zero
,
val
,
zero
);
});
rewriter
.
replaceOp
(
op
,
{
result
});
...
...
@@ -742,36 +736,37 @@ namespace
// res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRef
View
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
);
NGRAPH_CHECK
(
vLhs
.
rank
()
==
2
&&
vRhs
.
rank
()
==
2
&&
vRes
.
rank
()
==
2
,
"Dot operation is only supported for 2D tensors"
);
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first
// one.
IndexHandle
n
,
m
,
k
;
// It's important to note that MemRefBoundsCapture priovides lb/ub/step info is "reverse
// order", i.e., fastest varying dimension is the last one, slowest varying dimention is the
// first one.
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
());
ValueHandle
n
(
indexType
),
m
(
indexType
),
k
(
indexType
);
unsigned
nDim
=
vLhs
.
fastestVarying
()
-
1
;
unsigned
mDim
=
vRhs
.
fastestVarying
();
unsigned
kDim
=
vRhs
.
fastestVarying
();
Index
Handle
nLb
(
vLhs
.
lb
(
nDim
)),
mLb
(
vLhs
.
lb
(
mDim
)),
kLb
(
vRhs
.
lb
(
kDim
));
Index
Handle
nUb
(
vLhs
.
ub
(
nDim
)),
mUb
(
vLhs
.
ub
(
mDim
)),
kUb
(
vRhs
.
ub
(
kDim
));
Value
Handle
nLb
(
vLhs
.
lb
(
nDim
)),
mLb
(
vLhs
.
lb
(
mDim
)),
kLb
(
vRhs
.
lb
(
kDim
));
Value
Handle
nUb
(
vLhs
.
ub
(
nDim
)),
mUb
(
vLhs
.
ub
(
mDim
)),
kUb
(
vRhs
.
ub
(
kDim
));
int64_t
nStep
=
vLhs
.
step
(
nDim
),
mStep
=
vLhs
.
step
(
mDim
),
kStep
=
vRhs
.
step
(
kDim
);
// Constants and indexed values to be used inside the loop nest.
IndexedValue
iRes
(
result
),
iLhs
(
lhs
),
iRhs
(
rhs
);
Affine
IndexedValue
iRes
(
result
),
iLhs
(
lhs
),
iRhs
(
rhs
);
ValueHandle
zeroInit
(
rewriter
.
create
<
ConstantOp
>
(
loc
,
rewriter
.
getZeroAttr
(
elemTy
)));
{
IndexHandle
n
,
k
;
LoopBuilder
::
makeAffine
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
k
,
kLb
,
kUb
,
kStep
)([
&
]
{
iRes
(
n
,
k
)
=
zeroInit
;
});
ValueHandle
n
(
indexType
),
k
(
indexType
)
;
makeAffineLoopBuilder
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
makeAffineLoopBuilder
(
&
k
,
kLb
,
kUb
,
kStep
)([
&
]
{
iRes
(
n
,
k
)
=
zeroInit
;
});
});
}
LoopBuilder
::
makeAffine
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
m
,
mLb
,
mUb
,
mStep
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
k
,
kLb
,
kUb
,
kStep
)(
makeAffineLoopBuilder
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
makeAffineLoopBuilder
(
&
m
,
mLb
,
mUb
,
mStep
)([
&
]
{
makeAffineLoopBuilder
(
&
k
,
kLb
,
kUb
,
kStep
)(
[
&
]
{
iRes
(
n
,
k
)
+=
iLhs
(
n
,
m
)
*
iRhs
(
m
,
k
);
});
});
});
...
...
@@ -792,13 +787,13 @@ namespace
NGRAPH_CHECK
(
result
,
"Unexpected null result in ConcatOp"
);
// Create view to write into result.
MemRef
View
vRes
(
result
);
MemRef
BoundsCapture
vRes
(
result
);
auto
rank
=
vRes
.
rank
();
// For each operand, generate a separate loop to copy into the target slice of "result".
// We'll keep track of the slice offsets via concatenation_axis_pos.
auto
concatenationAxis
=
concat
.
concatenation_axis
().
getSExtValue
();
IndexHandle
concatenationAxisPos
(
index_type
(
0
));
Value
concatenationAxisPos
(
std_constant_index
(
0
));
for
(
auto
&
operand
:
operands
)
{
...
...
@@ -817,7 +812,7 @@ namespace
// [i_(r-2)][i_(r-1)]
// :=
// operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
MemRef
View
vOperand
(
operand
);
MemRef
BoundsCapture
vOperand
(
operand
);
NGRAPH_CHECK
(
vOperand
.
rank
()
==
rank
,
"Unexpected rank mismatch"
);
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVars
;
...
...
@@ -825,9 +820,10 @@ namespace
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVarLbs
;
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVarUbs
;
llvm
::
SmallVector
<
int64_t
,
5
>
indexVarSteps
;
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
());
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
indexVars
.
push_back
(
IndexHandle
(
));
indexVars
.
push_back
(
ValueHandle
(
indexType
));
indexVarPtrs
.
push_back
(
&
(
indexVars
.
back
()));
indexVarLbs
.
push_back
(
vOperand
.
lb
(
i
));
indexVarUbs
.
push_back
(
vOperand
.
ub
(
i
));
...
...
@@ -835,15 +831,15 @@ namespace
}
AffineLoopNestBuilder
(
indexVarPtrs
,
indexVarLbs
,
indexVarUbs
,
indexVarSteps
)([
&
]
{
IndexedValue
ivRes
(
result
);
IndexedValue
ivOperand
(
operand
);
Affine
IndexedValue
ivRes
(
result
);
Affine
IndexedValue
ivOperand
(
operand
);
// On the LHS of the assignment, adjust the index for the concatenation axis.
llvm
::
SmallVector
<
ValueHandle
,
5
>
resIndexHandles
;
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
resIndexHandles
.
push_back
(
i
==
concatenationAxis
?
indexVars
[
i
]
+
concatenationAxisPos
?
indexVars
[
i
]
+
ValueHandle
(
concatenationAxisPos
)
:
indexVars
[
i
]);
}
...
...
@@ -851,11 +847,11 @@ namespace
});
// Move up concatenation_axis_pos for the next operand.
concatenationAxisPos
=
concatenationAxisPos
+
vOperand
.
ub
(
concatenationAxis
);
concatenationAxisPos
=
ValueHandle
(
concatenationAxisPos
)
+
vOperand
.
ub
(
concatenationAxis
);
}
rewriter
.
replaceOp
(
op
,
{
result
});
return
matchSuccess
();
}
...
...
@@ -874,14 +870,13 @@ namespace
auto
axis
=
gatherOp
.
axis
().
getSExtValue
();
// Create view to write into result.
MemRef
View
vRes
(
result
),
vParams
(
params
),
vIndices
(
indices
);
MemRef
BoundsCapture
vRes
(
result
),
vParams
(
params
),
vIndices
(
indices
);
// Indexed Values
IndexedValue
iRes
(
result
),
iIndices
(
indices
);
Affine
IndexedValue
iRes
(
result
),
iIndices
(
indices
);
StdIndexedValue
iParams
(
params
);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector
<
ValueHandle
,
4
>
paramsLbs
,
paramsUbs
;
SmallVector
<
IndexHandle
,
4
>
paramsIVs
;
SmallVector
<
ValueHandle
,
4
>
paramsLbs
,
paramsUbs
,
paramsIVs
;
SmallVector
<
int64_t
,
4
>
paramsSteps
;
SmallVector
<
ValueHandle
*
,
4
>
paramsIVPtrs
;
for
(
auto
i
=
0
;
i
<
vParams
.
rank
();
i
++
)
...
...
@@ -889,8 +884,8 @@ namespace
// skip gather axis
if
(
i
==
axis
)
continue
;
paramsLbs
.
push_back
(
IndexHandle
(
vParams
.
lb
(
i
)
));
paramsUbs
.
push_back
(
IndexHandle
(
vParams
.
ub
(
i
)
));
paramsLbs
.
push_back
(
vParams
.
lb
(
i
));
paramsUbs
.
push_back
(
vParams
.
ub
(
i
));
paramsSteps
.
push_back
(
vParams
.
step
(
i
));
}
NGRAPH_CHECK
(
paramsLbs
.
size
()
==
vParams
.
rank
()
-
1
&&
...
...
@@ -898,17 +893,17 @@ namespace
paramsSteps
.
size
()
==
paramsLbs
.
size
(),
"Incorrect loop nest bounds size for gather params"
);
paramsIVs
=
makeIndexHandles
(
vParams
.
rank
()
-
1
);
paramsIVPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
paramsIVs
)
);
paramsIVs
=
ValueHandle
::
makeIndexHandles
(
vParams
.
rank
()
-
1
);
paramsIVPtrs
=
makeHandlePointers
(
paramsIVs
);
auto
indicesLbs
=
vIndices
.
getLbs
();
auto
indicesUbs
=
vIndices
.
getUbs
();
auto
indicesSteps
=
vIndices
.
getSteps
();
auto
indicesIVs
=
makeIndexHandles
(
vIndices
.
rank
());
auto
indicesIVPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
indicesIVs
)
);
auto
indicesIVs
=
ValueHandle
::
makeIndexHandles
(
vIndices
.
rank
());
auto
indicesIVPtrs
=
makeHandlePointers
(
indicesIVs
);
SmallVector
<
Index
Handle
,
8
>
paramsIndices
,
resIndices
;
SmallVector
<
Value
Handle
,
8
>
paramsIndices
,
resIndices
;
// Make sure we are going to create loops
NGRAPH_CHECK
(
vParams
.
rank
()
>
0
,
"Invalid size for indices steps"
);
...
...
@@ -946,7 +941,7 @@ namespace
{
if
(
i
==
axis
)
{
paramsIndices
.
push_back
(
IndexHandle
(
axisIdx
)
);
paramsIndices
.
push_back
(
axisIdx
);
}
else
{
...
...
@@ -1022,10 +1017,10 @@ namespace
NGRAPH_CHECK
(
groups
>
0
,
"Invalid number of groups"
);
// create outer group convolution loop
// for group = 0 to groups
IndexHandle
iv
;
ValueHandle
lb
=
intrinsics
::
constant_index
(
0
);
ValueHandle
ub
=
intrinsics
::
constant_index
(
groups
);
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
())
;
ValueHandle
iv
(
indexType
);
ValueHandle
lb
=
std_
constant_index
(
0
);
ValueHandle
ub
=
std_
constant_index
(
groups
);
auto
imagesType
=
images
.
getType
().
cast
<
MemRefType
>
();
auto
filtersType
=
filters
.
getType
().
cast
<
MemRefType
>
();
...
...
@@ -1043,13 +1038,13 @@ namespace
NGRAPH_CHECK
(
groupsInFilters
||
filtersShape
[
0
]
%
groups
==
0
,
"Filters dim is not divisible by number of groups"
);
auto
channelGroupSize
=
intrinsics
::
constant_index
(
imagesShape
[
1
]
/
groups
);
auto
filtersGroupSize
=
intrinsics
::
constant_index
(
groupsInFilters
?
filtersShape
[
1
]
:
filtersShape
[
0
]
/
groups
);
auto
channelGroupSize
=
std_
constant_index
(
imagesShape
[
1
]
/
groups
);
auto
filtersGroupSize
=
std_constant_index
(
groupsInFilters
?
filtersShape
[
1
]
:
filtersShape
[
0
]
/
groups
);
NGRAPH_CHECK
(
!
groupsInFilters
||
groups
==
filtersShape
[
0
]);
LoopBuilder
::
makeAffine
(
&
iv
,
lb
,
ub
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
iv
,
lb
,
ub
,
1
)([
&
]
{
// lower/upper bounds on image channel dim and kernels dim
auto
cLb
=
iv
*
channelGroupSize
;
auto
cUb
=
cLb
+
channelGroupSize
;
...
...
@@ -1152,7 +1147,7 @@ namespace
castMemRef
(
inputs
,
outputs
,
rewriter
,
unrankedMemrefTy
);
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_2_inputs"
,
"callback_2_inputs"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
rewriter
);
...
...
@@ -1245,7 +1240,7 @@ namespace
auto
int64Ty
=
rewriter
.
getIntegerType
(
64
);
auto
unrankedMemrefTy
=
UnrankedMemRefType
::
get
(
elemTy
,
0
);
auto
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_2_inputs"
,
"callback_2_inputs"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
rewriter
);
...
...
@@ -1297,7 +1292,7 @@ namespace
elemTy
==
biasTy
.
getElementType
(),
"Types mismatch in GemmOp"
);
MemRef
View
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
),
vBias
(
bias
);
MemRef
BoundsCapture
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
),
vBias
(
bias
);
NGRAPH_CHECK
(
vLhs
.
rank
()
==
2
&&
vRhs
.
rank
()
==
2
&&
vRes
.
rank
()
==
2
&&
vBias
.
rank
()
<=
2
,
"Gemm operation is only supported for 2D tensors"
);
...
...
@@ -1361,7 +1356,7 @@ namespace
auto
int64Ty
=
rewriter
.
getIntegerType
(
64
);
auto
unrankedMemrefTy
=
UnrankedMemRefType
::
get
(
elemTy
,
0
);
auto
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_3_inputs"
,
auto
callBackFunc
=
pass
.
getCallDecl
(
"callback_3_inputs"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
...
...
@@ -1425,7 +1420,7 @@ namespace
rewriter
.
getUnknownLoc
(),
static_cast
<
int64_t
>
(
OpType
::
SOFTMAX
),
64
);
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_1_input"
,
pass
.
getCallDecl
(
"callback_1_input"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
rewriter
);
...
...
@@ -1511,11 +1506,12 @@ namespace
auto
padBelow
=
padBelowAttr
.
getValue
();
auto
padAbove
=
padBelowAttr
.
getValue
();
Type
elemTy
=
images
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
());
// Create views
MemRef
View
vRes
(
result
),
vImages
(
images
),
vFilters
(
filters
);
MemRef
BoundsCapture
vRes
(
result
),
vImages
(
images
),
vFilters
(
filters
);
// Create indexed Values
IndexedValue
iRes
(
result
),
iImages
(
images
),
iFilters
(
filters
);
Affine
IndexedValue
iRes
(
result
),
iImages
(
images
),
iFilters
(
filters
);
// Bounds on batch size N
ValueHandle
batchLb
=
vImages
.
lb
(
0
),
batchUb
=
vImages
.
ub
(
0
);
// Bounds on spatial dimensions
...
...
@@ -1526,9 +1522,8 @@ namespace
unsigned
spatialRank
=
vImages
.
rank
()
-
2
;
// Result spatial indices and bounds
auto
resSpatialIndices
=
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndicesPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
resSpatialIndices
));
auto
resSpatialIndices
=
ValueHandle
::
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndicesPtrs
=
makeHandlePointers
(
resSpatialIndices
);
SmallVector
<
int64_t
,
4
>
resSteps
,
filtersSteps
;
SmallVector
<
int
,
4
>
padBelowIntValues
;
bool
withPadding
=
false
;
...
...
@@ -1610,9 +1605,8 @@ namespace
"Results spatial dims mismatches input"
);
// Filters spatial indices and bounds
auto
filtersSpatialIndices
=
makeIndexHandles
(
spatialRank
);
auto
filtersSpatialIndicesPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
filtersSpatialIndices
));
auto
filtersSpatialIndices
=
ValueHandle
::
makeIndexHandles
(
spatialRank
);
auto
filtersSpatialIndicesPtrs
=
makeHandlePointers
(
filtersSpatialIndices
);
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
...
...
@@ -1658,23 +1652,22 @@ namespace
// Initialize output to zero
{
IndexHandle
n
,
k
,
c
;
auto
resSpatialIndices
=
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndicesPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
resSpatialIndices
));
ValueHandle
n
(
indexType
),
k
(
indexType
),
c
(
indexType
);
auto
resSpatialIndices
=
ValueHandle
::
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndicesPtrs
=
makeHandlePointers
(
resSpatialIndices
);
LoopBuilder
::
makeAffine
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
AffineLoopNestBuilder
(
resSpatialIndicesPtrs
,
resSpatialLbs
,
resSpatialUbs
,
resSteps
)([
&
]
{
SmallVector
<
Index
Handle
,
4
>
resIndices
;
SmallVector
<
Value
Handle
,
4
>
resIndices
;
// Result indices
resIndices
.
push_back
(
n
);
if
(
groupConvolution
&&
groupsInFilters
)
{
// compute global C_OUT from gID and k
// gId * C_OUT (num of filters) + k
resIndices
.
push_back
(
IndexHandle
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
)
);
resIndices
.
push_back
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
);
}
else
{
...
...
@@ -1689,31 +1682,31 @@ namespace
});
}
IndexHandle
n
,
k
,
c
;
ValueHandle
n
(
indexType
),
k
(
indexType
),
c
(
indexType
)
;
// Convolution loop
LoopBuilder
::
makeAffine
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
// Number of filters loop
LoopBuilder
::
makeAffine
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
// Channels loop
LoopBuilder
::
makeAffine
(
&
c
,
numChannelsLb
,
numChannelsUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
c
,
numChannelsLb
,
numChannelsUb
,
1
)([
&
]
{
// Results loop
AffineLoopNestBuilder
(
resSpatialIndicesPtrs
,
resSpatialLbs
,
resSpatialUbs
,
resSteps
)([
&
]
{
// Compute image start indices
SmallVector
<
Index
Handle
,
4
>
imgStartIndices
;
SmallVector
<
Value
Handle
,
4
>
imgStartIndices
;
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
IntegerAttr
iAttr
=
strides
[
i
].
cast
<
IntegerAttr
>
();
auto
stride
=
intrinsics
::
constant_index
(
iAttr
.
getInt
());
imgStartIndices
.
push_back
(
IndexHandle
(
resSpatialIndices
[
i
]
*
stride
)
);
auto
stride
=
std_
constant_index
(
iAttr
.
getInt
());
imgStartIndices
.
push_back
(
resSpatialIndices
[
i
]
*
stride
);
}
SmallVector
<
Index
Handle
,
4
>
resIndices
;
SmallVector
<
Value
Handle
,
4
>
resIndices
;
// Result indices
resIndices
.
push_back
(
n
);
if
(
groupConvolution
&&
groupsInFilters
)
{
// gId * C_OUT (num of filters) + k
resIndices
.
push_back
(
IndexHandle
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
)
);
resIndices
.
push_back
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
);
}
else
{
...
...
@@ -1727,15 +1720,14 @@ namespace
filtersSpatialLbs
,
filtersSpatialUbs
,
filtersSteps
)([
&
]
{
SmallVector
<
Index
Handle
,
4
>
imgIndices
,
filtersIndices
;
SmallVector
<
Value
Handle
,
4
>
imgIndices
,
filtersIndices
;
// Image indices
// Here we compute the virtual start index into the padded image.
imgIndices
.
push_back
(
n
);
imgIndices
.
push_back
(
c
);
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
imgIndices
.
push_back
(
IndexHandle
(
imgStartIndices
[
i
]
+
filtersSpatialIndices
[
i
]));
imgIndices
.
push_back
(
imgStartIndices
[
i
]
+
filtersSpatialIndices
[
i
]);
}
// Filter indices
...
...
@@ -1744,14 +1736,14 @@ namespace
// index
if
(
groupConvolution
&&
groupsInFilters
)
{
filtersIndices
.
push_back
(
Index
Handle
(
gId
));
filtersIndices
.
push_back
(
Value
Handle
(
gId
));
}
filtersIndices
.
push_back
(
k
);
// subtract lower bound of channel
// if we are doing group convolution this bound will advance based
// on the group id. For the filters, it should always start from 0
filtersIndices
.
push_back
(
IndexHandle
(
c
-
numChannelsLb
)
);
filtersIndices
.
push_back
(
c
-
numChannelsLb
);
filtersIndices
.
insert
(
filtersIndices
.
end
(),
filtersSpatialIndices
.
begin
(),
filtersSpatialIndices
.
end
());
...
...
@@ -1759,7 +1751,7 @@ namespace
if
(
withPadding
)
{
// if args : img dims, img lbs, img ubs
SmallVector
<
Index
Handle
,
4
>::
iterator
it
=
imgIndices
.
begin
();
SmallVector
<
Value
Handle
,
4
>::
iterator
it
=
imgIndices
.
begin
();
std
::
advance
(
it
,
2
);
SmallVector
<
Value
,
4
>
affineIfArgs
(
it
,
imgIndices
.
end
());
affineIfArgs
.
insert
(
...
...
@@ -1777,14 +1769,14 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
// We must subtract pad below before img load, since the
// physical image is not padded
SmallVector
<
Index
Handle
,
4
>
adjustedImgIndices
;
SmallVector
<
Value
Handle
,
4
>
adjustedImgIndices
;
adjustedImgIndices
.
push_back
(
n
);
adjustedImgIndices
.
push_back
(
c
);
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
adjustedImgIndices
.
push_back
(
IndexHandle
(
adjustedImgIndices
.
push_back
(
imgIndices
[
2
+
i
]
-
intrinsics
::
constant_index
(
padBelowIntValues
[
i
])
));
std_constant_index
(
padBelowIntValues
[
i
]
));
}
iRes
(
resIndices
)
=
iRes
(
resIndices
)
+
...
...
@@ -1821,15 +1813,15 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
// Views
MemRef
View
vRes
(
result
),
vLHS
(
lhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLHS
(
lhs
);
// Index Values
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
Affine
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
// Bounds Index Handles
auto
lbs
=
vLHS
.
getLbs
();
auto
ubs
=
vLHS
.
getUbs
();
// Loop induction vars
auto
ivs
=
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
ivs
);
// Steps
auto
steps
=
vLHS
.
getSteps
();
...
...
@@ -1867,15 +1859,15 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
// Views
MemRef
View
vRes
(
result
),
vLHS
(
lhs
),
vRHS
(
rhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLHS
(
lhs
),
vRHS
(
rhs
);
// Index Values
IndexedValue
iRes
(
result
),
iLHS
(
lhs
),
iRHS
(
rhs
);
Affine
IndexedValue
iRes
(
result
),
iLHS
(
lhs
),
iRHS
(
rhs
);
// Bounds Index Handles
auto
lbs
=
vLHS
.
getLbs
();
auto
ubs
=
vLHS
.
getUbs
();
// Loop induction vars
auto
ivs
=
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
ivs
);
// Steps
auto
steps
=
vLHS
.
getSteps
();
// element type of the operand
...
...
@@ -1900,63 +1892,55 @@ namespace
iRes
(
ivs
)
=
iLHS
(
ivs
)
/
iRHS
(
ivs
);
}
// TODO(pthoreho) For all comparision operators, use
//
edsc::intrinsics::
zero_extendi(ValueHandle(iLHS(ivs)) !=
// zero_extendi(ValueHandle(iLHS(ivs)) !=
// ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext()));
// instead of
edsc::intrinsics::
select once `zero_extendi` is
// instead of
std_
select once `zero_extendi` is
// made available in the edsc::intrinsics namescope in MLIR repo.
else
if
(
isa
<
NGGreaterOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
}
else
if
(
isa
<
NGLessOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
}
else
if
(
isa
<
NGGreaterEqOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
>=
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
>=
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
}
else
if
(
isa
<
NGLessEqOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
<=
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
<=
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
}
else
if
(
isa
<
NGEqOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
==
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
==
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
}
else
if
(
isa
<
NGNotEqOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
!=
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
!=
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
}
else
if
(
isa
<
NGMaxOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
ValueHandle
(
iLHS
(
ivs
)),
ValueHandle
(
iRHS
(
ivs
)));
}
else
if
(
isa
<
NGMinOp
>
(
op
))
{
iRes
(
ivs
)
=
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
ValueHandle
(
iLHS
(
ivs
)),
ValueHandle
(
iRHS
(
ivs
)));
}
...
...
@@ -1995,10 +1979,10 @@ namespace
Value
result
=
pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
// Views
MemRef
View
vRes
(
result
),
vArg
(
arg
);
MemRef
BoundsCapture
vRes
(
result
),
vArg
(
arg
);
// Index Values
StdIndexedValue
iRes
(
result
),
stdArg
(
arg
);
IndexedValue
affineArg
(
arg
);
Affine
IndexedValue
affineArg
(
arg
);
// Bounds Index Handles
auto
resLbs
=
vRes
.
getLbs
();
auto
resUbs
=
vRes
.
getUbs
();
...
...
@@ -2008,8 +1992,8 @@ namespace
Type
resTy
=
result
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
auto
ivs
=
makeIndexHandles
(
vRes
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vRes
.
rank
());
auto
pivs
=
makeHandlePointers
(
ivs
);
auto
steps
=
vRes
.
getSteps
();
auto
initVal
=
vArg
.
lb
(
axis
);
AffineLoopNestBuilder
(
pivs
,
resLbs
,
resUbs
,
steps
)(
...
...
@@ -2018,10 +2002,10 @@ namespace
// Generate loop nest that computes the actual index reduction.
{
auto
allIVs
=
makeIndexHandles
(
vArg
.
rank
());
auto
pAllIVs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
allIVs
)
);
auto
allIVs
=
ValueHandle
::
makeIndexHandles
(
vArg
.
rank
());
auto
pAllIVs
=
makeHandlePointers
(
allIVs
);
auto
steps
=
vArg
.
getSteps
();
SmallVector
<
Index
Handle
,
8
>
nonRedIVs
;
SmallVector
<
Value
Handle
,
8
>
nonRedIVs
;
Type
resTy
=
result
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
NGRAPH_CHECK
(
resTy
.
isa
<
IntegerType
>
(),
...
...
@@ -2049,10 +2033,8 @@ namespace
// Select the min/max value and cast it back to integer type before storing it.
ValueHandle
newRedIdx
=
std
::
is_same
<
RedOp
,
NGArgMinRedOp
>
()
?
edsc
::
intrinsics
::
select
(
affineArg
(
allIVs
)
<
stdArg
(
tempIVs
),
allIVs
[
axis
],
currRedIdx
)
:
edsc
::
intrinsics
::
select
(
stdArg
(
tempIVs
)
<
affineArg
(
allIVs
),
allIVs
[
axis
],
currRedIdx
);
?
std_select
(
affineArg
(
allIVs
)
<
stdArg
(
tempIVs
),
allIVs
[
axis
],
currRedIdx
)
:
std_select
(
stdArg
(
tempIVs
)
<
affineArg
(
allIVs
),
allIVs
[
axis
],
currRedIdx
);
iRes
(
nonRedIVs
)
=
ValueHandle
::
create
<
IndexCastOp
>
(
newRedIdx
,
resTy
);
});
...
...
@@ -2123,7 +2105,7 @@ namespace
castMemRef
(
inputs
,
outputs
,
rewriter
,
unrankedMemrefTy
);
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_1_input"
,
pass
.
getCallDecl
(
"callback_1_input"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
rewriter
);
...
...
@@ -2168,11 +2150,11 @@ namespace
{
if
(
floatTy
.
isF32
())
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
0.0
f
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
0.0
f
),
floatTy
);
}
else
if
(
floatTy
.
isF64
())
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
0.0
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
0.0
),
floatTy
);
}
else
{
...
...
@@ -2181,7 +2163,7 @@ namespace
}
else
if
(
auto
intTy
=
type
.
dyn_cast
<
IntegerType
>
())
{
return
intrinsics
::
constant_int
(
0
,
intTy
.
getWidth
());
return
std_
constant_int
(
0
,
intTy
.
getWidth
());
}
NGRAPH_UNREACHABLE
(
"Unsupported type"
);
}
...
...
@@ -2192,11 +2174,11 @@ namespace
{
if
(
floatTy
.
isF32
())
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
}
else
if
(
floatTy
.
isF64
())
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
}
else
{
...
...
@@ -2205,7 +2187,7 @@ namespace
}
else
if
(
auto
intTy
=
type
.
dyn_cast
<
IntegerType
>
())
{
return
intrinsics
::
constant_int
(
1
,
intTy
.
getWidth
());
return
std_
constant_int
(
1
,
intTy
.
getWidth
());
}
NGRAPH_UNREACHABLE
(
"Unsupported type"
);
}
...
...
src/contrib/mlir/core/pass/ng_dialect_fused_ops.cpp
View file @
e0135089
...
...
@@ -23,9 +23,7 @@
#include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include <llvm/IR/Module.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
...
...
src/contrib/mlir/runtime/cpu/cpu_callbacks.cpp
View file @
e0135089
...
...
@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
}
}
extern
"C"
void
_
_mlir
_callback_1_input
(
void
*
input
,
void
*
output
,
size_t
index
,
OpType
type
)
extern
"C"
void
_
mlir_ciface
_callback_1_input
(
void
*
input
,
void
*
output
,
size_t
index
,
OpType
type
)
{
auto
unrankedMemRefInput
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input
);
auto
unrankedMemRefOutput
=
reinterpret_cast
<
UnrankedMemRef
*>
(
output
);
...
...
@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index,
}
}
extern
"C"
void
__mlir_callback_2_inputs
(
void
*
input0
,
void
*
input1
,
void
*
output
,
size_t
index
,
OpType
type
)
extern
"C"
void
_mlir_ciface_callback_2_inputs
(
void
*
input0
,
void
*
input1
,
void
*
output
,
size_t
index
,
OpType
type
)
{
auto
unrankedMemRefInput0
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input0
);
auto
unrankedMemRefInput1
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input1
);
...
...
@@ -780,7 +780,7 @@ extern "C" void
}
}
extern
"C"
void
_
_mlir
_callback_3_inputs
(
extern
"C"
void
_
mlir_ciface
_callback_3_inputs
(
void
*
input0
,
void
*
input1
,
void
*
input2
,
void
*
output
,
size_t
index
,
OpType
type
)
{
auto
unrankedMemRefInput0
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input0
);
...
...
src/contrib/mlir/runtime/cpu/cpu_runtime.cpp
View file @
e0135089
...
...
@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{
NGRAPH_CHECK
(
m_module
,
"MLIR module is not ready."
);
auto
func
=
m_module
->
lookupSymbol
<
mlir
::
LLVM
::
LLVMFuncOp
>
(
"main"
);
auto
func
=
m_module
->
lookupSymbol
<
mlir
::
LLVM
::
LLVMFuncOp
>
(
"
_mlir_ciface_
main"
);
NGRAPH_CHECK
(
func
&&
!
func
.
getBlocks
().
empty
(),
"Function not found"
);
// Set external arguments
...
...
@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto
invocationResult
=
m_engine
->
invoke
(
"main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invokeArgs
));
auto
invocationResult
=
m_engine
->
invoke
(
"_mlir_ciface_main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invokeArgs
));
if
(
clDumpObjectFile
)
{
m_engine
->
dumpToObjectFile
(
clObjectFilename
.
empty
()
?
"jitted_mlir.o"
:
clObjectFilename
.
getValue
());
}
NGRAPH_CHECK
(
!
invocationResult
,
"JIT invocation of 'main' failed
\n
"
);
NGRAPH_CHECK
(
!
invocationResult
,
"JIT invocation of '
_mlir_ciface_
main' failed
\n
"
);
}
void
MLIRCPURuntime
::
cleanup
()
...
...
src/contrib/mlir/tools/ngraph-opt/CMakeLists.txt
View file @
e0135089
...
...
@@ -16,7 +16,7 @@
set
(
LIBS
mlir_backend
MLIROpt
Main
MLIROpt
Lib
MLIRPass
MLIRParser
LLVMSupport
...
...
src/contrib/mlir/utils.cpp
View file @
e0135089
...
...
@@ -21,10 +21,21 @@
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
#include <mlir/Dialect/AffineOps/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/LoopOps/LoopOps.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include <mlir/Dialect/VectorOps/VectorOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/LocationSnapshot.h>
#include <mlir/Transforms/Passes.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
using
namespace
mlir
;
static
llvm
::
cl
::
opt
<
bool
>
clPrintIRAfterAll
(
"ngraph-print-ir-after-all"
,
...
...
@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
void
ngraph
::
runtime
::
ngmlir
::
initializeNGraphMLIR
()
{
// Initialize a dialect only once.
// We currently have no way to query if a dialect is previously
// registered. So using a global flag instead.
static
bool
init
=
false
;
if
(
!
init
)
{
mlir
::
registerDialect
<
mlir
::
NGraphOpsDialect
>
();
init
=
true
;
}
// Initialize MLIR dialects and passes only once.
static
bool
init_once
=
[]()
{
// In-tree Dialects.
registerDialect
<
AffineOpsDialect
>
();
registerDialect
<
LLVM
::
LLVMDialect
>
();
registerDialect
<
loop
::
LoopOpsDialect
>
();
registerDialect
<
StandardOpsDialect
>
();
registerDialect
<
vector
::
VectorOpsDialect
>
();
// nGraph dialects.
registerDialect
<
mlir
::
NGraphOpsDialect
>
();
// In-tree passes.
// No-op to avoid DCE on the following pass initializations.
if
(
std
::
getenv
(
"bar"
)
!=
(
char
*
)
-
1
)
return
false
;
createCanonicalizerPass
();
createCSEPass
();
createVectorizePass
({});
createLoopUnrollPass
();
createLoopUnrollAndJamPass
();
createSimplifyAffineStructuresPass
();
createLoopFusionPass
();
createLoopInvariantCodeMotionPass
();
createAffineLoopInvariantCodeMotionPass
();
createPipelineDataTransferPass
();
createLowerAffinePass
();
createLoopTilingPass
(
0
);
createLoopCoalescingPass
();
createAffineDataCopyGenerationPass
(
0
,
0
);
createMemRefDataFlowOptPass
();
createStripDebugInfoPass
();
createPrintOpStatsPass
();
createInlinerPass
();
createSymbolDCEPass
();
createLocationSnapshotPass
({});
return
true
;
}();
(
void
)
init_once
;
}
void
ngraph
::
runtime
::
ngmlir
::
dumpMlirModule
(
const
std
::
string
msg
,
mlir
::
ModuleOp
module
)
...
...
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
e0135089
...
...
@@ -171,6 +171,8 @@ add_library(onnx_import STATIC
op/reshape.hpp
op/reverse_sequence.cpp
op/reverse_sequence.hpp
op/round.cpp
op/round.hpp
op/scatter_nd.cpp
op/scatter_nd.hpp
op/selu.cpp
...
...
src/ngraph/frontend/onnx_import/op/average_pool.cpp
View file @
e0135089
...
...
@@ -16,7 +16,6 @@
#include "average_pool.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "utils/pooling_factory.hpp"
namespace
ngraph
...
...
@@ -29,7 +28,7 @@ namespace ngraph
{
NodeVector
average_pool
(
const
Node
&
node
)
{
return
pooling
::
PoolingFactory
(
node
).
make_avg_pool
();
return
pooling
::
Local
PoolingFactory
(
node
).
make_avg_pool
();
}
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/max_pool.cpp
View file @
e0135089
...
...
@@ -31,7 +31,7 @@ namespace ngraph
{
NodeVector
max_pool
(
const
Node
&
node
)
{
auto
max_pool
=
pooling
::
PoolingFactory
(
node
).
make_max_pool
();
auto
max_pool
=
pooling
::
Local
PoolingFactory
(
node
).
make_max_pool
();
max_pool
.
emplace_back
(
std
::
make_shared
<
NullNode
>
());
// Indices (optional)
return
max_pool
;
}
...
...
src/ngraph/frontend/onnx_import/op/onehot.cpp
View file @
e0135089
...
...
@@ -42,9 +42,9 @@ namespace ngraph
auto
off_on_values
=
std
::
make_shared
<
default_opset
::
Split
>
(
values
,
split_axis
,
2
);
auto
off_value
=
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
0ul
));
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
size_t
{
0
}
));
auto
on_value
=
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
1ul
));
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
size_t
{
1
}
));
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
);
...
...
src/ngraph/frontend/onnx_import/op/pad.cpp
View file @
e0135089
...
...
@@ -65,14 +65,19 @@ namespace ngraph
NodeVector
pad
(
const
Node
&
node
)
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
const
Shape
&
data_shape
=
data
->
get_shape
();
const
auto
data_rank
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
).
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"Data rank must be static for pad op"
);
const
auto
data_rank_value
=
static_cast
<
size_t
>
(
data_rank
);
double
value
=
node
.
get_attribute_value
<
double
>
(
"value"
,
0
);
const
std
::
string
mode
=
node
.
get_attribute_value
<
std
::
string
>
(
"mode"
,
"constant"
);
ngraph
::
op
::
PadMode
pad_mode
=
get_pad_mode
(
mode
);
auto
paddings
=
convpool
::
get_pads
(
node
,
data_shap
e
);
const
auto
paddings
=
convpool
::
get_pads
(
node
,
data_rank_valu
e
);
ngraph
::
CoordinateDiff
padding_below
=
paddings
.
first
;
ngraph
::
CoordinateDiff
padding_above
=
paddings
.
second
;
...
...
src/ngraph/frontend/onnx_import/op/round.cpp
0 → 100644
View file @
e0135089
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/opsets/opset0.hpp"
#include "round.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
round
(
const
Node
&
node
)
{
const
std
::
shared_ptr
<
ngraph
::
Node
>
data
{
node
.
get_ng_inputs
().
at
(
0
)};
return
{
std
::
make_shared
<
ngraph
::
opset0
::
Round
>
(
data
)};
}
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/op/round.hpp
0 → 100644
View file @
e0135089
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
round
(
const
Node
&
node
);
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
e0135089
...
...
@@ -101,6 +101,7 @@
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/reverse_sequence.hpp"
#include "op/round.hpp"
#include "op/scatter_nd.hpp"
#include "op/selu.hpp"
#include "op/shape.hpp"
...
...
@@ -334,6 +335,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Relu"
,
1
,
relu
);
REGISTER_OPERATOR
(
"Reshape"
,
1
,
reshape
);
REGISTER_OPERATOR
(
"ReverseSequence"
,
1
,
reverse_sequence
);
REGISTER_OPERATOR
(
"Round"
,
1
,
round
);
REGISTER_OPERATOR
(
"ScatterND"
,
1
,
scatter_nd
);
REGISTER_OPERATOR
(
"Selu"
,
1
,
selu
);
REGISTER_OPERATOR
(
"Shape"
,
1
,
shape
);
...
...
src/ngraph/frontend/onnx_import/utils/convpool.cpp
View file @
e0135089
...
...
@@ -38,28 +38,41 @@ namespace ngraph
namespace
detail
{
Strides
get_strides_helper
(
const
Node
&
node
,
const
std
::
string
&
name
,
const
Shape
&
kernel_shape
)
{
return
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
name
,
std
::
vector
<
std
::
size_t
>
(
kernel_shape
.
size
(),
1UL
));
/// \brief Helper method used to read vector attribute
/// \note Default value is vector of size spatial dims filled with
/// ones
///
/// \param node Node from which attribute is read
/// \param attr_name Attribute name (such as `strides`, `dilations`)
///
/// \return Read vector attribute if available or default value
std
::
vector
<
std
::
size_t
>
get_attribute_value
(
const
Node
&
node
,
const
std
::
string
&
attr_name
)
{
if
(
node
.
has_attribute
(
attr_name
))
{
return
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
attr_name
);
}
}
// namespace detail
Strides
get_strides
(
const
Node
&
node
,
const
Shape
&
kernel_shape
)
{
return
detail
::
get_strides_helper
(
node
,
"strides"
,
kernel_shape
);
const
auto
data_rank
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
).
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"If '"
,
attr_name
,
"' is not provided data rank must be static"
);
const
auto
data_spatial_dims
=
static_cast
<
size_t
>
(
data_rank
)
-
2
;
return
std
::
vector
<
std
::
size_t
>
(
data_spatial_dims
,
1UL
);
}
}
// namespace detail
Strides
get_strides
(
const
Node
&
node
)
{
return
get_strides
(
node
,
get_kernel_shape
(
node
)
);
return
detail
::
get_attribute_value
(
node
,
"strides"
);
}
Strides
get_dilations
(
const
Node
&
node
)
{
return
detail
::
get_
strides_helper
(
node
,
"dilations"
,
get_kernel_shape
(
node
)
);
return
detail
::
get_
attribute_value
(
node
,
"dilations"
);
}
ngraph
::
op
::
PadType
get_auto_pad
(
const
Node
&
node
)
...
...
@@ -90,16 +103,16 @@ namespace ngraph
}
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
,
const
Shape
&
kernel_shape
)
const
size_t
kernel_rank
)
{
CoordinateDiff
pads
(
kernel_
shape
.
size
()
,
0
);
CoordinateDiff
pads
(
kernel_
rank
,
0
);
if
(
node
.
has_attribute
(
"pads"
))
{
auto
pads_int64
=
node
.
get_attribute_value
<
std
::
vector
<
int64_t
>>
(
"pads"
);
pads
=
CoordinateDiff
{
std
::
begin
(
pads_int64
),
std
::
end
(
pads_int64
)};
}
if
(
pads
.
size
()
==
kernel_
shape
.
size
()
*
2
)
if
(
pads
.
size
()
==
kernel_
rank
*
2
)
{
return
{{
std
::
begin
(
pads
),
std
::
begin
(
pads
)
+
pads
.
size
()
/
2
},
{
std
::
begin
(
pads
)
+
pads
.
size
()
/
2
,
std
::
end
(
pads
)}};
...
...
@@ -112,6 +125,18 @@ namespace ngraph
}
}
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
)
{
const
auto
data_rank
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
).
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"The rank of node must be static in order to calculate pads"
);
const
auto
data_spatial_dims
=
static_cast
<
size_t
>
(
data_rank
)
-
2
;
return
get_pads
(
node
,
data_spatial_dims
);
}
void
calculate_auto_pads
(
const
Shape
&
data_shape
,
const
Shape
&
filter_shape
,
const
Strides
&
strides
,
...
...
src/ngraph/frontend/onnx_import/utils/convpool.hpp
View file @
e0135089
...
...
@@ -33,13 +33,6 @@ namespace ngraph
/// \return The kernel Shape object representing its dimensions (height, width, depth).
Shape
get_kernel_shape
(
const
Node
&
node
);
/// \brief Get number of pixels to stride operation by in each direction.
///
/// \param node The Node ptr representing Conv or Pool operation.
/// \param kernel_shape The shape of the kernel which we retrieve strides for.
/// \return The kernel Shape object representing its dimensions (height, width, depth).
Strides
get_strides
(
const
Node
&
node
,
const
Shape
&
kernel_shape
);
/// \brief Get number of pixels to stride operation by in each direction.
///
/// \param node The Node ptr representing Conv or Pool operation.
...
...
@@ -59,12 +52,12 @@ namespace ngraph
/// `pads` value should follow [x1_begin, x2_begin..., x1_end, x2_end,...].
///
/// \param node The Node ptr representing ONNX operation.
/// \param kernel_
shape The shape
of the kernel which we retrieve pads for.
/// \param kernel_
rank The rank
of the kernel which we retrieve pads for.
///
/// \return A pair of (padding_above, padding_below), which elements contains number of
/// pixels to pad in respective dimensions (height, width, depth).
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
,
const
Shape
&
kernel_shape
);
const
size_t
kernel_rank
);
/// \brief Get padding values for the operation described by an ONNX node.
/// \details Values are taken from the `pads` attribute.
...
...
@@ -75,11 +68,7 @@ namespace ngraph
///
/// \return A pair of (padding_above, padding_below), which elements contains number of
/// pixels to pad in respective dimensions (height, width, depth).
inline
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
)
{
return
get_pads
(
node
,
get_kernel_shape
(
node
));
}
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
);
///
/// \brief Calculate paddings with respect to auto_pad value.
...
...
src/ngraph/frontend/onnx_import/utils/pooling_factory.cpp
View file @
e0135089
...
...
@@ -17,6 +17,7 @@
#include <iterator>
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "utils/convpool.hpp"
#include "utils/pooling_factory.hpp"
...
...
@@ -30,12 +31,11 @@ namespace ngraph
PoolingFactory
::
PoolingFactory
(
const
Node
&
node
)
:
m_onnx_node
{
node
}
,
m_inputs
{
node
.
get_ng_inputs
()}
,
m_kernel_shape
{
convpool
::
get_kernel_shape
(
node
)}
,
m_strides
{
convpool
::
get_strides
(
node
)}
,
m_dilations
{
convpool
::
get_dilations
(
node
)}
,
m_auto_pad
{
convpool
::
get_auto_pad
(
node
)}
{
auto
paddings
=
convpool
::
get_pads
(
node
);
const
auto
paddings
=
convpool
::
get_pads
(
node
);
const
CoordinateDiff
&
padding_above
{
paddings
.
second
};
const
CoordinateDiff
&
padding_below
{
paddings
.
first
};
m_padding_below
=
Shape
{
std
::
begin
(
padding_below
),
std
::
end
(
padding_below
)};
...
...
@@ -44,7 +44,7 @@ namespace ngraph
NodeVector
PoolingFactory
::
make_avg_pool
()
const
{
bool
count_include_pad
=
const
bool
count_include_pad
=
m_onnx_node
.
get_attribute_value
<
std
::
int64_t
>
(
"count_include_pad"
,
0
);
return
{
std
::
make_shared
<
default_opset
::
AvgPool
>
(
m_inputs
.
at
(
0
),
m_strides
,
...
...
@@ -67,13 +67,31 @@ namespace ngraph
m_auto_pad
)};
}
LocalPoolingFactory
::
LocalPoolingFactory
(
const
Node
&
node
)
:
PoolingFactory
(
node
)
{
// Kernel shape is required
m_kernel_shape
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
"kernel_shape"
);
}
GlobalPoolingFactory
::
GlobalPoolingFactory
(
const
Node
&
node
)
:
PoolingFactory
(
node
)
{
// Correct the kernel shape.
const
Shape
&
data_shape
{
m_inputs
.
at
(
0
)
->
get_shape
()};
const
auto
data_shape
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
);
const
auto
data_rank
=
data_shape
.
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"Data rank must be static for global pooling ops"
);
Shape
kernel_shape
;
for
(
auto
i
=
2
;
i
<
static_cast
<
size_t
>
(
data_rank
);
++
i
)
{
CHECK_VALID_NODE
(
node
,
data_shape
[
i
].
is_static
(),
"All spatial dimensions must be known for global pooling ops"
);
kernel_shape
.
emplace_back
(
static_cast
<
size_t
>
(
data_shape
[
i
]));
}
// Set shape to all but {N,C} axes.
m_kernel_shape
=
Shape
{
std
::
next
(
std
::
begin
(
data_shape
),
2
),
std
::
end
(
data_shape
)}
;
m_kernel_shape
=
kernel_shape
;
}
}
// namespace pooling
}
// namespace onnx_import
...
...
src/ngraph/frontend/onnx_import/utils/pooling_factory.hpp
View file @
e0135089
...
...
@@ -48,7 +48,6 @@ namespace ngraph
class
PoolingFactory
{
public
:
explicit
PoolingFactory
(
const
Node
&
node
);
virtual
~
PoolingFactory
()
=
default
;
///
...
...
@@ -64,6 +63,8 @@ namespace ngraph
NodeVector
make_max_pool
()
const
;
protected
:
explicit
PoolingFactory
(
const
Node
&
node
);
Node
m_onnx_node
;
const
NodeVector
m_inputs
;
Shape
m_kernel_shape
;
...
...
@@ -75,9 +76,20 @@ namespace ngraph
};
///
/// \brief Factory class which generates sub-graphs for ONNX '
glob
al' pooling
/// \brief Factory class which generates sub-graphs for ONNX '
loc
al' pooling
/// operators.
/// \note Kernel shape attribute is required
class
LocalPoolingFactory
:
public
PoolingFactory
{
public
:
explicit
LocalPoolingFactory
(
const
Node
&
node
);
virtual
~
LocalPoolingFactory
()
=
default
;
};
///
/// \brief Factory class which generates sub-graphs for ONNX 'global' pooling
/// operators.
/// \note Kernel shape is calculated based on spatial dims
class
GlobalPoolingFactory
:
public
PoolingFactory
{
public
:
...
...
src/ngraph/pattern/matcher.cpp
View file @
e0135089
...
...
@@ -39,11 +39,19 @@ namespace ngraph
MatcherState
::~
MatcherState
()
{
if
(
m_restore
)
{
if
(
!
m_matcher
->
m_matched_list
.
empty
())
{
m_matcher
->
m_matched_list
.
erase
(
m_matcher
->
m_matched_list
.
begin
()
+
m_watermark
,
m_matcher
->
m_matched_list
.
end
());
m_matcher
->
m_pattern_value_maps
.
erase
(
m_pattern_value_maps
.
begin
()
+
m_capture_size
,
m_pattern_value_maps
.
end
());
}
if
(
!
m_pattern_value_maps
.
empty
())
{
m_matcher
->
m_pattern_value_maps
.
erase
(
m_pattern_value_maps
.
begin
()
+
m_capture_size
,
m_pattern_value_maps
.
end
());
}
m_matcher
->
m_pattern_map
=
m_pattern_value_map
;
}
}
...
...
src/ngraph/runtime/cpu/cpu_runtime_context.hpp
View file @
e0135089
...
...
@@ -36,7 +36,7 @@
namespace
mkldnn
{
class
primitive
;
struct
primitive
;
}
namespace
ngraph
...
...
src/ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.hpp
View file @
e0135089
...
...
@@ -35,7 +35,7 @@
namespace
mkldnn
{
class
primitive
;
struct
primitive
;
}
namespace
ngraph
...
...
src/ngraph/runtime/gpu/unit_test.manifest
View file @
e0135089
...
...
@@ -453,6 +453,7 @@ model_gatherND_int32
model_gatherND_float
model_pad_constant
model_reciprocal
model_round
tile_3d_small_data_rank
tile_3d_few_repeats
select_v1
...
...
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
e0135089
...
...
@@ -282,6 +282,7 @@ model_argmax_int32
model_argmin_int32
model_lp_norm_default
model_instance_normalization
model_round
# passing locally, fails closeness checks in CI which may be too strict
elu
...
...
test/mlir/affine_conversion/callback_ops.mlir
View file @
e0135089
...
...
@@ -10,7 +10,7 @@
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
...
...
@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @
__mlir_
callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
...
...
@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @
__mlir_
callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
...
...
@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
...
...
@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
...
...
@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
...
...
@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
...
...
test/models/onnx/dynamic_shapes/average_pool_2d_dyn.prototxt
0 → 100644
View file @
e0135089
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "AveragePool"
attribute {
name: "kernel_shape"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/dynamic_shapes/global_average_pool_dyn.prototxt
0 → 100644
View file @
e0135089
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "GlobalAveragePool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_value: 5
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/dynamic_shapes/global_max_pool_dyn.prototxt
0 → 100644
View file @
e0135089
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "GlobalMaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_value: 5
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/dynamic_shapes/max_pool_2d_dyn.prototxt
0 → 100644
View file @
e0135089
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "MaxPool"
attribute {
name: "kernel_shape"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 1
ints: 1
ints: 1
ints: 1
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/round.prototxt
0 → 100644
View file @
e0135089
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
op_type: "Round"
}
name: "test_round"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 15
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 15
}
}
}
}
}
}
opset_import {
version: 11
}
test/onnx/onnx_import.in.cpp
View file @
e0135089
...
...
@@ -1963,3 +1963,30 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reciprocal)
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
model_round
)
{
const
auto
round_fn
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/round.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
round_fn
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0.1
f
,
0.5
f
,
0.9
f
,
1.2
f
,
1.5
f
,
1.8
f
,
2.3
f
,
2.5
f
,
2.7
f
,
-
1.1
f
,
-
1.5
f
,
-
1.9
f
,
-
2.2
f
,
-
2.5
f
,
-
2.8
f
});
test_case
.
add_expected_output
<
float
>
(
{
0.
f
,
0.
f
,
1.
f
,
1.
f
,
2.
f
,
2.
f
,
2.
f
,
2.
f
,
3.
f
,
-
1.
f
,
-
2.
f
,
-
2.
f
,
-
2.
f
,
-
2.
f
,
-
3.
f
});
test_case
.
run
();
}
test/onnx/onnx_import_dyn_shapes.in.cpp
View file @
e0135089
...
...
@@ -282,3 +282,83 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, model_conv_with_dynamic_batch)
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
avg_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/average_pool_2d_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
1
,
4
,
4
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
2.5
f
,
4.5
f
,
10.5
f
,
12.5
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
1
,
2
,
2
},
expected_values
);
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
max_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/max_pool_2d_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
1
,
4
,
4
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
0.
f
,
2.
f
,
3.
f
,
8.
f
,
10.
f
,
11.
f
,
12.
f
,
14.
f
,
15.
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
1
,
3
,
3
},
expected_values
);
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
global_avg_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/global_average_pool_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
3
,
5
,
5
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
12.
f
,
37.
f
,
62.
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
3
,
1
,
1
},
expected_values
);
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
global_max_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/global_max_pool_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
3
,
5
,
5
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
24.
f
,
49.
f
,
74.
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
3
,
1
,
1
},
expected_values
);
test_case
.
run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment