Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
310fcf07
Commit
310fcf07
authored
Feb 21, 2020
by
mbencer
Browse files
Options
Browse Files
Download
Plain Diff
Merge remote-tracking branch 'origin/master' into mbencer/BuilderSplitV1
parents
b4b352d5
6c0cf85a
Hide whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
826 additions
and
291 deletions
+826
-291
external_mkldnn_v1.cmake
cmake/external_mkldnn_v1.cmake
+3
-3
external_mlir.cmake
cmake/external_mlir.cmake
+1
-1
memory_analysis.cpp
src/contrib/mlir/backend/analysis/memory_analysis.cpp
+0
-1
cpu_backend.cpp
src/contrib/mlir/backend/cpu/cpu_backend.cpp
+2
-1
affine_lowerer.cpp
src/contrib/mlir/backend/pass/affine_lowerer.cpp
+187
-205
ng_dialect_fused_ops.cpp
src/contrib/mlir/core/pass/ng_dialect_fused_ops.cpp
+1
-3
cpu_callbacks.cpp
src/contrib/mlir/runtime/cpu/cpu_callbacks.cpp
+4
-4
cpu_runtime.cpp
src/contrib/mlir/runtime/cpu/cpu_runtime.cpp
+4
-3
CMakeLists.txt
src/contrib/mlir/tools/ngraph-opt/CMakeLists.txt
+1
-1
utils.cpp
src/contrib/mlir/utils.cpp
+54
-11
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+2
-0
average_pool.cpp
src/ngraph/frontend/onnx_import/op/average_pool.cpp
+1
-2
max_pool.cpp
src/ngraph/frontend/onnx_import/op/max_pool.cpp
+1
-1
onehot.cpp
src/ngraph/frontend/onnx_import/op/onehot.cpp
+2
-2
pad.cpp
src/ngraph/frontend/onnx_import/op/pad.cpp
+7
-2
round.cpp
src/ngraph/frontend/onnx_import/op/round.cpp
+41
-0
round.hpp
src/ngraph/frontend/onnx_import/op/round.hpp
+38
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+2
-0
convpool.cpp
src/ngraph/frontend/onnx_import/utils/convpool.cpp
+40
-15
convpool.hpp
src/ngraph/frontend/onnx_import/utils/convpool.hpp
+3
-14
pooling_factory.cpp
src/ngraph/frontend/onnx_import/utils/pooling_factory.cpp
+24
-6
pooling_factory.hpp
src/ngraph/frontend/onnx_import/utils/pooling_factory.hpp
+14
-2
gather.cpp
src/ngraph/op/gather.cpp
+1
-1
matcher.cpp
src/ngraph/pattern/matcher.cpp
+12
-4
cpu_runtime_context.hpp
src/ngraph/runtime/cpu/cpu_runtime_context.hpp
+1
-1
cpu_mkldnn_primitive_build.hpp
src/ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.hpp
+1
-1
unit_test.manifest
src/ngraph/runtime/gpu/unit_test.manifest
+1
-0
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+1
-0
callback_ops.mlir
test/mlir/affine_conversion/callback_ops.mlir
+7
-7
average_pool_2d_dyn.prototxt
test/models/onnx/dynamic_shapes/average_pool_2d_dyn.prototxt
+57
-0
global_average_pool_dyn.prototxt
...dels/onnx/dynamic_shapes/global_average_pool_dyn.prototxt
+51
-0
global_max_pool_dyn.prototxt
test/models/onnx/dynamic_shapes/global_max_pool_dyn.prototxt
+51
-0
max_pool_2d_dyn.prototxt
test/models/onnx/dynamic_shapes/max_pool_2d_dyn.prototxt
+65
-0
round.prototxt
test/models/onnx/round.prototxt
+39
-0
onnx_import.in.cpp
test/onnx/onnx_import.in.cpp
+27
-0
onnx_import_dyn_shapes.in.cpp
test/onnx/onnx_import_dyn_shapes.in.cpp
+80
-0
No files found.
cmake/external_mkldnn_v1.cmake
View file @
310fcf07
...
@@ -18,12 +18,12 @@ include(ExternalProject)
...
@@ -18,12 +18,12 @@ include(ExternalProject)
# Includes blas 3.8.0 in mkldnn
# Includes blas 3.8.0 in mkldnn
set
(
NGRAPH_MKLDNN_SHORT_VERSION 1
)
set
(
NGRAPH_MKLDNN_SHORT_VERSION 1
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 1.
1.1
.0
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 1.
2.0
.0
)
set
(
NGRAPH_MKLDNN_MKLML_ASSET_VERSION
"v0.21"
)
set
(
NGRAPH_MKLDNN_MKLML_ASSET_VERSION
"v0.21"
)
set
(
NGRAPH_MKLDNN_VERSION
"v1.
1.1
"
)
set
(
NGRAPH_MKLDNN_VERSION
"v1.
2
"
)
set
(
NGRAPH_MKLDNN_MKLML_VERSION
"2019.0.5.20190502"
)
set
(
NGRAPH_MKLDNN_MKLML_VERSION
"2019.0.5.20190502"
)
set
(
NGRAPH_MKLDNN_MKLML_WIN32_VERSION
"2020.0.20190813"
)
set
(
NGRAPH_MKLDNN_MKLML_WIN32_VERSION
"2020.0.20190813"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"v1.
1.1
"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"v1.
2
"
)
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
# Fetch and install MKL-DNN
# Fetch and install MKL-DNN
...
...
cmake/external_mlir.cmake
View file @
310fcf07
...
@@ -19,7 +19,7 @@ include(ExternalProject)
...
@@ -19,7 +19,7 @@ include(ExternalProject)
set
(
MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git
)
set
(
MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git
)
# Change these commit IDs to move to latest stable versions
# Change these commit IDs to move to latest stable versions
set
(
MLIR_LLVM_COMMIT_ID
96400ae
)
set
(
MLIR_LLVM_COMMIT_ID
376c6853
)
# MLIR environment variables. Some of them are used by LIT tool.
# MLIR environment variables. Some of them are used by LIT tool.
...
...
src/contrib/mlir/backend/analysis/memory_analysis.cpp
View file @
310fcf07
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include <llvm/ADT/DenseSet.h>
#include <llvm/ADT/DenseSet.h>
#include <map>
#include <map>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/IntegerSet.h>
...
...
src/contrib/mlir/backend/cpu/cpu_backend.cpp
View file @
310fcf07
...
@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect()
...
@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect()
void
MLIRCPUBackend
::
lowerStandardDialect
()
void
MLIRCPUBackend
::
lowerStandardDialect
()
{
{
mlir
::
PassManager
pm
(
&
m_context
);
mlir
::
PassManager
pm
(
&
m_context
);
pm
.
addPass
(
mlir
::
createLowerToLLVMPass
());
pm
.
addPass
(
mlir
::
createLowerToLLVMPass
(
/*useAlloca=*/
false
,
/*useBarePtrCallConv=*/
false
,
/*emitCWrappers=*/
true
));
// Apply any generic pass manager command line options.
// Apply any generic pass manager command line options.
mlir
::
applyPassManagerCLOptions
(
pm
);
mlir
::
applyPassManagerCLOptions
(
pm
);
...
...
src/contrib/mlir/backend/pass/affine_lowerer.cpp
View file @
310fcf07
...
@@ -28,9 +28,9 @@
...
@@ -28,9 +28,9 @@
#include <llvm/ADT/DenseSet.h>
#include <llvm/ADT/DenseSet.h>
#include <llvm/Support/Debug.h>
#include <llvm/Support/Debug.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/
Dialect/AffineOps/
EDSC/Builders.h>
#include <mlir/
EDSC/Helper
s.h>
#include <mlir/
Dialect/AffineOps/EDSC/Intrinsic
s.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/
Dialect/StandardOps/
EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/IntegerSet.h>
...
@@ -51,11 +51,10 @@ namespace
...
@@ -51,11 +51,10 @@ namespace
{
{
using
namespace
mlir
;
using
namespace
mlir
;
using
namespace
mlir
::
edsc
;
using
namespace
mlir
::
edsc
;
using
namespace
mlir
::
edsc
::
intrinsics
;
using
namespace
mlir
::
edsc
::
op
;
using
namespace
mlir
::
edsc
::
op
;
using
namespace
ngraph
::
runtime
;
using
namespace
ngraph
::
runtime
;
using
namespace
ngraph
::
runtime
::
ngmlir
;
using
namespace
ngraph
::
runtime
::
ngmlir
;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using
StdIndexedValue
=
TemplatedIndexedValue
<
intrinsics
::
std_load
,
intrinsics
::
std_store
>
;
class
DialectLoweringPass
;
class
DialectLoweringPass
;
...
@@ -215,9 +214,37 @@ namespace
...
@@ -215,9 +214,37 @@ namespace
NGraphTypeConverter
()
NGraphTypeConverter
()
:
TypeConverter
()
:
TypeConverter
()
{
{
}
// TODO(dcaballe): split this into independent conversion patterns when there is a
// way to check if a type is valid in Std dialect.
addConversion
([
this
](
Type
type
)
->
Type
{
if
(
auto
tensorType
=
type
.
dyn_cast
<
NGTensorType
>
())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std
// TensorType. This may change in the future.
return
MemRefType
::
get
(
tensorType
.
getShape
(),
convertType
(
tensorType
.
getElementType
()),
{
/* no map used */
},
0
);
}
if
(
auto
floatType
=
type
.
dyn_cast
<
NGFloatType
>
())
{
// Float types are already std type.
return
floatType
;
}
if
(
auto
intType
=
type
.
dyn_cast
<
NGIntegerType
>
())
{
return
mlir
::
IntegerType
::
get
(
intType
.
getWidth
(),
intType
.
getContext
());
}
if
(
auto
boolType
=
type
.
dyn_cast
<
NGBoolType
>
())
{
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
boolType
.
getContext
());
}
Type
convertType
(
Type
t
)
override
;
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to
// return the input type if the type is not supported.
return
type
;
});
}
};
};
/// Dialect Lowering Pass to affine ops
/// Dialect Lowering Pass to affine ops
...
@@ -317,7 +344,8 @@ namespace
...
@@ -317,7 +344,8 @@ namespace
// TODO: Encode no alias attribute as part of the function signature conversion or as a
// TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion.
// separate rewrite pattern. Retrieve new function after signature conversion.
insertNoAliasArgAttrs
();
// TODO: To be enabled in follow-up commit.
// insertNoAliasArgAttrs();
}
}
opAttrsVec
=
m_attrsVec
;
opAttrsVec
=
m_attrsVec
;
...
@@ -492,22 +520,22 @@ namespace
...
@@ -492,22 +520,22 @@ namespace
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics.
/// by nGraph op semantics.
void
DialectLoweringPass
::
insertNoAliasArgAttrs
()
//
void DialectLoweringPass::insertNoAliasArgAttrs()
{
//
{
FuncOp
func
=
getModule
().
lookupSymbol
<
mlir
::
FuncOp
>
(
funcName
);
//
FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK
(
func
,
"FuncOp '"
+
funcName
.
str
()
+
"' not found"
);
//
NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
unsigned
int
argIdx
=
0
;
//
unsigned int argIdx = 0;
for
(
auto
arg
:
func
.
getArguments
())
//
for (auto arg : func.getArguments())
{
//
{
if
(
arg
.
getType
().
isa
<
MemRefType
>
())
//
if (arg.getType().isa<MemRefType>())
{
//
{
func
.
setArgAttr
(
argIdx
,
"llvm.noalias"
,
BoolAttr
::
get
(
true
,
&
getContext
()));
//
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
}
//
}
++
argIdx
;
//
++argIdx;
}
//
}
}
//
}
void
DialectLoweringPass
::
insertDeallocs
(
PatternRewriter
&
rewriter
)
void
DialectLoweringPass
::
insertDeallocs
(
PatternRewriter
&
rewriter
)
{
{
...
@@ -543,40 +571,6 @@ namespace
...
@@ -543,40 +571,6 @@ namespace
return
m_attrsVec
.
size
()
-
1
;
return
m_attrsVec
.
size
()
-
1
;
}
}
// NGDialect converters
Type
NGraphTypeConverter
::
convertType
(
Type
type
)
{
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since NGraphTypeConverter is private.
if
(
auto
tensorType
=
type
.
dyn_cast
<
NGTensorType
>
())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return
MemRefType
::
get
(
tensorType
.
getShape
(),
convertType
(
tensorType
.
getElementType
()),
{
/* no map used */
},
0
);
}
if
(
auto
floatType
=
type
.
dyn_cast
<
NGFloatType
>
())
{
// Float types are already std type.
return
floatType
;
}
if
(
auto
intType
=
type
.
dyn_cast
<
NGIntegerType
>
())
{
return
mlir
::
IntegerType
::
get
(
intType
.
getWidth
(),
intType
.
getContext
());
}
if
(
auto
boolType
=
type
.
dyn_cast
<
NGBoolType
>
())
{
return
mlir
::
IntegerType
::
get
(
1
/* width */
,
boolType
.
getContext
());
}
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return
// the input type if the type is not supported.
return
type
;
}
#define REWRITER(OP) \
#define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \
PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
...
@@ -680,15 +674,15 @@ namespace
...
@@ -680,15 +674,15 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
ScopedContext
scope
(
rewriter
,
loc
);
// Views
// Views
MemRef
View
vRes
(
result
),
vLHS
(
lhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLHS
(
lhs
);
// Index Values
// Index Values
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
Affine
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
// Bounds Index Handles
// Bounds Index Handles
auto
lbs
=
vLHS
.
getLbs
();
auto
lbs
=
vLHS
.
getLbs
();
auto
ubs
=
vLHS
.
getUbs
();
auto
ubs
=
vLHS
.
getUbs
();
// Loop induction vars
// Loop induction vars
auto
ivs
=
makeIndexHandles
(
vLHS
.
rank
());
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
pivs
=
makeHandlePointers
(
ivs
);
// Steps
// Steps
auto
steps
=
vLHS
.
getSteps
();
auto
steps
=
vLHS
.
getSteps
();
...
@@ -698,7 +692,7 @@ namespace
...
@@ -698,7 +692,7 @@ namespace
AffineLoopNestBuilder
(
pivs
,
lbs
,
ubs
,
steps
)([
&
]
{
AffineLoopNestBuilder
(
pivs
,
lbs
,
ubs
,
steps
)([
&
]
{
ValueHandle
val
=
iLHS
(
ivs
);
ValueHandle
val
=
iLHS
(
ivs
);
ValueHandle
zero
=
createZeroConstant
(
elemTy
);
ValueHandle
zero
=
createZeroConstant
(
elemTy
);
iRes
(
ivs
)
=
intrinsics
::
select
(
val
>
zero
,
val
,
zero
);
iRes
(
ivs
)
=
std_
select
(
val
>
zero
,
val
,
zero
);
});
});
rewriter
.
replaceOp
(
op
,
{
result
});
rewriter
.
replaceOp
(
op
,
{
result
});
...
@@ -742,36 +736,37 @@ namespace
...
@@ -742,36 +736,37 @@ namespace
// res[n, k] += lhs[n, m] * rhs[m, k]
// res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRef
View
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
);
NGRAPH_CHECK
(
vLhs
.
rank
()
==
2
&&
vRhs
.
rank
()
==
2
&&
vRes
.
rank
()
==
2
,
NGRAPH_CHECK
(
vLhs
.
rank
()
==
2
&&
vRhs
.
rank
()
==
2
&&
vRes
.
rank
()
==
2
,
"Dot operation is only supported for 2D tensors"
);
"Dot operation is only supported for 2D tensors"
);
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
// It's important to note that MemRefBoundsCapture priovides lb/ub/step info is "reverse
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first
// order", i.e., fastest varying dimension is the last one, slowest varying dimention is the
// one.
// first one.
IndexHandle
n
,
m
,
k
;
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
());
ValueHandle
n
(
indexType
),
m
(
indexType
),
k
(
indexType
);
unsigned
nDim
=
vLhs
.
fastestVarying
()
-
1
;
unsigned
nDim
=
vLhs
.
fastestVarying
()
-
1
;
unsigned
mDim
=
vRhs
.
fastestVarying
();
unsigned
mDim
=
vRhs
.
fastestVarying
();
unsigned
kDim
=
vRhs
.
fastestVarying
();
unsigned
kDim
=
vRhs
.
fastestVarying
();
Index
Handle
nLb
(
vLhs
.
lb
(
nDim
)),
mLb
(
vLhs
.
lb
(
mDim
)),
kLb
(
vRhs
.
lb
(
kDim
));
Value
Handle
nLb
(
vLhs
.
lb
(
nDim
)),
mLb
(
vLhs
.
lb
(
mDim
)),
kLb
(
vRhs
.
lb
(
kDim
));
Index
Handle
nUb
(
vLhs
.
ub
(
nDim
)),
mUb
(
vLhs
.
ub
(
mDim
)),
kUb
(
vRhs
.
ub
(
kDim
));
Value
Handle
nUb
(
vLhs
.
ub
(
nDim
)),
mUb
(
vLhs
.
ub
(
mDim
)),
kUb
(
vRhs
.
ub
(
kDim
));
int64_t
nStep
=
vLhs
.
step
(
nDim
),
mStep
=
vLhs
.
step
(
mDim
),
kStep
=
vRhs
.
step
(
kDim
);
int64_t
nStep
=
vLhs
.
step
(
nDim
),
mStep
=
vLhs
.
step
(
mDim
),
kStep
=
vRhs
.
step
(
kDim
);
// Constants and indexed values to be used inside the loop nest.
// Constants and indexed values to be used inside the loop nest.
IndexedValue
iRes
(
result
),
iLhs
(
lhs
),
iRhs
(
rhs
);
Affine
IndexedValue
iRes
(
result
),
iLhs
(
lhs
),
iRhs
(
rhs
);
ValueHandle
zeroInit
(
rewriter
.
create
<
ConstantOp
>
(
loc
,
rewriter
.
getZeroAttr
(
elemTy
)));
ValueHandle
zeroInit
(
rewriter
.
create
<
ConstantOp
>
(
loc
,
rewriter
.
getZeroAttr
(
elemTy
)));
{
{
IndexHandle
n
,
k
;
ValueHandle
n
(
indexType
),
k
(
indexType
)
;
LoopBuilder
::
makeAffine
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
makeAffineLoopBuilder
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
k
,
kLb
,
kUb
,
kStep
)([
&
]
{
iRes
(
n
,
k
)
=
zeroInit
;
});
makeAffineLoopBuilder
(
&
k
,
kLb
,
kUb
,
kStep
)([
&
]
{
iRes
(
n
,
k
)
=
zeroInit
;
});
});
});
}
}
LoopBuilder
::
makeAffine
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
makeAffineLoopBuilder
(
&
n
,
nLb
,
nUb
,
nStep
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
m
,
mLb
,
mUb
,
mStep
)([
&
]
{
makeAffineLoopBuilder
(
&
m
,
mLb
,
mUb
,
mStep
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
k
,
kLb
,
kUb
,
kStep
)(
makeAffineLoopBuilder
(
&
k
,
kLb
,
kUb
,
kStep
)(
[
&
]
{
iRes
(
n
,
k
)
+=
iLhs
(
n
,
m
)
*
iRhs
(
m
,
k
);
});
[
&
]
{
iRes
(
n
,
k
)
+=
iLhs
(
n
,
m
)
*
iRhs
(
m
,
k
);
});
});
});
});
});
...
@@ -792,13 +787,13 @@ namespace
...
@@ -792,13 +787,13 @@ namespace
NGRAPH_CHECK
(
result
,
"Unexpected null result in ConcatOp"
);
NGRAPH_CHECK
(
result
,
"Unexpected null result in ConcatOp"
);
// Create view to write into result.
// Create view to write into result.
MemRef
View
vRes
(
result
);
MemRef
BoundsCapture
vRes
(
result
);
auto
rank
=
vRes
.
rank
();
auto
rank
=
vRes
.
rank
();
// For each operand, generate a separate loop to copy into the target slice of "result".
// For each operand, generate a separate loop to copy into the target slice of "result".
// We'll keep track of the slice offsets via concatenation_axis_pos.
// We'll keep track of the slice offsets via concatenation_axis_pos.
auto
concatenationAxis
=
concat
.
concatenation_axis
().
getSExtValue
();
auto
concatenationAxis
=
concat
.
concatenation_axis
().
getSExtValue
();
IndexHandle
concatenationAxisPos
(
index_type
(
0
));
Value
concatenationAxisPos
(
std_constant_index
(
0
));
for
(
auto
&
operand
:
operands
)
for
(
auto
&
operand
:
operands
)
{
{
...
@@ -817,7 +812,7 @@ namespace
...
@@ -817,7 +812,7 @@ namespace
// [i_(r-2)][i_(r-1)]
// [i_(r-2)][i_(r-1)]
// :=
// :=
// operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
// operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
MemRef
View
vOperand
(
operand
);
MemRef
BoundsCapture
vOperand
(
operand
);
NGRAPH_CHECK
(
vOperand
.
rank
()
==
rank
,
"Unexpected rank mismatch"
);
NGRAPH_CHECK
(
vOperand
.
rank
()
==
rank
,
"Unexpected rank mismatch"
);
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVars
;
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVars
;
...
@@ -825,9 +820,10 @@ namespace
...
@@ -825,9 +820,10 @@ namespace
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVarLbs
;
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVarLbs
;
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVarUbs
;
llvm
::
SmallVector
<
ValueHandle
,
5
>
indexVarUbs
;
llvm
::
SmallVector
<
int64_t
,
5
>
indexVarSteps
;
llvm
::
SmallVector
<
int64_t
,
5
>
indexVarSteps
;
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
());
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
{
indexVars
.
push_back
(
IndexHandle
(
));
indexVars
.
push_back
(
ValueHandle
(
indexType
));
indexVarPtrs
.
push_back
(
&
(
indexVars
.
back
()));
indexVarPtrs
.
push_back
(
&
(
indexVars
.
back
()));
indexVarLbs
.
push_back
(
vOperand
.
lb
(
i
));
indexVarLbs
.
push_back
(
vOperand
.
lb
(
i
));
indexVarUbs
.
push_back
(
vOperand
.
ub
(
i
));
indexVarUbs
.
push_back
(
vOperand
.
ub
(
i
));
...
@@ -835,15 +831,15 @@ namespace
...
@@ -835,15 +831,15 @@ namespace
}
}
AffineLoopNestBuilder
(
indexVarPtrs
,
indexVarLbs
,
indexVarUbs
,
indexVarSteps
)([
&
]
{
AffineLoopNestBuilder
(
indexVarPtrs
,
indexVarLbs
,
indexVarUbs
,
indexVarSteps
)([
&
]
{
IndexedValue
ivRes
(
result
);
Affine
IndexedValue
ivRes
(
result
);
IndexedValue
ivOperand
(
operand
);
Affine
IndexedValue
ivOperand
(
operand
);
// On the LHS of the assignment, adjust the index for the concatenation axis.
// On the LHS of the assignment, adjust the index for the concatenation axis.
llvm
::
SmallVector
<
ValueHandle
,
5
>
resIndexHandles
;
llvm
::
SmallVector
<
ValueHandle
,
5
>
resIndexHandles
;
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
{
resIndexHandles
.
push_back
(
i
==
concatenationAxis
resIndexHandles
.
push_back
(
i
==
concatenationAxis
?
indexVars
[
i
]
+
concatenationAxisPos
?
indexVars
[
i
]
+
ValueHandle
(
concatenationAxisPos
)
:
indexVars
[
i
]);
:
indexVars
[
i
]);
}
}
...
@@ -851,11 +847,11 @@ namespace
...
@@ -851,11 +847,11 @@ namespace
});
});
// Move up concatenation_axis_pos for the next operand.
// Move up concatenation_axis_pos for the next operand.
concatenationAxisPos
=
concatenationAxisPos
+
vOperand
.
ub
(
concatenationAxis
);
concatenationAxisPos
=
ValueHandle
(
concatenationAxisPos
)
+
vOperand
.
ub
(
concatenationAxis
);
}
}
rewriter
.
replaceOp
(
op
,
{
result
});
rewriter
.
replaceOp
(
op
,
{
result
});
return
matchSuccess
();
return
matchSuccess
();
}
}
...
@@ -874,14 +870,13 @@ namespace
...
@@ -874,14 +870,13 @@ namespace
auto
axis
=
gatherOp
.
axis
().
getSExtValue
();
auto
axis
=
gatherOp
.
axis
().
getSExtValue
();
// Create view to write into result.
// Create view to write into result.
MemRef
View
vRes
(
result
),
vParams
(
params
),
vIndices
(
indices
);
MemRef
BoundsCapture
vRes
(
result
),
vParams
(
params
),
vIndices
(
indices
);
// Indexed Values
// Indexed Values
IndexedValue
iRes
(
result
),
iIndices
(
indices
);
Affine
IndexedValue
iRes
(
result
),
iIndices
(
indices
);
StdIndexedValue
iParams
(
params
);
StdIndexedValue
iParams
(
params
);
// Construct outer loop for params dims. Exclude the axis dim.
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector
<
ValueHandle
,
4
>
paramsLbs
,
paramsUbs
;
SmallVector
<
ValueHandle
,
4
>
paramsLbs
,
paramsUbs
,
paramsIVs
;
SmallVector
<
IndexHandle
,
4
>
paramsIVs
;
SmallVector
<
int64_t
,
4
>
paramsSteps
;
SmallVector
<
int64_t
,
4
>
paramsSteps
;
SmallVector
<
ValueHandle
*
,
4
>
paramsIVPtrs
;
SmallVector
<
ValueHandle
*
,
4
>
paramsIVPtrs
;
for
(
auto
i
=
0
;
i
<
vParams
.
rank
();
i
++
)
for
(
auto
i
=
0
;
i
<
vParams
.
rank
();
i
++
)
...
@@ -889,8 +884,8 @@ namespace
...
@@ -889,8 +884,8 @@ namespace
// skip gather axis
// skip gather axis
if
(
i
==
axis
)
if
(
i
==
axis
)
continue
;
continue
;
paramsLbs
.
push_back
(
IndexHandle
(
vParams
.
lb
(
i
)
));
paramsLbs
.
push_back
(
vParams
.
lb
(
i
));
paramsUbs
.
push_back
(
IndexHandle
(
vParams
.
ub
(
i
)
));
paramsUbs
.
push_back
(
vParams
.
ub
(
i
));
paramsSteps
.
push_back
(
vParams
.
step
(
i
));
paramsSteps
.
push_back
(
vParams
.
step
(
i
));
}
}
NGRAPH_CHECK
(
paramsLbs
.
size
()
==
vParams
.
rank
()
-
1
&&
NGRAPH_CHECK
(
paramsLbs
.
size
()
==
vParams
.
rank
()
-
1
&&
...
@@ -898,17 +893,17 @@ namespace
...
@@ -898,17 +893,17 @@ namespace
paramsSteps
.
size
()
==
paramsLbs
.
size
(),
paramsSteps
.
size
()
==
paramsLbs
.
size
(),
"Incorrect loop nest bounds size for gather params"
);
"Incorrect loop nest bounds size for gather params"
);
paramsIVs
=
makeIndexHandles
(
vParams
.
rank
()
-
1
);
paramsIVs
=
ValueHandle
::
makeIndexHandles
(
vParams
.
rank
()
-
1
);
paramsIVPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
paramsIVs
)
);
paramsIVPtrs
=
makeHandlePointers
(
paramsIVs
);
auto
indicesLbs
=
vIndices
.
getLbs
();
auto
indicesLbs
=
vIndices
.
getLbs
();
auto
indicesUbs
=
vIndices
.
getUbs
();
auto
indicesUbs
=
vIndices
.
getUbs
();
auto
indicesSteps
=
vIndices
.
getSteps
();
auto
indicesSteps
=
vIndices
.
getSteps
();
auto
indicesIVs
=
makeIndexHandles
(
vIndices
.
rank
());
auto
indicesIVs
=
ValueHandle
::
makeIndexHandles
(
vIndices
.
rank
());
auto
indicesIVPtrs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
indicesIVs
)
);
auto
indicesIVPtrs
=
makeHandlePointers
(
indicesIVs
);
SmallVector
<
Index
Handle
,
8
>
paramsIndices
,
resIndices
;
SmallVector
<
Value
Handle
,
8
>
paramsIndices
,
resIndices
;
// Make sure we are going to create loops
// Make sure we are going to create loops
NGRAPH_CHECK
(
vParams
.
rank
()
>
0
,
"Invalid size for indices steps"
);
NGRAPH_CHECK
(
vParams
.
rank
()
>
0
,
"Invalid size for indices steps"
);
...
@@ -946,7 +941,7 @@ namespace
...
@@ -946,7 +941,7 @@ namespace
{
{
if
(
i
==
axis
)
if
(
i
==
axis
)
{
{
paramsIndices
.
push_back
(
IndexHandle
(
axisIdx
)
);
paramsIndices
.
push_back
(
axisIdx
);
}
}
else
else
{
{
...
@@ -1022,10 +1017,10 @@ namespace
...
@@ -1022,10 +1017,10 @@ namespace
NGRAPH_CHECK
(
groups
>
0
,
"Invalid number of groups"
);
NGRAPH_CHECK
(
groups
>
0
,
"Invalid number of groups"
);
// create outer group convolution loop
// create outer group convolution loop
// for group = 0 to groups
// for group = 0 to groups
IndexHandle
iv
;
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
())
;
ValueHandle
iv
(
indexType
);
ValueHandle
lb
=
intrinsics
::
constant_index
(
0
);
ValueHandle
lb
=
std_
constant_index
(
0
);
ValueHandle
ub
=
intrinsics
::
constant_index
(
groups
);
ValueHandle
ub
=
std_
constant_index
(
groups
);
auto
imagesType
=
images
.
getType
().
cast
<
MemRefType
>
();
auto
imagesType
=
images
.
getType
().
cast
<
MemRefType
>
();
auto
filtersType
=
filters
.
getType
().
cast
<
MemRefType
>
();
auto
filtersType
=
filters
.
getType
().
cast
<
MemRefType
>
();
...
@@ -1043,13 +1038,13 @@ namespace
...
@@ -1043,13 +1038,13 @@ namespace
NGRAPH_CHECK
(
groupsInFilters
||
filtersShape
[
0
]
%
groups
==
0
,
NGRAPH_CHECK
(
groupsInFilters
||
filtersShape
[
0
]
%
groups
==
0
,
"Filters dim is not divisible by number of groups"
);
"Filters dim is not divisible by number of groups"
);
auto
channelGroupSize
=
intrinsics
::
constant_index
(
imagesShape
[
1
]
/
groups
);
auto
channelGroupSize
=
std_
constant_index
(
imagesShape
[
1
]
/
groups
);
auto
filtersGroupSize
=
intrinsics
::
constant_index
(
auto
filtersGroupSize
=
groupsInFilters
?
filtersShape
[
1
]
:
filtersShape
[
0
]
/
groups
);
std_constant_index
(
groupsInFilters
?
filtersShape
[
1
]
:
filtersShape
[
0
]
/
groups
);
NGRAPH_CHECK
(
!
groupsInFilters
||
groups
==
filtersShape
[
0
]);
NGRAPH_CHECK
(
!
groupsInFilters
||
groups
==
filtersShape
[
0
]);
LoopBuilder
::
makeAffine
(
&
iv
,
lb
,
ub
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
iv
,
lb
,
ub
,
1
)([
&
]
{
// lower/upper bounds on image channel dim and kernels dim
// lower/upper bounds on image channel dim and kernels dim
auto
cLb
=
iv
*
channelGroupSize
;
auto
cLb
=
iv
*
channelGroupSize
;
auto
cUb
=
cLb
+
channelGroupSize
;
auto
cUb
=
cLb
+
channelGroupSize
;
...
@@ -1152,7 +1147,7 @@ namespace
...
@@ -1152,7 +1147,7 @@ namespace
castMemRef
(
inputs
,
outputs
,
rewriter
,
unrankedMemrefTy
);
castMemRef
(
inputs
,
outputs
,
rewriter
,
unrankedMemrefTy
);
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_2_inputs"
,
"callback_2_inputs"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
{},
rewriter
);
rewriter
);
...
@@ -1245,7 +1240,7 @@ namespace
...
@@ -1245,7 +1240,7 @@ namespace
auto
int64Ty
=
rewriter
.
getIntegerType
(
64
);
auto
int64Ty
=
rewriter
.
getIntegerType
(
64
);
auto
unrankedMemrefTy
=
UnrankedMemRefType
::
get
(
elemTy
,
0
);
auto
unrankedMemrefTy
=
UnrankedMemRefType
::
get
(
elemTy
,
0
);
auto
callBackFunc
=
pass
.
getCallDecl
(
auto
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_2_inputs"
,
"callback_2_inputs"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
{},
rewriter
);
rewriter
);
...
@@ -1297,7 +1292,7 @@ namespace
...
@@ -1297,7 +1292,7 @@ namespace
elemTy
==
biasTy
.
getElementType
(),
elemTy
==
biasTy
.
getElementType
(),
"Types mismatch in GemmOp"
);
"Types mismatch in GemmOp"
);
MemRef
View
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
),
vBias
(
bias
);
MemRef
BoundsCapture
vRes
(
result
),
vLhs
(
lhs
),
vRhs
(
rhs
),
vBias
(
bias
);
NGRAPH_CHECK
(
vLhs
.
rank
()
==
2
&&
vRhs
.
rank
()
==
2
&&
vRes
.
rank
()
==
2
&&
vBias
.
rank
()
<=
2
,
NGRAPH_CHECK
(
vLhs
.
rank
()
==
2
&&
vRhs
.
rank
()
==
2
&&
vRes
.
rank
()
==
2
&&
vBias
.
rank
()
<=
2
,
"Gemm operation is only supported for 2D tensors"
);
"Gemm operation is only supported for 2D tensors"
);
...
@@ -1361,7 +1356,7 @@ namespace
...
@@ -1361,7 +1356,7 @@ namespace
auto
int64Ty
=
rewriter
.
getIntegerType
(
64
);
auto
int64Ty
=
rewriter
.
getIntegerType
(
64
);
auto
unrankedMemrefTy
=
UnrankedMemRefType
::
get
(
elemTy
,
0
);
auto
unrankedMemrefTy
=
UnrankedMemRefType
::
get
(
elemTy
,
0
);
auto
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_3_inputs"
,
auto
callBackFunc
=
pass
.
getCallDecl
(
"callback_3_inputs"
,
{
unrankedMemrefTy
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
unrankedMemrefTy
,
...
@@ -1425,7 +1420,7 @@ namespace
...
@@ -1425,7 +1420,7 @@ namespace
rewriter
.
getUnknownLoc
(),
static_cast
<
int64_t
>
(
OpType
::
SOFTMAX
),
64
);
rewriter
.
getUnknownLoc
(),
static_cast
<
int64_t
>
(
OpType
::
SOFTMAX
),
64
);
FuncOp
callBackFunc
=
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_1_input"
,
pass
.
getCallDecl
(
"callback_1_input"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
{},
rewriter
);
rewriter
);
...
@@ -1511,11 +1506,12 @@ namespace
...
@@ -1511,11 +1506,12 @@ namespace
auto
padBelow
=
padBelowAttr
.
getValue
();
auto
padBelow
=
padBelowAttr
.
getValue
();
auto
padAbove
=
padBelowAttr
.
getValue
();
auto
padAbove
=
padBelowAttr
.
getValue
();
Type
elemTy
=
images
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
Type
elemTy
=
images
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
auto
indexType
=
IndexType
::
get
(
rewriter
.
getContext
());
// Create views
// Create views
MemRef
View
vRes
(
result
),
vImages
(
images
),
vFilters
(
filters
);
MemRef
BoundsCapture
vRes
(
result
),
vImages
(
images
),
vFilters
(
filters
);
// Create indexed Values
// Create indexed Values
IndexedValue
iRes
(
result
),
iImages
(
images
),
iFilters
(
filters
);
Affine
IndexedValue
iRes
(
result
),
iImages
(
images
),
iFilters
(
filters
);
// Bounds on batch size N
// Bounds on batch size N
ValueHandle
batchLb
=
vImages
.
lb
(
0
),
batchUb
=
vImages
.
ub
(
0
);
ValueHandle
batchLb
=
vImages
.
lb
(
0
),
batchUb
=
vImages
.
ub
(
0
);
// Bounds on spatial dimensions
// Bounds on spatial dimensions
...
@@ -1526,9 +1522,8 @@ namespace
...
@@ -1526,9 +1522,8 @@ namespace
unsigned
spatialRank
=
vImages
.
rank
()
-
2
;
unsigned
spatialRank
=
vImages
.
rank
()
-
2
;
// Result spatial indices and bounds
// Result spatial indices and bounds
auto
resSpatialIndices
=
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndices
=
ValueHandle
::
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndicesPtrs
=
auto
resSpatialIndicesPtrs
=
makeHandlePointers
(
resSpatialIndices
);
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
resSpatialIndices
));
SmallVector
<
int64_t
,
4
>
resSteps
,
filtersSteps
;
SmallVector
<
int64_t
,
4
>
resSteps
,
filtersSteps
;
SmallVector
<
int
,
4
>
padBelowIntValues
;
SmallVector
<
int
,
4
>
padBelowIntValues
;
bool
withPadding
=
false
;
bool
withPadding
=
false
;
...
@@ -1610,9 +1605,8 @@ namespace
...
@@ -1610,9 +1605,8 @@ namespace
"Results spatial dims mismatches input"
);
"Results spatial dims mismatches input"
);
// Filters spatial indices and bounds
// Filters spatial indices and bounds
auto
filtersSpatialIndices
=
makeIndexHandles
(
spatialRank
);
auto
filtersSpatialIndices
=
ValueHandle
::
makeIndexHandles
(
spatialRank
);
auto
filtersSpatialIndicesPtrs
=
auto
filtersSpatialIndicesPtrs
=
makeHandlePointers
(
filtersSpatialIndices
);
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
filtersSpatialIndices
));
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
{
...
@@ -1658,23 +1652,22 @@ namespace
...
@@ -1658,23 +1652,22 @@ namespace
// Initialize output to zero
// Initialize output to zero
{
{
IndexHandle
n
,
k
,
c
;
ValueHandle
n
(
indexType
),
k
(
indexType
),
c
(
indexType
);
auto
resSpatialIndices
=
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndices
=
ValueHandle
::
makeIndexHandles
(
spatialRank
);
auto
resSpatialIndicesPtrs
=
auto
resSpatialIndicesPtrs
=
makeHandlePointers
(
resSpatialIndices
);
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
resSpatialIndices
));
LoopBuilder
::
makeAffine
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
LoopBuilder
::
makeAffine
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
AffineLoopNestBuilder
(
AffineLoopNestBuilder
(
resSpatialIndicesPtrs
,
resSpatialLbs
,
resSpatialUbs
,
resSteps
)([
&
]
{
resSpatialIndicesPtrs
,
resSpatialLbs
,
resSpatialUbs
,
resSteps
)([
&
]
{
SmallVector
<
Index
Handle
,
4
>
resIndices
;
SmallVector
<
Value
Handle
,
4
>
resIndices
;
// Result indices
// Result indices
resIndices
.
push_back
(
n
);
resIndices
.
push_back
(
n
);
if
(
groupConvolution
&&
groupsInFilters
)
if
(
groupConvolution
&&
groupsInFilters
)
{
{
// compute global C_OUT from gID and k
// compute global C_OUT from gID and k
// gId * C_OUT (num of filters) + k
// gId * C_OUT (num of filters) + k
resIndices
.
push_back
(
IndexHandle
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
)
);
resIndices
.
push_back
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
);
}
}
else
else
{
{
...
@@ -1689,31 +1682,31 @@ namespace
...
@@ -1689,31 +1682,31 @@ namespace
});
});
}
}
IndexHandle
n
,
k
,
c
;
ValueHandle
n
(
indexType
),
k
(
indexType
),
c
(
indexType
)
;
// Convolution loop
// Convolution loop
LoopBuilder
::
makeAffine
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
n
,
batchLb
,
batchUb
,
1
)([
&
]
{
// Number of filters loop
// Number of filters loop
LoopBuilder
::
makeAffine
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
k
,
numFiltersLb
,
numFiltersUb
,
1
)([
&
]
{
// Channels loop
// Channels loop
LoopBuilder
::
makeAffine
(
&
c
,
numChannelsLb
,
numChannelsUb
,
1
)([
&
]
{
makeAffineLoopBuilder
(
&
c
,
numChannelsLb
,
numChannelsUb
,
1
)([
&
]
{
// Results loop
// Results loop
AffineLoopNestBuilder
(
AffineLoopNestBuilder
(
resSpatialIndicesPtrs
,
resSpatialLbs
,
resSpatialUbs
,
resSteps
)([
&
]
{
resSpatialIndicesPtrs
,
resSpatialLbs
,
resSpatialUbs
,
resSteps
)([
&
]
{
// Compute image start indices
// Compute image start indices
SmallVector
<
Index
Handle
,
4
>
imgStartIndices
;
SmallVector
<
Value
Handle
,
4
>
imgStartIndices
;
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
{
IntegerAttr
iAttr
=
strides
[
i
].
cast
<
IntegerAttr
>
();
IntegerAttr
iAttr
=
strides
[
i
].
cast
<
IntegerAttr
>
();
auto
stride
=
intrinsics
::
constant_index
(
iAttr
.
getInt
());
auto
stride
=
std_
constant_index
(
iAttr
.
getInt
());
imgStartIndices
.
push_back
(
IndexHandle
(
resSpatialIndices
[
i
]
*
stride
)
);
imgStartIndices
.
push_back
(
resSpatialIndices
[
i
]
*
stride
);
}
}
SmallVector
<
Index
Handle
,
4
>
resIndices
;
SmallVector
<
Value
Handle
,
4
>
resIndices
;
// Result indices
// Result indices
resIndices
.
push_back
(
n
);
resIndices
.
push_back
(
n
);
if
(
groupConvolution
&&
groupsInFilters
)
if
(
groupConvolution
&&
groupsInFilters
)
{
{
// gId * C_OUT (num of filters) + k
// gId * C_OUT (num of filters) + k
resIndices
.
push_back
(
IndexHandle
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
)
);
resIndices
.
push_back
(
ValueHandle
(
gId
)
*
numFiltersUb
+
k
);
}
}
else
else
{
{
...
@@ -1727,15 +1720,14 @@ namespace
...
@@ -1727,15 +1720,14 @@ namespace
filtersSpatialLbs
,
filtersSpatialLbs
,
filtersSpatialUbs
,
filtersSpatialUbs
,
filtersSteps
)([
&
]
{
filtersSteps
)([
&
]
{
SmallVector
<
Index
Handle
,
4
>
imgIndices
,
filtersIndices
;
SmallVector
<
Value
Handle
,
4
>
imgIndices
,
filtersIndices
;
// Image indices
// Image indices
// Here we compute the virtual start index into the padded image.
// Here we compute the virtual start index into the padded image.
imgIndices
.
push_back
(
n
);
imgIndices
.
push_back
(
n
);
imgIndices
.
push_back
(
c
);
imgIndices
.
push_back
(
c
);
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
{
imgIndices
.
push_back
(
imgIndices
.
push_back
(
imgStartIndices
[
i
]
+
filtersSpatialIndices
[
i
]);
IndexHandle
(
imgStartIndices
[
i
]
+
filtersSpatialIndices
[
i
]));
}
}
// Filter indices
// Filter indices
...
@@ -1744,14 +1736,14 @@ namespace
...
@@ -1744,14 +1736,14 @@ namespace
// index
// index
if
(
groupConvolution
&&
groupsInFilters
)
if
(
groupConvolution
&&
groupsInFilters
)
{
{
filtersIndices
.
push_back
(
Index
Handle
(
gId
));
filtersIndices
.
push_back
(
Value
Handle
(
gId
));
}
}
filtersIndices
.
push_back
(
k
);
filtersIndices
.
push_back
(
k
);
// subtract lower bound of channel
// subtract lower bound of channel
// if we are doing group convolution this bound will advance based
// if we are doing group convolution this bound will advance based
// on the group id. For the filters, it should always start from 0
// on the group id. For the filters, it should always start from 0
filtersIndices
.
push_back
(
IndexHandle
(
c
-
numChannelsLb
)
);
filtersIndices
.
push_back
(
c
-
numChannelsLb
);
filtersIndices
.
insert
(
filtersIndices
.
end
(),
filtersIndices
.
insert
(
filtersIndices
.
end
(),
filtersSpatialIndices
.
begin
(),
filtersSpatialIndices
.
begin
(),
filtersSpatialIndices
.
end
());
filtersSpatialIndices
.
end
());
...
@@ -1759,7 +1751,7 @@ namespace
...
@@ -1759,7 +1751,7 @@ namespace
if
(
withPadding
)
if
(
withPadding
)
{
{
// if args : img dims, img lbs, img ubs
// if args : img dims, img lbs, img ubs
SmallVector
<
Index
Handle
,
4
>::
iterator
it
=
imgIndices
.
begin
();
SmallVector
<
Value
Handle
,
4
>::
iterator
it
=
imgIndices
.
begin
();
std
::
advance
(
it
,
2
);
std
::
advance
(
it
,
2
);
SmallVector
<
Value
,
4
>
affineIfArgs
(
it
,
imgIndices
.
end
());
SmallVector
<
Value
,
4
>
affineIfArgs
(
it
,
imgIndices
.
end
());
affineIfArgs
.
insert
(
affineIfArgs
.
insert
(
...
@@ -1777,14 +1769,14 @@ namespace
...
@@ -1777,14 +1769,14 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
ScopedContext
scope
(
rewriter
,
loc
);
// We must subtract pad below before img load, since the
// We must subtract pad below before img load, since the
// physical image is not padded
// physical image is not padded
SmallVector
<
Index
Handle
,
4
>
adjustedImgIndices
;
SmallVector
<
Value
Handle
,
4
>
adjustedImgIndices
;
adjustedImgIndices
.
push_back
(
n
);
adjustedImgIndices
.
push_back
(
n
);
adjustedImgIndices
.
push_back
(
c
);
adjustedImgIndices
.
push_back
(
c
);
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
for
(
auto
i
=
0
;
i
<
spatialRank
;
i
++
)
{
{
adjustedImgIndices
.
push_back
(
IndexHandle
(
adjustedImgIndices
.
push_back
(
imgIndices
[
2
+
i
]
-
imgIndices
[
2
+
i
]
-
intrinsics
::
constant_index
(
padBelowIntValues
[
i
])
));
std_constant_index
(
padBelowIntValues
[
i
]
));
}
}
iRes
(
resIndices
)
=
iRes
(
resIndices
)
=
iRes
(
resIndices
)
+
iRes
(
resIndices
)
+
...
@@ -1821,15 +1813,15 @@ namespace
...
@@ -1821,15 +1813,15 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
ScopedContext
scope
(
rewriter
,
loc
);
// Views
// Views
MemRef
View
vRes
(
result
),
vLHS
(
lhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLHS
(
lhs
);
// Index Values
// Index Values
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
Affine
IndexedValue
iRes
(
result
),
iLHS
(
lhs
);
// Bounds Index Handles
// Bounds Index Handles
auto
lbs
=
vLHS
.
getLbs
();
auto
lbs
=
vLHS
.
getLbs
();
auto
ubs
=
vLHS
.
getUbs
();
auto
ubs
=
vLHS
.
getUbs
();
// Loop induction vars
// Loop induction vars
auto
ivs
=
makeIndexHandles
(
vLHS
.
rank
());
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
pivs
=
makeHandlePointers
(
ivs
);
// Steps
// Steps
auto
steps
=
vLHS
.
getSteps
();
auto
steps
=
vLHS
.
getSteps
();
...
@@ -1867,15 +1859,15 @@ namespace
...
@@ -1867,15 +1859,15 @@ namespace
ScopedContext
scope
(
rewriter
,
loc
);
ScopedContext
scope
(
rewriter
,
loc
);
// Views
// Views
MemRef
View
vRes
(
result
),
vLHS
(
lhs
),
vRHS
(
rhs
);
MemRef
BoundsCapture
vRes
(
result
),
vLHS
(
lhs
),
vRHS
(
rhs
);
// Index Values
// Index Values
IndexedValue
iRes
(
result
),
iLHS
(
lhs
),
iRHS
(
rhs
);
Affine
IndexedValue
iRes
(
result
),
iLHS
(
lhs
),
iRHS
(
rhs
);
// Bounds Index Handles
// Bounds Index Handles
auto
lbs
=
vLHS
.
getLbs
();
auto
lbs
=
vLHS
.
getLbs
();
auto
ubs
=
vLHS
.
getUbs
();
auto
ubs
=
vLHS
.
getUbs
();
// Loop induction vars
// Loop induction vars
auto
ivs
=
makeIndexHandles
(
vLHS
.
rank
());
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vLHS
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
pivs
=
makeHandlePointers
(
ivs
);
// Steps
// Steps
auto
steps
=
vLHS
.
getSteps
();
auto
steps
=
vLHS
.
getSteps
();
// element type of the operand
// element type of the operand
...
@@ -1900,65 +1892,57 @@ namespace
...
@@ -1900,65 +1892,57 @@ namespace
iRes
(
ivs
)
=
iLHS
(
ivs
)
/
iRHS
(
ivs
);
iRes
(
ivs
)
=
iLHS
(
ivs
)
/
iRHS
(
ivs
);
}
}
// TODO(pthoreho) For all comparision operators, use
// TODO(pthoreho) For all comparision operators, use
//
edsc::intrinsics::
zero_extendi(ValueHandle(iLHS(ivs)) !=
// zero_extendi(ValueHandle(iLHS(ivs)) !=
// ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext()));
// ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext()));
// instead of
edsc::intrinsics::
select once `zero_extendi` is
// instead of
std_
select once `zero_extendi` is
// made available in the edsc::intrinsics namescope in MLIR repo.
// made available in the edsc::intrinsics namescope in MLIR repo.
else
if
(
isa
<
NGGreaterOp
>
(
op
))
else
if
(
isa
<
NGGreaterOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
createZeroConstant
(
elemTy
));
}
}
else
if
(
isa
<
NGLessOp
>
(
op
))
else
if
(
isa
<
NGLessOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
createZeroConstant
(
elemTy
));
}
}
else
if
(
isa
<
NGGreaterEqOp
>
(
op
))
else
if
(
isa
<
NGGreaterEqOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
>=
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
>=
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
createZeroConstant
(
elemTy
));
}
}
else
if
(
isa
<
NGLessEqOp
>
(
op
))
else
if
(
isa
<
NGLessEqOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
<=
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
<=
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
createZeroConstant
(
elemTy
));
}
}
else
if
(
isa
<
NGEqOp
>
(
op
))
else
if
(
isa
<
NGEqOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
==
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
==
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
createZeroConstant
(
elemTy
));
}
}
else
if
(
isa
<
NGNotEqOp
>
(
op
))
else
if
(
isa
<
NGNotEqOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
!=
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
!=
ValueHandle
(
iRHS
(
ivs
)),
createOneConstant
(
elemTy
),
createOneConstant
(
elemTy
),
createZeroConstant
(
elemTy
));
createZeroConstant
(
elemTy
));
}
}
else
if
(
isa
<
NGMaxOp
>
(
op
))
else
if
(
isa
<
NGMaxOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
>
ValueHandle
(
iRHS
(
ivs
)),
ValueHandle
(
iLHS
(
ivs
)),
ValueHandle
(
iLHS
(
ivs
)),
ValueHandle
(
iRHS
(
ivs
)));
ValueHandle
(
iRHS
(
ivs
)));
}
}
else
if
(
isa
<
NGMinOp
>
(
op
))
else
if
(
isa
<
NGMinOp
>
(
op
))
{
{
iRes
(
ivs
)
=
iRes
(
ivs
)
=
std_select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
edsc
::
intrinsics
::
select
(
ValueHandle
(
iLHS
(
ivs
))
<
ValueHandle
(
iRHS
(
ivs
)),
ValueHandle
(
iLHS
(
ivs
)),
ValueHandle
(
iLHS
(
ivs
)),
ValueHandle
(
iRHS
(
ivs
)));
ValueHandle
(
iRHS
(
ivs
)));
}
}
else
else
{
{
...
@@ -1995,10 +1979,10 @@ namespace
...
@@ -1995,10 +1979,10 @@ namespace
Value
result
=
pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
Value
result
=
pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
// Views
// Views
MemRef
View
vRes
(
result
),
vArg
(
arg
);
MemRef
BoundsCapture
vRes
(
result
),
vArg
(
arg
);
// Index Values
// Index Values
StdIndexedValue
iRes
(
result
),
stdArg
(
arg
);
StdIndexedValue
iRes
(
result
),
stdArg
(
arg
);
IndexedValue
affineArg
(
arg
);
Affine
IndexedValue
affineArg
(
arg
);
// Bounds Index Handles
// Bounds Index Handles
auto
resLbs
=
vRes
.
getLbs
();
auto
resLbs
=
vRes
.
getLbs
();
auto
resUbs
=
vRes
.
getUbs
();
auto
resUbs
=
vRes
.
getUbs
();
...
@@ -2008,8 +1992,8 @@ namespace
...
@@ -2008,8 +1992,8 @@ namespace
Type
resTy
=
result
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
Type
resTy
=
result
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
{
auto
ivs
=
makeIndexHandles
(
vRes
.
rank
());
auto
ivs
=
ValueHandle
::
makeIndexHandles
(
vRes
.
rank
());
auto
pivs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
ivs
)
);
auto
pivs
=
makeHandlePointers
(
ivs
);
auto
steps
=
vRes
.
getSteps
();
auto
steps
=
vRes
.
getSteps
();
auto
initVal
=
vArg
.
lb
(
axis
);
auto
initVal
=
vArg
.
lb
(
axis
);
AffineLoopNestBuilder
(
pivs
,
resLbs
,
resUbs
,
steps
)(
AffineLoopNestBuilder
(
pivs
,
resLbs
,
resUbs
,
steps
)(
...
@@ -2018,10 +2002,10 @@ namespace
...
@@ -2018,10 +2002,10 @@ namespace
// Generate loop nest that computes the actual index reduction.
// Generate loop nest that computes the actual index reduction.
{
{
auto
allIVs
=
makeIndexHandles
(
vArg
.
rank
());
auto
allIVs
=
ValueHandle
::
makeIndexHandles
(
vArg
.
rank
());
auto
pAllIVs
=
makeHandlePointers
(
MutableArrayRef
<
IndexHandle
>
(
allIVs
)
);
auto
pAllIVs
=
makeHandlePointers
(
allIVs
);
auto
steps
=
vArg
.
getSteps
();
auto
steps
=
vArg
.
getSteps
();
SmallVector
<
Index
Handle
,
8
>
nonRedIVs
;
SmallVector
<
Value
Handle
,
8
>
nonRedIVs
;
Type
resTy
=
result
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
Type
resTy
=
result
.
getType
().
cast
<
MemRefType
>
().
getElementType
();
NGRAPH_CHECK
(
resTy
.
isa
<
IntegerType
>
(),
NGRAPH_CHECK
(
resTy
.
isa
<
IntegerType
>
(),
...
@@ -2049,10 +2033,8 @@ namespace
...
@@ -2049,10 +2033,8 @@ namespace
// Select the min/max value and cast it back to integer type before storing it.
// Select the min/max value and cast it back to integer type before storing it.
ValueHandle
newRedIdx
=
ValueHandle
newRedIdx
=
std
::
is_same
<
RedOp
,
NGArgMinRedOp
>
()
std
::
is_same
<
RedOp
,
NGArgMinRedOp
>
()
?
edsc
::
intrinsics
::
select
(
?
std_select
(
affineArg
(
allIVs
)
<
stdArg
(
tempIVs
),
allIVs
[
axis
],
currRedIdx
)
affineArg
(
allIVs
)
<
stdArg
(
tempIVs
),
allIVs
[
axis
],
currRedIdx
)
:
std_select
(
stdArg
(
tempIVs
)
<
affineArg
(
allIVs
),
allIVs
[
axis
],
currRedIdx
);
:
edsc
::
intrinsics
::
select
(
stdArg
(
tempIVs
)
<
affineArg
(
allIVs
),
allIVs
[
axis
],
currRedIdx
);
iRes
(
nonRedIVs
)
=
ValueHandle
::
create
<
IndexCastOp
>
(
newRedIdx
,
resTy
);
iRes
(
nonRedIVs
)
=
ValueHandle
::
create
<
IndexCastOp
>
(
newRedIdx
,
resTy
);
});
});
...
@@ -2123,7 +2105,7 @@ namespace
...
@@ -2123,7 +2105,7 @@ namespace
castMemRef
(
inputs
,
outputs
,
rewriter
,
unrankedMemrefTy
);
castMemRef
(
inputs
,
outputs
,
rewriter
,
unrankedMemrefTy
);
FuncOp
callBackFunc
=
FuncOp
callBackFunc
=
pass
.
getCallDecl
(
"
__mlir_
callback_1_input"
,
pass
.
getCallDecl
(
"callback_1_input"
,
{
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{
unrankedMemrefTy
,
unrankedMemrefTy
,
int64Ty
,
int64Ty
},
{},
{},
rewriter
);
rewriter
);
...
@@ -2168,11 +2150,11 @@ namespace
...
@@ -2168,11 +2150,11 @@ namespace
{
{
if
(
floatTy
.
isF32
())
if
(
floatTy
.
isF32
())
{
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
0.0
f
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
0.0
f
),
floatTy
);
}
}
else
if
(
floatTy
.
isF64
())
else
if
(
floatTy
.
isF64
())
{
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
0.0
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
0.0
),
floatTy
);
}
}
else
else
{
{
...
@@ -2181,7 +2163,7 @@ namespace
...
@@ -2181,7 +2163,7 @@ namespace
}
}
else
if
(
auto
intTy
=
type
.
dyn_cast
<
IntegerType
>
())
else
if
(
auto
intTy
=
type
.
dyn_cast
<
IntegerType
>
())
{
{
return
intrinsics
::
constant_int
(
0
,
intTy
.
getWidth
());
return
std_
constant_int
(
0
,
intTy
.
getWidth
());
}
}
NGRAPH_UNREACHABLE
(
"Unsupported type"
);
NGRAPH_UNREACHABLE
(
"Unsupported type"
);
}
}
...
@@ -2192,11 +2174,11 @@ namespace
...
@@ -2192,11 +2174,11 @@ namespace
{
{
if
(
floatTy
.
isF32
())
if
(
floatTy
.
isF32
())
{
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
}
}
else
if
(
floatTy
.
isF64
())
else
if
(
floatTy
.
isF64
())
{
{
return
intrinsics
::
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
return
std_
constant_float
(
llvm
::
APFloat
(
1.0
f
),
floatTy
);
}
}
else
else
{
{
...
@@ -2205,7 +2187,7 @@ namespace
...
@@ -2205,7 +2187,7 @@ namespace
}
}
else
if
(
auto
intTy
=
type
.
dyn_cast
<
IntegerType
>
())
else
if
(
auto
intTy
=
type
.
dyn_cast
<
IntegerType
>
())
{
{
return
intrinsics
::
constant_int
(
1
,
intTy
.
getWidth
());
return
std_
constant_int
(
1
,
intTy
.
getWidth
());
}
}
NGRAPH_UNREACHABLE
(
"Unsupported type"
);
NGRAPH_UNREACHABLE
(
"Unsupported type"
);
}
}
...
...
src/contrib/mlir/core/pass/ng_dialect_fused_ops.cpp
View file @
310fcf07
...
@@ -23,9 +23,7 @@
...
@@ -23,9 +23,7 @@
#include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include <llvm/IR/Module.h>
#include <llvm/IR/Module.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/StandardTypes.h>
...
...
src/contrib/mlir/runtime/cpu/cpu_callbacks.cpp
View file @
310fcf07
...
@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
...
@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
}
}
}
}
extern
"C"
void
_
_mlir
_callback_1_input
(
void
*
input
,
void
*
output
,
size_t
index
,
OpType
type
)
extern
"C"
void
_
mlir_ciface
_callback_1_input
(
void
*
input
,
void
*
output
,
size_t
index
,
OpType
type
)
{
{
auto
unrankedMemRefInput
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input
);
auto
unrankedMemRefInput
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input
);
auto
unrankedMemRefOutput
=
reinterpret_cast
<
UnrankedMemRef
*>
(
output
);
auto
unrankedMemRefOutput
=
reinterpret_cast
<
UnrankedMemRef
*>
(
output
);
...
@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index,
...
@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index,
}
}
}
}
extern
"C"
void
extern
"C"
void
_mlir_ciface_callback_2_inputs
(
__mlir_callback_2_inputs
(
void
*
input0
,
void
*
input1
,
void
*
output
,
size_t
index
,
OpType
type
)
void
*
input0
,
void
*
input1
,
void
*
output
,
size_t
index
,
OpType
type
)
{
{
auto
unrankedMemRefInput0
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input0
);
auto
unrankedMemRefInput0
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input0
);
auto
unrankedMemRefInput1
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input1
);
auto
unrankedMemRefInput1
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input1
);
...
@@ -780,7 +780,7 @@ extern "C" void
...
@@ -780,7 +780,7 @@ extern "C" void
}
}
}
}
extern
"C"
void
_
_mlir
_callback_3_inputs
(
extern
"C"
void
_
mlir_ciface
_callback_3_inputs
(
void
*
input0
,
void
*
input1
,
void
*
input2
,
void
*
output
,
size_t
index
,
OpType
type
)
void
*
input0
,
void
*
input1
,
void
*
input2
,
void
*
output
,
size_t
index
,
OpType
type
)
{
{
auto
unrankedMemRefInput0
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input0
);
auto
unrankedMemRefInput0
=
reinterpret_cast
<
UnrankedMemRef
*>
(
input0
);
...
...
src/contrib/mlir/runtime/cpu/cpu_runtime.cpp
View file @
310fcf07
...
@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
...
@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{
{
NGRAPH_CHECK
(
m_module
,
"MLIR module is not ready."
);
NGRAPH_CHECK
(
m_module
,
"MLIR module is not ready."
);
auto
func
=
m_module
->
lookupSymbol
<
mlir
::
LLVM
::
LLVMFuncOp
>
(
"main"
);
auto
func
=
m_module
->
lookupSymbol
<
mlir
::
LLVM
::
LLVMFuncOp
>
(
"
_mlir_ciface_
main"
);
NGRAPH_CHECK
(
func
&&
!
func
.
getBlocks
().
empty
(),
"Function not found"
);
NGRAPH_CHECK
(
func
&&
!
func
.
getBlocks
().
empty
(),
"Function not found"
);
// Set external arguments
// Set external arguments
...
@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute()
...
@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
// Make sure the MutableArrayRef version is invoked.
auto
invocationResult
=
m_engine
->
invoke
(
"main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invokeArgs
));
auto
invocationResult
=
m_engine
->
invoke
(
"_mlir_ciface_main"
,
llvm
::
MutableArrayRef
<
void
*>
(
m_invokeArgs
));
if
(
clDumpObjectFile
)
if
(
clDumpObjectFile
)
{
{
m_engine
->
dumpToObjectFile
(
clObjectFilename
.
empty
()
?
"jitted_mlir.o"
m_engine
->
dumpToObjectFile
(
clObjectFilename
.
empty
()
?
"jitted_mlir.o"
:
clObjectFilename
.
getValue
());
:
clObjectFilename
.
getValue
());
}
}
NGRAPH_CHECK
(
!
invocationResult
,
"JIT invocation of 'main' failed
\n
"
);
NGRAPH_CHECK
(
!
invocationResult
,
"JIT invocation of '
_mlir_ciface_
main' failed
\n
"
);
}
}
void
MLIRCPURuntime
::
cleanup
()
void
MLIRCPURuntime
::
cleanup
()
...
...
src/contrib/mlir/tools/ngraph-opt/CMakeLists.txt
View file @
310fcf07
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
set
(
LIBS
set
(
LIBS
mlir_backend
mlir_backend
MLIROpt
Main
MLIROpt
Lib
MLIRPass
MLIRPass
MLIRParser
MLIRParser
LLVMSupport
LLVMSupport
...
...
src/contrib/mlir/utils.cpp
View file @
310fcf07
...
@@ -21,10 +21,21 @@
...
@@ -21,10 +21,21 @@
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/AffineOps/AffineOps.h>
#include <llvm/Support/Debug.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/LoopOps/LoopOps.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include <mlir/Dialect/VectorOps/VectorOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/LocationSnapshot.h>
#include <mlir/Transforms/Passes.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
using
namespace
mlir
;
static
llvm
::
cl
::
opt
<
bool
>
clPrintIRAfterAll
(
static
llvm
::
cl
::
opt
<
bool
>
clPrintIRAfterAll
(
"ngraph-print-ir-after-all"
,
"ngraph-print-ir-after-all"
,
...
@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
...
@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
void
ngraph
::
runtime
::
ngmlir
::
initializeNGraphMLIR
()
void
ngraph
::
runtime
::
ngmlir
::
initializeNGraphMLIR
()
{
{
// Initialize a dialect only once.
// Initialize MLIR dialects and passes only once.
// We currently have no way to query if a dialect is previously
static
bool
init_once
=
[]()
{
// registered. So using a global flag instead.
// In-tree Dialects.
static
bool
init
=
false
;
registerDialect
<
AffineOpsDialect
>
();
if
(
!
init
)
registerDialect
<
LLVM
::
LLVMDialect
>
();
{
registerDialect
<
loop
::
LoopOpsDialect
>
();
mlir
::
registerDialect
<
mlir
::
NGraphOpsDialect
>
();
registerDialect
<
StandardOpsDialect
>
();
init
=
true
;
registerDialect
<
vector
::
VectorOpsDialect
>
();
}
// nGraph dialects.
registerDialect
<
mlir
::
NGraphOpsDialect
>
();
// In-tree passes.
// No-op to avoid DCE on the following pass initializations.
if
(
std
::
getenv
(
"bar"
)
!=
(
char
*
)
-
1
)
return
false
;
createCanonicalizerPass
();
createCSEPass
();
createVectorizePass
({});
createLoopUnrollPass
();
createLoopUnrollAndJamPass
();
createSimplifyAffineStructuresPass
();
createLoopFusionPass
();
createLoopInvariantCodeMotionPass
();
createAffineLoopInvariantCodeMotionPass
();
createPipelineDataTransferPass
();
createLowerAffinePass
();
createLoopTilingPass
(
0
);
createLoopCoalescingPass
();
createAffineDataCopyGenerationPass
(
0
,
0
);
createMemRefDataFlowOptPass
();
createStripDebugInfoPass
();
createPrintOpStatsPass
();
createInlinerPass
();
createSymbolDCEPass
();
createLocationSnapshotPass
({});
return
true
;
}();
(
void
)
init_once
;
}
}
void
ngraph
::
runtime
::
ngmlir
::
dumpMlirModule
(
const
std
::
string
msg
,
mlir
::
ModuleOp
module
)
void
ngraph
::
runtime
::
ngmlir
::
dumpMlirModule
(
const
std
::
string
msg
,
mlir
::
ModuleOp
module
)
...
...
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
310fcf07
...
@@ -171,6 +171,8 @@ add_library(onnx_import STATIC
...
@@ -171,6 +171,8 @@ add_library(onnx_import STATIC
op/reshape.hpp
op/reshape.hpp
op/reverse_sequence.cpp
op/reverse_sequence.cpp
op/reverse_sequence.hpp
op/reverse_sequence.hpp
op/round.cpp
op/round.hpp
op/scatter_nd.cpp
op/scatter_nd.cpp
op/scatter_nd.hpp
op/scatter_nd.hpp
op/selu.cpp
op/selu.cpp
...
...
src/ngraph/frontend/onnx_import/op/average_pool.cpp
View file @
310fcf07
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include "average_pool.hpp"
#include "average_pool.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "utils/pooling_factory.hpp"
#include "utils/pooling_factory.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -29,7 +28,7 @@ namespace ngraph
...
@@ -29,7 +28,7 @@ namespace ngraph
{
{
NodeVector
average_pool
(
const
Node
&
node
)
NodeVector
average_pool
(
const
Node
&
node
)
{
{
return
pooling
::
PoolingFactory
(
node
).
make_avg_pool
();
return
pooling
::
Local
PoolingFactory
(
node
).
make_avg_pool
();
}
}
}
// namespace set_1
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/max_pool.cpp
View file @
310fcf07
...
@@ -31,7 +31,7 @@ namespace ngraph
...
@@ -31,7 +31,7 @@ namespace ngraph
{
{
NodeVector
max_pool
(
const
Node
&
node
)
NodeVector
max_pool
(
const
Node
&
node
)
{
{
auto
max_pool
=
pooling
::
PoolingFactory
(
node
).
make_max_pool
();
auto
max_pool
=
pooling
::
Local
PoolingFactory
(
node
).
make_max_pool
();
max_pool
.
emplace_back
(
std
::
make_shared
<
NullNode
>
());
// Indices (optional)
max_pool
.
emplace_back
(
std
::
make_shared
<
NullNode
>
());
// Indices (optional)
return
max_pool
;
return
max_pool
;
}
}
...
...
src/ngraph/frontend/onnx_import/op/onehot.cpp
View file @
310fcf07
...
@@ -42,9 +42,9 @@ namespace ngraph
...
@@ -42,9 +42,9 @@ namespace ngraph
auto
off_on_values
=
auto
off_on_values
=
std
::
make_shared
<
default_opset
::
Split
>
(
values
,
split_axis
,
2
);
std
::
make_shared
<
default_opset
::
Split
>
(
values
,
split_axis
,
2
);
auto
off_value
=
auto
off_value
=
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
0ul
));
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
size_t
{
0
}
));
auto
on_value
=
auto
on_value
=
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
1ul
));
reshape
::
interpret_as_scalar
(
get_output_element
(
off_on_values
,
size_t
{
1
}
));
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
);
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
);
...
...
src/ngraph/frontend/onnx_import/op/pad.cpp
View file @
310fcf07
...
@@ -65,14 +65,19 @@ namespace ngraph
...
@@ -65,14 +65,19 @@ namespace ngraph
NodeVector
pad
(
const
Node
&
node
)
NodeVector
pad
(
const
Node
&
node
)
{
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
const
Shape
&
data_shape
=
data
->
get_shape
();
const
auto
data_rank
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
).
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"Data rank must be static for pad op"
);
const
auto
data_rank_value
=
static_cast
<
size_t
>
(
data_rank
);
double
value
=
node
.
get_attribute_value
<
double
>
(
"value"
,
0
);
double
value
=
node
.
get_attribute_value
<
double
>
(
"value"
,
0
);
const
std
::
string
mode
=
const
std
::
string
mode
=
node
.
get_attribute_value
<
std
::
string
>
(
"mode"
,
"constant"
);
node
.
get_attribute_value
<
std
::
string
>
(
"mode"
,
"constant"
);
ngraph
::
op
::
PadMode
pad_mode
=
get_pad_mode
(
mode
);
ngraph
::
op
::
PadMode
pad_mode
=
get_pad_mode
(
mode
);
auto
paddings
=
convpool
::
get_pads
(
node
,
data_shap
e
);
const
auto
paddings
=
convpool
::
get_pads
(
node
,
data_rank_valu
e
);
ngraph
::
CoordinateDiff
padding_below
=
paddings
.
first
;
ngraph
::
CoordinateDiff
padding_below
=
paddings
.
first
;
ngraph
::
CoordinateDiff
padding_above
=
paddings
.
second
;
ngraph
::
CoordinateDiff
padding_above
=
paddings
.
second
;
...
...
src/ngraph/frontend/onnx_import/op/round.cpp
0 → 100644
View file @
310fcf07
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/opsets/opset0.hpp"
#include "round.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
round
(
const
Node
&
node
)
{
const
std
::
shared_ptr
<
ngraph
::
Node
>
data
{
node
.
get_ng_inputs
().
at
(
0
)};
return
{
std
::
make_shared
<
ngraph
::
opset0
::
Round
>
(
data
)};
}
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/op/round.hpp
0 → 100644
View file @
310fcf07
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
round
(
const
Node
&
node
);
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
310fcf07
...
@@ -101,6 +101,7 @@
...
@@ -101,6 +101,7 @@
#include "op/relu.hpp"
#include "op/relu.hpp"
#include "op/reshape.hpp"
#include "op/reshape.hpp"
#include "op/reverse_sequence.hpp"
#include "op/reverse_sequence.hpp"
#include "op/round.hpp"
#include "op/scatter_nd.hpp"
#include "op/scatter_nd.hpp"
#include "op/selu.hpp"
#include "op/selu.hpp"
#include "op/shape.hpp"
#include "op/shape.hpp"
...
@@ -334,6 +335,7 @@ namespace ngraph
...
@@ -334,6 +335,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Relu"
,
1
,
relu
);
REGISTER_OPERATOR
(
"Relu"
,
1
,
relu
);
REGISTER_OPERATOR
(
"Reshape"
,
1
,
reshape
);
REGISTER_OPERATOR
(
"Reshape"
,
1
,
reshape
);
REGISTER_OPERATOR
(
"ReverseSequence"
,
1
,
reverse_sequence
);
REGISTER_OPERATOR
(
"ReverseSequence"
,
1
,
reverse_sequence
);
REGISTER_OPERATOR
(
"Round"
,
1
,
round
);
REGISTER_OPERATOR
(
"ScatterND"
,
1
,
scatter_nd
);
REGISTER_OPERATOR
(
"ScatterND"
,
1
,
scatter_nd
);
REGISTER_OPERATOR
(
"Selu"
,
1
,
selu
);
REGISTER_OPERATOR
(
"Selu"
,
1
,
selu
);
REGISTER_OPERATOR
(
"Shape"
,
1
,
shape
);
REGISTER_OPERATOR
(
"Shape"
,
1
,
shape
);
...
...
src/ngraph/frontend/onnx_import/utils/convpool.cpp
View file @
310fcf07
...
@@ -38,28 +38,41 @@ namespace ngraph
...
@@ -38,28 +38,41 @@ namespace ngraph
namespace
detail
namespace
detail
{
{
Strides
get_strides_helper
(
const
Node
&
node
,
/// \brief Helper method used to read vector attribute
const
std
::
string
&
name
,
/// \note Default value is vector of size spatial dims filled with
const
Shape
&
kernel_shape
)
/// ones
///
/// \param node Node from which attribute is read
/// \param attr_name Attribute name (such as `strides`, `dilations`)
///
/// \return Read vector attribute if available or default value
std
::
vector
<
std
::
size_t
>
get_attribute_value
(
const
Node
&
node
,
const
std
::
string
&
attr_name
)
{
{
return
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
if
(
node
.
has_attribute
(
attr_name
))
name
,
std
::
vector
<
std
::
size_t
>
(
kernel_shape
.
size
(),
1UL
));
{
return
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
attr_name
);
}
const
auto
data_rank
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
).
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"If '"
,
attr_name
,
"' is not provided data rank must be static"
);
const
auto
data_spatial_dims
=
static_cast
<
size_t
>
(
data_rank
)
-
2
;
return
std
::
vector
<
std
::
size_t
>
(
data_spatial_dims
,
1UL
);
}
}
}
// namespace detail
}
// namespace detail
Strides
get_strides
(
const
Node
&
node
,
const
Shape
&
kernel_shape
)
{
return
detail
::
get_strides_helper
(
node
,
"strides"
,
kernel_shape
);
}
Strides
get_strides
(
const
Node
&
node
)
Strides
get_strides
(
const
Node
&
node
)
{
{
return
get_strides
(
node
,
get_kernel_shape
(
node
)
);
return
detail
::
get_attribute_value
(
node
,
"strides"
);
}
}
Strides
get_dilations
(
const
Node
&
node
)
Strides
get_dilations
(
const
Node
&
node
)
{
{
return
detail
::
get_
strides_helper
(
node
,
"dilations"
,
get_kernel_shape
(
node
)
);
return
detail
::
get_
attribute_value
(
node
,
"dilations"
);
}
}
ngraph
::
op
::
PadType
get_auto_pad
(
const
Node
&
node
)
ngraph
::
op
::
PadType
get_auto_pad
(
const
Node
&
node
)
...
@@ -90,16 +103,16 @@ namespace ngraph
...
@@ -90,16 +103,16 @@ namespace ngraph
}
}
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
,
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
,
const
Shape
&
kernel_shape
)
const
size_t
kernel_rank
)
{
{
CoordinateDiff
pads
(
kernel_
shape
.
size
()
,
0
);
CoordinateDiff
pads
(
kernel_
rank
,
0
);
if
(
node
.
has_attribute
(
"pads"
))
if
(
node
.
has_attribute
(
"pads"
))
{
{
auto
pads_int64
=
node
.
get_attribute_value
<
std
::
vector
<
int64_t
>>
(
"pads"
);
auto
pads_int64
=
node
.
get_attribute_value
<
std
::
vector
<
int64_t
>>
(
"pads"
);
pads
=
CoordinateDiff
{
std
::
begin
(
pads_int64
),
std
::
end
(
pads_int64
)};
pads
=
CoordinateDiff
{
std
::
begin
(
pads_int64
),
std
::
end
(
pads_int64
)};
}
}
if
(
pads
.
size
()
==
kernel_
shape
.
size
()
*
2
)
if
(
pads
.
size
()
==
kernel_
rank
*
2
)
{
{
return
{{
std
::
begin
(
pads
),
std
::
begin
(
pads
)
+
pads
.
size
()
/
2
},
return
{{
std
::
begin
(
pads
),
std
::
begin
(
pads
)
+
pads
.
size
()
/
2
},
{
std
::
begin
(
pads
)
+
pads
.
size
()
/
2
,
std
::
end
(
pads
)}};
{
std
::
begin
(
pads
)
+
pads
.
size
()
/
2
,
std
::
end
(
pads
)}};
...
@@ -112,6 +125,18 @@ namespace ngraph
...
@@ -112,6 +125,18 @@ namespace ngraph
}
}
}
}
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
)
{
const
auto
data_rank
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
).
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"The rank of node must be static in order to calculate pads"
);
const
auto
data_spatial_dims
=
static_cast
<
size_t
>
(
data_rank
)
-
2
;
return
get_pads
(
node
,
data_spatial_dims
);
}
void
calculate_auto_pads
(
const
Shape
&
data_shape
,
void
calculate_auto_pads
(
const
Shape
&
data_shape
,
const
Shape
&
filter_shape
,
const
Shape
&
filter_shape
,
const
Strides
&
strides
,
const
Strides
&
strides
,
...
...
src/ngraph/frontend/onnx_import/utils/convpool.hpp
View file @
310fcf07
...
@@ -33,13 +33,6 @@ namespace ngraph
...
@@ -33,13 +33,6 @@ namespace ngraph
/// \return The kernel Shape object representing its dimensions (height, width, depth).
/// \return The kernel Shape object representing its dimensions (height, width, depth).
Shape
get_kernel_shape
(
const
Node
&
node
);
Shape
get_kernel_shape
(
const
Node
&
node
);
/// \brief Get number of pixels to stride operation by in each direction.
///
/// \param node The Node ptr representing Conv or Pool operation.
/// \param kernel_shape The shape of the kernel which we retrieve strides for.
/// \return The kernel Shape object representing its dimensions (height, width, depth).
Strides
get_strides
(
const
Node
&
node
,
const
Shape
&
kernel_shape
);
/// \brief Get number of pixels to stride operation by in each direction.
/// \brief Get number of pixels to stride operation by in each direction.
///
///
/// \param node The Node ptr representing Conv or Pool operation.
/// \param node The Node ptr representing Conv or Pool operation.
...
@@ -59,12 +52,12 @@ namespace ngraph
...
@@ -59,12 +52,12 @@ namespace ngraph
/// `pads` value should follow [x1_begin, x2_begin..., x1_end, x2_end,...].
/// `pads` value should follow [x1_begin, x2_begin..., x1_end, x2_end,...].
///
///
/// \param node The Node ptr representing ONNX operation.
/// \param node The Node ptr representing ONNX operation.
/// \param kernel_
shape The shape
of the kernel which we retrieve pads for.
/// \param kernel_
rank The rank
of the kernel which we retrieve pads for.
///
///
/// \return A pair of (padding_above, padding_below), which elements contains number of
/// \return A pair of (padding_above, padding_below), which elements contains number of
/// pixels to pad in respective dimensions (height, width, depth).
/// pixels to pad in respective dimensions (height, width, depth).
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
,
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
,
const
Shape
&
kernel_shape
);
const
size_t
kernel_rank
);
/// \brief Get padding values for the operation described by an ONNX node.
/// \brief Get padding values for the operation described by an ONNX node.
/// \details Values are taken from the `pads` attribute.
/// \details Values are taken from the `pads` attribute.
...
@@ -75,11 +68,7 @@ namespace ngraph
...
@@ -75,11 +68,7 @@ namespace ngraph
///
///
/// \return A pair of (padding_above, padding_below), which elements contains number of
/// \return A pair of (padding_above, padding_below), which elements contains number of
/// pixels to pad in respective dimensions (height, width, depth).
/// pixels to pad in respective dimensions (height, width, depth).
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
);
inline
std
::
pair
<
CoordinateDiff
,
CoordinateDiff
>
get_pads
(
const
Node
&
node
)
{
return
get_pads
(
node
,
get_kernel_shape
(
node
));
}
///
///
/// \brief Calculate paddings with respect to auto_pad value.
/// \brief Calculate paddings with respect to auto_pad value.
...
...
src/ngraph/frontend/onnx_import/utils/pooling_factory.cpp
View file @
310fcf07
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <iterator>
#include <iterator>
#include "default_opset.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/coordinate_diff.hpp"
#include "utils/convpool.hpp"
#include "utils/convpool.hpp"
#include "utils/pooling_factory.hpp"
#include "utils/pooling_factory.hpp"
...
@@ -30,12 +31,11 @@ namespace ngraph
...
@@ -30,12 +31,11 @@ namespace ngraph
PoolingFactory
::
PoolingFactory
(
const
Node
&
node
)
PoolingFactory
::
PoolingFactory
(
const
Node
&
node
)
:
m_onnx_node
{
node
}
:
m_onnx_node
{
node
}
,
m_inputs
{
node
.
get_ng_inputs
()}
,
m_inputs
{
node
.
get_ng_inputs
()}
,
m_kernel_shape
{
convpool
::
get_kernel_shape
(
node
)}
,
m_strides
{
convpool
::
get_strides
(
node
)}
,
m_strides
{
convpool
::
get_strides
(
node
)}
,
m_dilations
{
convpool
::
get_dilations
(
node
)}
,
m_dilations
{
convpool
::
get_dilations
(
node
)}
,
m_auto_pad
{
convpool
::
get_auto_pad
(
node
)}
,
m_auto_pad
{
convpool
::
get_auto_pad
(
node
)}
{
{
auto
paddings
=
convpool
::
get_pads
(
node
);
const
auto
paddings
=
convpool
::
get_pads
(
node
);
const
CoordinateDiff
&
padding_above
{
paddings
.
second
};
const
CoordinateDiff
&
padding_above
{
paddings
.
second
};
const
CoordinateDiff
&
padding_below
{
paddings
.
first
};
const
CoordinateDiff
&
padding_below
{
paddings
.
first
};
m_padding_below
=
Shape
{
std
::
begin
(
padding_below
),
std
::
end
(
padding_below
)};
m_padding_below
=
Shape
{
std
::
begin
(
padding_below
),
std
::
end
(
padding_below
)};
...
@@ -44,7 +44,7 @@ namespace ngraph
...
@@ -44,7 +44,7 @@ namespace ngraph
NodeVector
PoolingFactory
::
make_avg_pool
()
const
NodeVector
PoolingFactory
::
make_avg_pool
()
const
{
{
bool
count_include_pad
=
const
bool
count_include_pad
=
m_onnx_node
.
get_attribute_value
<
std
::
int64_t
>
(
"count_include_pad"
,
0
);
m_onnx_node
.
get_attribute_value
<
std
::
int64_t
>
(
"count_include_pad"
,
0
);
return
{
std
::
make_shared
<
default_opset
::
AvgPool
>
(
m_inputs
.
at
(
0
),
return
{
std
::
make_shared
<
default_opset
::
AvgPool
>
(
m_inputs
.
at
(
0
),
m_strides
,
m_strides
,
...
@@ -67,13 +67,31 @@ namespace ngraph
...
@@ -67,13 +67,31 @@ namespace ngraph
m_auto_pad
)};
m_auto_pad
)};
}
}
LocalPoolingFactory
::
LocalPoolingFactory
(
const
Node
&
node
)
:
PoolingFactory
(
node
)
{
// Kernel shape is required
m_kernel_shape
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
"kernel_shape"
);
}
GlobalPoolingFactory
::
GlobalPoolingFactory
(
const
Node
&
node
)
GlobalPoolingFactory
::
GlobalPoolingFactory
(
const
Node
&
node
)
:
PoolingFactory
(
node
)
:
PoolingFactory
(
node
)
{
{
// Correct the kernel shape.
const
auto
data_shape
=
node
.
get_ng_inputs
().
at
(
0
)
->
get_output_partial_shape
(
0
);
const
Shape
&
data_shape
{
m_inputs
.
at
(
0
)
->
get_shape
()};
const
auto
data_rank
=
data_shape
.
rank
();
CHECK_VALID_NODE
(
node
,
data_rank
.
is_static
(),
"Data rank must be static for global pooling ops"
);
Shape
kernel_shape
;
for
(
auto
i
=
2
;
i
<
static_cast
<
size_t
>
(
data_rank
);
++
i
)
{
CHECK_VALID_NODE
(
node
,
data_shape
[
i
].
is_static
(),
"All spatial dimensions must be known for global pooling ops"
);
kernel_shape
.
emplace_back
(
static_cast
<
size_t
>
(
data_shape
[
i
]));
}
// Set shape to all but {N,C} axes.
// Set shape to all but {N,C} axes.
m_kernel_shape
=
Shape
{
std
::
next
(
std
::
begin
(
data_shape
),
2
),
std
::
end
(
data_shape
)}
;
m_kernel_shape
=
kernel_shape
;
}
}
}
// namespace pooling
}
// namespace pooling
}
// namespace onnx_import
}
// namespace onnx_import
...
...
src/ngraph/frontend/onnx_import/utils/pooling_factory.hpp
View file @
310fcf07
...
@@ -48,7 +48,6 @@ namespace ngraph
...
@@ -48,7 +48,6 @@ namespace ngraph
class
PoolingFactory
class
PoolingFactory
{
{
public
:
public
:
explicit
PoolingFactory
(
const
Node
&
node
);
virtual
~
PoolingFactory
()
=
default
;
virtual
~
PoolingFactory
()
=
default
;
///
///
...
@@ -64,6 +63,8 @@ namespace ngraph
...
@@ -64,6 +63,8 @@ namespace ngraph
NodeVector
make_max_pool
()
const
;
NodeVector
make_max_pool
()
const
;
protected
:
protected
:
explicit
PoolingFactory
(
const
Node
&
node
);
Node
m_onnx_node
;
Node
m_onnx_node
;
const
NodeVector
m_inputs
;
const
NodeVector
m_inputs
;
Shape
m_kernel_shape
;
Shape
m_kernel_shape
;
...
@@ -75,9 +76,20 @@ namespace ngraph
...
@@ -75,9 +76,20 @@ namespace ngraph
};
};
///
///
/// \brief Factory class which generates sub-graphs for ONNX '
glob
al' pooling
/// \brief Factory class which generates sub-graphs for ONNX '
loc
al' pooling
/// operators.
/// operators.
/// \note Kernel shape attribute is required
class
LocalPoolingFactory
:
public
PoolingFactory
{
public
:
explicit
LocalPoolingFactory
(
const
Node
&
node
);
virtual
~
LocalPoolingFactory
()
=
default
;
};
///
///
/// \brief Factory class which generates sub-graphs for ONNX 'global' pooling
/// operators.
/// \note Kernel shape is calculated based on spatial dims
class
GlobalPoolingFactory
:
public
PoolingFactory
class
GlobalPoolingFactory
:
public
PoolingFactory
{
{
public
:
public
:
...
...
src/ngraph/op/gather.cpp
View file @
310fcf07
...
@@ -130,7 +130,7 @@ void op::v1::Gather::validate_and_infer_types()
...
@@ -130,7 +130,7 @@ void op::v1::Gather::validate_and_infer_types()
")."
);
")."
);
}
}
auto
axis
=
get_axis
();
int64_t
axis
=
get_axis
();
if
(
input_rank
.
is_static
()
&&
axis
!=
AXIS_NOT_SET_VALUE
)
if
(
input_rank
.
is_static
()
&&
axis
!=
AXIS_NOT_SET_VALUE
)
{
{
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
...
...
src/ngraph/pattern/matcher.cpp
View file @
310fcf07
...
@@ -40,10 +40,18 @@ namespace ngraph
...
@@ -40,10 +40,18 @@ namespace ngraph
{
{
if
(
m_restore
)
if
(
m_restore
)
{
{
m_matcher
->
m_matched_list
.
erase
(
m_matcher
->
m_matched_list
.
begin
()
+
m_watermark
,
if
(
!
m_matcher
->
m_matched_list
.
empty
())
m_matcher
->
m_matched_list
.
end
());
{
m_matcher
->
m_pattern_value_maps
.
erase
(
m_pattern_value_maps
.
begin
()
+
m_capture_size
,
m_matcher
->
m_matched_list
.
erase
(
m_matcher
->
m_matched_list
.
begin
()
+
m_watermark
,
m_pattern_value_maps
.
end
());
m_matcher
->
m_matched_list
.
end
());
}
if
(
!
m_pattern_value_maps
.
empty
())
{
m_matcher
->
m_pattern_value_maps
.
erase
(
m_pattern_value_maps
.
begin
()
+
m_capture_size
,
m_pattern_value_maps
.
end
());
}
m_matcher
->
m_pattern_map
=
m_pattern_value_map
;
m_matcher
->
m_pattern_map
=
m_pattern_value_map
;
}
}
}
}
...
...
src/ngraph/runtime/cpu/cpu_runtime_context.hpp
View file @
310fcf07
...
@@ -36,7 +36,7 @@
...
@@ -36,7 +36,7 @@
namespace
mkldnn
namespace
mkldnn
{
{
class
primitive
;
struct
primitive
;
}
}
namespace
ngraph
namespace
ngraph
...
...
src/ngraph/runtime/cpu/pass/cpu_mkldnn_primitive_build.hpp
View file @
310fcf07
...
@@ -35,7 +35,7 @@
...
@@ -35,7 +35,7 @@
namespace
mkldnn
namespace
mkldnn
{
{
class
primitive
;
struct
primitive
;
}
}
namespace
ngraph
namespace
ngraph
...
...
src/ngraph/runtime/gpu/unit_test.manifest
View file @
310fcf07
...
@@ -453,6 +453,7 @@ model_gatherND_int32
...
@@ -453,6 +453,7 @@ model_gatherND_int32
model_gatherND_float
model_gatherND_float
model_pad_constant
model_pad_constant
model_reciprocal
model_reciprocal
model_round
tile_3d_small_data_rank
tile_3d_small_data_rank
tile_3d_few_repeats
tile_3d_few_repeats
select_v1
select_v1
...
...
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
310fcf07
...
@@ -282,6 +282,7 @@ model_argmax_int32
...
@@ -282,6 +282,7 @@ model_argmax_int32
model_argmin_int32
model_argmin_int32
model_lp_norm_default
model_lp_norm_default
model_instance_normalization
model_instance_normalization
model_round
# passing locally, fails closeness checks in CI which may be too strict
# passing locally, fails closeness checks in CI which may be too strict
elu
elu
...
...
test/mlir/affine_conversion/callback_ops.mlir
View file @
310fcf07
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
...
@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
...
@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @
__mlir_
callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
...
@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
...
@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @
__mlir_
callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
...
@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
...
@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
...
@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
...
@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
...
@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
...
@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
...
@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
...
@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @
__mlir_
callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
...
...
test/models/onnx/dynamic_shapes/average_pool_2d_dyn.prototxt
0 → 100644
View file @
310fcf07
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "AveragePool"
attribute {
name: "kernel_shape"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/dynamic_shapes/global_average_pool_dyn.prototxt
0 → 100644
View file @
310fcf07
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "GlobalAveragePool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_value: 5
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/dynamic_shapes/global_max_pool_dyn.prototxt
0 → 100644
View file @
310fcf07
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "GlobalMaxPool"
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_value: 5
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/dynamic_shapes/max_pool_2d_dyn.prototxt
0 → 100644
View file @
310fcf07
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "MaxPool"
attribute {
name: "kernel_shape"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "pads"
ints: 1
ints: 1
ints: 1
ints: 1
type: INTS
}
}
name: "compute_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
dim {
dim_param: "batch"
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 7
}
test/models/onnx/round.prototxt
0 → 100644
View file @
310fcf07
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
op_type: "Round"
}
name: "test_round"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 15
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 15
}
}
}
}
}
}
opset_import {
version: 11
}
test/onnx/onnx_import.in.cpp
View file @
310fcf07
...
@@ -1963,3 +1963,30 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reciprocal)
...
@@ -1963,3 +1963,30 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reciprocal)
test_case
.
run
();
test_case
.
run
();
}
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
model_round
)
{
const
auto
round_fn
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/round.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
round_fn
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0.1
f
,
0.5
f
,
0.9
f
,
1.2
f
,
1.5
f
,
1.8
f
,
2.3
f
,
2.5
f
,
2.7
f
,
-
1.1
f
,
-
1.5
f
,
-
1.9
f
,
-
2.2
f
,
-
2.5
f
,
-
2.8
f
});
test_case
.
add_expected_output
<
float
>
(
{
0.
f
,
0.
f
,
1.
f
,
1.
f
,
2.
f
,
2.
f
,
2.
f
,
2.
f
,
3.
f
,
-
1.
f
,
-
2.
f
,
-
2.
f
,
-
2.
f
,
-
2.
f
,
-
3.
f
});
test_case
.
run
();
}
test/onnx/onnx_import_dyn_shapes.in.cpp
View file @
310fcf07
...
@@ -282,3 +282,83 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, model_conv_with_dynamic_batch)
...
@@ -282,3 +282,83 @@ NGRAPH_TEST(onnx_dyn_shapes_${BACKEND_NAME}, model_conv_with_dynamic_batch)
test_case
.
run
();
test_case
.
run
();
}
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
avg_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/average_pool_2d_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
1
,
4
,
4
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
2.5
f
,
4.5
f
,
10.5
f
,
12.5
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
1
,
2
,
2
},
expected_values
);
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
max_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/max_pool_2d_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
1
,
4
,
4
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
0.
f
,
2.
f
,
3.
f
,
8.
f
,
10.
f
,
11.
f
,
12.
f
,
14.
f
,
15.
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
1
,
3
,
3
},
expected_values
);
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
global_avg_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/global_average_pool_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
3
,
5
,
5
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
12.
f
,
37.
f
,
62.
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
3
,
1
,
1
},
expected_values
);
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_dyn_shapes_
$
{
BACKEND_NAME
},
global_max_pool_dyn_shape
)
{
const
auto
function
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/dynamic_shapes/global_max_pool_dyn.prototxt"
));
auto
test_case
=
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
,
BackendMode
::
DYNAMIC
);
const
Shape
shape
{
1
,
3
,
5
,
5
};
const
auto
elems_in_tensor
=
shape_size
(
shape
);
std
::
vector
<
float
>
input_values
(
elems_in_tensor
);
std
::
iota
(
input_values
.
begin
(),
input_values
.
end
(),
0.
f
);
test_case
.
add_input
<
float
>
(
shape
,
input_values
);
std
::
vector
<
float
>
expected_values
{
24.
f
,
49.
f
,
74.
f
};
test_case
.
add_expected_output
<
float
>
(
Shape
{
1
,
3
,
1
,
1
},
expected_values
);
test_case
.
run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment