Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
b534a674
Unverified
Commit
b534a674
authored
Jul 01, 2019
by
Tomasz Dołbniak
Committed by
GitHub
Jul 01, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into tomdol/plaidml
parents
f5ff6806
530fce64
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
810 additions
and
187 deletions
+810
-187
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-0
tensor.cpp
src/ngraph/descriptor/tensor.cpp
+3
-0
dimension.cpp
src/ngraph/dimension.cpp
+13
-13
dimension.hpp
src/ngraph/dimension.hpp
+22
-17
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
dyn_replace_slice.cpp
src/ngraph/op/experimental/dyn_replace_slice.cpp
+159
-0
dyn_replace_slice.hpp
src/ngraph/op/experimental/dyn_replace_slice.hpp
+78
-0
dyn_slice.cpp
src/ngraph/op/experimental/dyn_slice.cpp
+18
-146
op_tbl.hpp
src/ngraph/op/op_tbl.hpp
+1
-0
pad.cpp
src/ngraph/op/pad.cpp
+1
-1
partial_shape.cpp
src/ngraph/partial_shape.cpp
+13
-0
partial_shape.hpp
src/ngraph/partial_shape.hpp
+4
-0
dyn_elimination.cpp
src/ngraph/pass/dyn_elimination.cpp
+91
-2
dyn_elimination.hpp
src/ngraph/pass/dyn_elimination.hpp
+2
-1
gpu_backend.cpp
src/ngraph/runtime/gpu/gpu_backend.cpp
+1
-0
gpu_emitter.cpp
src/ngraph/runtime/gpu/gpu_emitter.cpp
+6
-0
intelgpu_backend.cpp
src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
+1
-0
unit_test.manifest
src/ngraph/runtime/intelgpu/unit_test.manifest
+1
-0
int_executable.hpp
src/ngraph/runtime/interpreter/int_executable.hpp
+2
-1
serializer.cpp
src/ngraph/serializer.cpp
+53
-2
validation_util.cpp
src/ngraph/validation_util.cpp
+258
-4
validation_util.hpp
src/ngraph/validation_util.hpp
+11
-0
CMakeLists.txt
test/CMakeLists.txt
+1
-0
dyn_elimination.cpp
test/dyn_elimination.cpp
+50
-0
dyn_replace_slice_test.in.cpp
test/dyn_replace_slice_test.in.cpp
+0
-0
generate_dyn_replace_slice_ref.py
test/ref_generators/generate_dyn_replace_slice_ref.py
+0
-0
type_prop.cpp
test/type_prop.cpp
+0
-0
update_dyn_replace_slice_reference.sh
test/update_dyn_replace_slice_reference.sh
+18
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
b534a674
...
@@ -142,6 +142,8 @@ set (SRC
...
@@ -142,6 +142,8 @@ set (SRC
op/experimental/dyn_broadcast.hpp
op/experimental/dyn_broadcast.hpp
op/experimental/dyn_pad.cpp
op/experimental/dyn_pad.cpp
op/experimental/dyn_pad.hpp
op/experimental/dyn_pad.hpp
op/experimental/dyn_replace_slice.cpp
op/experimental/dyn_replace_slice.hpp
op/experimental/dyn_reshape.cpp
op/experimental/dyn_reshape.cpp
op/experimental/dyn_reshape.hpp
op/experimental/dyn_reshape.hpp
op/experimental/dyn_slice.cpp
op/experimental/dyn_slice.cpp
...
...
src/ngraph/descriptor/tensor.cpp
View file @
b534a674
...
@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
...
@@ -46,6 +46,9 @@ descriptor::Tensor::Tensor(const element::Type& element_type,
void
descriptor
::
Tensor
::
set_tensor_type
(
const
element
::
Type
&
element_type
,
void
descriptor
::
Tensor
::
set_tensor_type
(
const
element
::
Type
&
element_type
,
const
PartialShape
&
pshape
)
const
PartialShape
&
pshape
)
{
{
NGRAPH_CHECK
(
pshape
.
all_non_negative
(),
"set_tensor_type called on a PartialShape containing negative dimensions: "
,
pshape
);
if
(
pshape
.
is_static
())
if
(
pshape
.
is_static
())
{
{
m_shape
=
pshape
.
to_shape
();
m_shape
=
pshape
.
to_shape
();
...
...
src/ngraph/dimension.cpp
View file @
b534a674
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
using
namespace
ngraph
;
using
namespace
ngraph
;
Dimension
::
Dimension
(
size
_t
dimension
)
Dimension
::
Dimension
(
int64
_t
dimension
)
:
m_dimension
(
dimension
)
:
m_dimension
(
dimension
)
{
{
if
(
dimension
==
s_dynamic_val
)
if
(
dimension
==
s_dynamic_val
)
...
@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
...
@@ -40,7 +40,7 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
{
{
if
(
dimension
.
is_static
())
if
(
dimension
.
is_static
())
{
{
return
(
str
<<
size
_t
(
dimension
));
return
(
str
<<
int64
_t
(
dimension
));
}
}
else
else
{
{
...
@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
...
@@ -50,36 +50,36 @@ std::ostream& ngraph::operator<<(std::ostream& str, const Dimension& dimension)
Dimension
Dimension
::
operator
+
(
const
Dimension
&
dim
)
const
Dimension
Dimension
::
operator
+
(
const
Dimension
&
dim
)
const
{
{
return
(
is_static
()
&&
dim
.
is_static
()
?
m_dimension
+
size
_t
(
dim
)
:
Dimension
::
dynamic
());
return
(
is_static
()
&&
dim
.
is_static
()
?
m_dimension
+
int64
_t
(
dim
)
:
Dimension
::
dynamic
());
}
}
Dimension
Dimension
::
operator
-
(
const
Dimension
&
dim
)
const
Dimension
Dimension
::
operator
-
(
const
Dimension
&
dim
)
const
{
{
return
(
is_static
()
&&
dim
.
is_static
()
?
m_dimension
-
size
_t
(
dim
)
:
Dimension
::
dynamic
());
return
(
is_static
()
&&
dim
.
is_static
()
?
m_dimension
-
int64
_t
(
dim
)
:
Dimension
::
dynamic
());
}
}
Dimension
Dimension
::
operator
*
(
const
Dimension
&
dim
)
const
Dimension
Dimension
::
operator
*
(
const
Dimension
&
dim
)
const
{
{
return
((
is_static
()
&&
dim
.
is_static
())
return
((
is_static
()
&&
dim
.
is_static
())
?
m_dimension
*
size
_t
(
dim
)
?
m_dimension
*
int64
_t
(
dim
)
:
(
is_static
()
&&
m_dimension
==
0
)
:
(
is_static
()
&&
m_dimension
==
0
)
?
0
?
0
:
(
dim
.
is_static
()
&&
size
_t
(
dim
)
==
0
)
?
0
:
Dimension
::
dynamic
());
:
(
dim
.
is_static
()
&&
int64
_t
(
dim
)
==
0
)
?
0
:
Dimension
::
dynamic
());
}
}
bool
Dimension
::
compatible
(
const
Dimension
&
d
)
const
bool
Dimension
::
compatible
(
const
Dimension
&
d
)
const
{
{
return
(
is_dynamic
()
||
d
.
is_dynamic
()
||
m_dimension
==
size
_t
(
d
));
return
(
is_dynamic
()
||
d
.
is_dynamic
()
||
m_dimension
==
int64
_t
(
d
));
}
}
bool
Dimension
::
relaxes
(
const
Dimension
&
d
)
const
bool
Dimension
::
relaxes
(
const
Dimension
&
d
)
const
{
{
return
(
is_dynamic
()
||
(
d
.
is_static
()
&&
size_t
(
*
this
)
==
size
_t
(
d
)));
return
(
is_dynamic
()
||
(
d
.
is_static
()
&&
int64_t
(
*
this
)
==
int64
_t
(
d
)));
}
}
bool
Dimension
::
refines
(
const
Dimension
&
d
)
const
bool
Dimension
::
refines
(
const
Dimension
&
d
)
const
{
{
return
(
d
.
is_dynamic
()
||
(
is_static
()
&&
size_t
(
d
)
==
size
_t
(
*
this
)));
return
(
d
.
is_dynamic
()
||
(
is_static
()
&&
int64_t
(
d
)
==
int64
_t
(
*
this
)));
}
}
bool
Dimension
::
merge
(
Dimension
&
dst
,
const
Dimension
d1
,
const
Dimension
d2
)
bool
Dimension
::
merge
(
Dimension
&
dst
,
const
Dimension
d1
,
const
Dimension
d2
)
...
@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
...
@@ -94,7 +94,7 @@ bool Dimension::merge(Dimension& dst, const Dimension d1, const Dimension d2)
dst
=
d1
;
dst
=
d1
;
return
true
;
return
true
;
}
}
else
if
(
size_t
(
d1
)
!=
size
_t
(
d2
))
else
if
(
int64_t
(
d1
)
!=
int64
_t
(
d2
))
{
{
return
false
;
return
false
;
}
}
...
@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
...
@@ -115,16 +115,16 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
else
if
(
d1
.
is_dynamic
()
||
d2
.
is_dynamic
())
else
if
(
d1
.
is_dynamic
()
||
d2
.
is_dynamic
())
{
{
// One static. Set dst to static size if >1
// One static. Set dst to static size if >1
auto
ds
=
d1
.
is_dynamic
()
?
size_t
(
d2
)
:
size
_t
(
d1
);
auto
ds
=
d1
.
is_dynamic
()
?
int64_t
(
d2
)
:
int64
_t
(
d1
);
dst
=
(
ds
>
1
)
?
ds
:
Dimension
::
dynamic
();
dst
=
(
ds
>
1
)
?
ds
:
Dimension
::
dynamic
();
return
true
;
return
true
;
}
}
else
else
{
{
// Static sizes. Both match or one of them is 1.
// Static sizes. Both match or one of them is 1.
if
(
size_t
(
d1
)
==
size_t
(
d2
)
||
size_t
(
d1
)
==
1
||
size
_t
(
d2
)
==
1
)
if
(
int64_t
(
d1
)
==
int64_t
(
d2
)
||
int64_t
(
d1
)
==
1
||
int64
_t
(
d2
)
==
1
)
{
{
dst
=
std
::
max
(
size_t
(
d1
),
size
_t
(
d2
));
dst
=
std
::
max
(
int64_t
(
d1
),
int64
_t
(
d2
));
return
true
;
return
true
;
}
}
else
else
...
...
src/ngraph/dimension.hpp
View file @
b534a674
...
@@ -25,7 +25,7 @@ namespace ngraph
...
@@ -25,7 +25,7 @@ namespace ngraph
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
/// in a shape or shape-like object.
///
///
/// Static dimensions may be implicitly converted from
size
_t. A dynamic dimension is
/// Static dimensions may be implicitly converted from
int64
_t. A dynamic dimension is
/// constructed with Dimension() or Dimension::dynamic().
/// constructed with Dimension() or Dimension::dynamic().
///
///
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
/// XXX: THIS CLASS IS NOT IN USE YET AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
...
@@ -36,7 +36,7 @@ namespace ngraph
...
@@ -36,7 +36,7 @@ namespace ngraph
/// \param dimension Value of the dimension. Must not be equal to
/// \param dimension Value of the dimension. Must not be equal to
/// Dimension::s_dynamic_val.
/// Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
/// \throws std::invalid_argument If `dimension` == Dimension::s_dynamic_val.
Dimension
(
size
_t
dimension
);
Dimension
(
int64
_t
dimension
);
/// \brief Construct a dynamic dimension.
/// \brief Construct a dynamic dimension.
Dimension
()
{
m_dimension
=
s_dynamic_val
;
}
Dimension
()
{
m_dimension
=
s_dynamic_val
;
}
...
@@ -46,25 +46,30 @@ namespace ngraph
...
@@ -46,25 +46,30 @@ namespace ngraph
/// \brief Check whether this dimension is dynamic.
/// \brief Check whether this dimension is dynamic.
/// \return `false` if the dimension is static, else `true`.
/// \return `false` if the dimension is static, else `true`.
bool
is_dynamic
()
const
{
return
!
is_static
();
}
bool
is_dynamic
()
const
{
return
!
is_static
();
}
/// \brief Convert this dimension to `
size
_t`. This dimension must be static.
/// \brief Convert this dimension to `
int64
_t`. This dimension must be static.
/// \throws std::invalid_argument If this dimension is dynamic.
/// \throws std::invalid_argument If this dimension is dynamic.
explicit
operator
size
_t
()
const
explicit
operator
int64
_t
()
const
{
{
if
(
is_dynamic
())
if
(
is_dynamic
())
{
{
throw
std
::
invalid_argument
(
"Cannot convert dynamic dimension to
size
_t"
);
throw
std
::
invalid_argument
(
"Cannot convert dynamic dimension to
int64
_t"
);
}
}
return
m_dimension
;
return
m_dimension
;
}
}
/// \brief Convert this dimension to `ptrdiff_t`. This dimension must be static.
/// \brief Convert this dimension to `size_t`. This dimension must be static and
/// \throws std::invalid_argument If this dimension is dynamic.
/// non-negative.
explicit
operator
ptrdiff_t
()
const
/// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit
operator
size_t
()
const
{
{
if
(
is_dynamic
())
if
(
is_dynamic
())
{
{
throw
std
::
invalid_argument
(
"Cannot convert dynamic dimension to ptrdiff_t"
);
throw
std
::
invalid_argument
(
"Cannot convert dynamic dimension to size_t"
);
}
if
(
m_dimension
<
0
)
{
throw
std
::
invalid_argument
(
"Cannot convert negative dimension to size_t"
);
}
}
return
static_cast
<
ptrdiff_t
>
(
m_dimension
)
;
return
m_dimension
;
}
}
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// \brief Check whether this dimension represents the same scheme as the argument (both
...
@@ -75,7 +80,7 @@ namespace ngraph
...
@@ -75,7 +80,7 @@ namespace ngraph
bool
same_scheme
(
const
Dimension
&
dim
)
const
bool
same_scheme
(
const
Dimension
&
dim
)
const
{
{
return
(
is_dynamic
()
&&
dim
.
is_dynamic
())
||
return
(
is_dynamic
()
&&
dim
.
is_dynamic
())
||
(
is_static
()
&&
dim
.
is_static
()
&&
m_dimension
==
size
_t
(
dim
));
(
is_static
()
&&
dim
.
is_static
()
&&
m_dimension
==
int64
_t
(
dim
));
}
}
/// \brief Try to merge two Dimension objects together.
/// \brief Try to merge two Dimension objects together.
...
@@ -128,25 +133,25 @@ namespace ngraph
...
@@ -128,25 +133,25 @@ namespace ngraph
/// \return A dynamic dimension.
/// \return A dynamic dimension.
static
Dimension
dynamic
()
{
return
Dimension
();
}
static
Dimension
dynamic
()
{
return
Dimension
();
}
/// \brief Constant for the value used internally to represent a dynamic dimension.
/// \brief Constant for the value used internally to represent a dynamic dimension.
static
const
size_t
s_dynamic_val
{(
std
::
numeric_limits
<
size
_t
>::
max
())};
static
const
int64_t
s_dynamic_val
{(
std
::
numeric_limits
<
int64
_t
>::
max
())};
/// \brief Addition operator for Dimension.
/// \brief Addition operator for Dimension.
/// \param dim Right operand for addition.
/// \param dim Right operand for addition.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `
size_t(*this)+size
_t(dim)`.
/// dimension with value `
int64_t(*this)+in64
_t(dim)`.
Dimension
operator
+
(
const
Dimension
&
dim
)
const
;
Dimension
operator
+
(
const
Dimension
&
dim
)
const
;
/// \brief Subtraction operator for Dimension.
/// \brief Subtraction operator for Dimension.
/// \param dim Right operand for subtraction.
/// \param dim Right operand for subtraction.
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// \return Dimension::dynamic() if either of `*this` or `dim` is dynamic; else, a static
/// dimension with value `
size_t(*this)-size
_t(dim)`.
/// dimension with value `
int64_t(*this)-int64
_t(dim)`.
Dimension
operator
-
(
const
Dimension
&
dim
)
const
;
Dimension
operator
-
(
const
Dimension
&
dim
)
const
;
/// \brief Multiplication operator for Dimension.
/// \brief Multiplication operator for Dimension.
/// \param dim Right operand for multiplicaiton.
/// \param dim Right operand for multiplicaiton.
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
/// \return 0 if either of `*this` or `dim` is static and 0; else, Dimension::dynamic() if
/// either of `*this` or `dim` is dynamic; else, a static dimension with value
/// either of `*this` or `dim` is dynamic; else, a static dimension with value
/// `
size_t(*this)*size
_t(dim)`.
/// `
int64_t(*this)*int64
_t(dim)`.
Dimension
operator
*
(
const
Dimension
&
dim
)
const
;
Dimension
operator
*
(
const
Dimension
&
dim
)
const
;
/// \brief Add-into operator for Dimension.
/// \brief Add-into operator for Dimension.
...
@@ -160,7 +165,7 @@ namespace ngraph
...
@@ -160,7 +165,7 @@ namespace ngraph
private
:
private
:
// The actual numerical value of the dimension. s_dynamic_val is a special case,
// The actual numerical value of the dimension. s_dynamic_val is a special case,
// representing a dynamic dimension.
// representing a dynamic dimension.
size
_t
m_dimension
;
int64
_t
m_dimension
;
};
};
/// \brief Insert a human-readable representation of a dimension into an output stream.
/// \brief Insert a human-readable representation of a dimension into an output stream.
...
@@ -168,6 +173,6 @@ namespace ngraph
...
@@ -168,6 +173,6 @@ namespace ngraph
/// \param dimension The dimension to be inserted into `str`.
/// \param dimension The dimension to be inserted into `str`.
/// \return A reference to `str` after insertion.
/// \return A reference to `str` after insertion.
///
///
/// Inserts the string `?` if `dimension` is dynamic; else inserts `
size
_t(dimension)`.
/// Inserts the string `?` if `dimension` is dynamic; else inserts `
int64
_t(dimension)`.
std
::
ostream
&
operator
<<
(
std
::
ostream
&
str
,
const
Dimension
&
dimension
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
str
,
const
Dimension
&
dimension
);
}
}
src/ngraph/ngraph.hpp
View file @
b534a674
...
@@ -89,6 +89,7 @@
...
@@ -89,6 +89,7 @@
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/range.hpp"
...
...
src/ngraph/op/experimental/dyn_replace_slice.cpp
0 → 100644
View file @
b534a674
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory>
using
namespace
std
;
using
namespace
ngraph
;
op
::
DynReplaceSlice
::
DynReplaceSlice
(
const
shared_ptr
<
Node
>&
arg
,
const
shared_ptr
<
Node
>&
replacement
,
const
shared_ptr
<
Node
>&
lower_bounds
,
const
shared_ptr
<
Node
>&
upper_bounds
,
const
shared_ptr
<
Node
>&
strides
,
const
AxisSet
&
lower_bounds_mask
,
const
AxisSet
&
upper_bounds_mask
,
const
AxisSet
&
new_axis
,
const
AxisSet
&
shrink_axis
,
const
AxisSet
&
ellipsis_mask
)
:
Op
(
"DynReplaceSlice"
,
check_single_output_args
({
arg
,
replacement
,
lower_bounds
,
upper_bounds
,
strides
}))
,
m_lower_bounds_mask
(
lower_bounds_mask
)
,
m_upper_bounds_mask
(
upper_bounds_mask
)
,
m_new_axis
(
new_axis
)
,
m_shrink_axis
(
shrink_axis
)
,
m_ellipsis_mask
(
ellipsis_mask
)
{
constructor_validate_and_infer_types
();
}
void
op
::
DynReplaceSlice
::
validate_and_infer_types
()
{
auto
arg_et
=
get_input_element_type
(
0
);
auto
replacement_et
=
get_input_element_type
(
1
);
auto
lower_bounds_et
=
get_input_element_type
(
2
);
auto
upper_bounds_et
=
get_input_element_type
(
3
);
auto
strides_et
=
get_input_element_type
(
4
);
element
::
Type
result_et
;
// check data types
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
arg_et
,
replacement_et
),
"Argument element type is not compatible with replacement element type"
);
NODE_VALIDATION_CHECK
(
this
,
lower_bounds_et
.
compatible
(
element
::
Type_t
::
i64
),
"Lower bounds must have element type i64."
);
NODE_VALIDATION_CHECK
(
this
,
upper_bounds_et
.
compatible
(
element
::
Type_t
::
i64
),
"Upper bounds must have element type i64."
);
NODE_VALIDATION_CHECK
(
this
,
strides_et
.
compatible
(
element
::
Type_t
::
i64
),
"Strides must have element type i64"
);
// check shapes
auto
arg_shape
=
get_input_partial_shape
(
0
);
auto
replacement_shape
=
get_input_partial_shape
(
1
);
auto
lower_bounds_shape
=
get_input_partial_shape
(
2
);
auto
upper_bounds_shape
=
get_input_partial_shape
(
3
);
auto
strides_shape
=
get_input_partial_shape
(
4
);
NODE_VALIDATION_CHECK
(
this
,
lower_bounds_shape
.
rank
().
compatible
(
1
),
"Lower bounds shape must have rank 1, got "
,
lower_bounds_shape
.
rank
(),
"."
);
NODE_VALIDATION_CHECK
(
this
,
upper_bounds_shape
.
rank
().
compatible
(
1
),
"Upper bounds shape must have rank 1, got "
,
upper_bounds_shape
.
rank
(),
"."
);
NODE_VALIDATION_CHECK
(
this
,
strides_shape
.
rank
().
compatible
(
1
),
"Strides shape must have rank 1, got "
,
strides_shape
.
rank
(),
"."
);
PartialShape
attrs_shape
{
PartialShape
::
dynamic
()};
NODE_VALIDATION_CHECK
(
this
,
(
lower_bounds_shape
.
same_scheme
(
PartialShape
{
0
})
||
PartialShape
::
merge_into
(
attrs_shape
,
lower_bounds_shape
))
&&
(
upper_bounds_shape
.
same_scheme
(
PartialShape
{
0
})
||
PartialShape
::
merge_into
(
attrs_shape
,
upper_bounds_shape
))
&&
(
strides_shape
.
same_scheme
(
PartialShape
{
0
})
||
PartialShape
::
merge_into
(
attrs_shape
,
strides_shape
)),
"Shapes for lower bounds, upper bounds, and strides do not match"
);
set_input_is_relevant_to_shape
(
2
);
set_input_is_relevant_to_shape
(
3
);
set_input_is_relevant_to_shape
(
4
);
auto
lower_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
2
));
auto
upper_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
));
auto
strides
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
4
));
// TODO(amprocte): We can get a bit more information here about the ranks of arg and
// replacement by inspecting the attributes.
auto
slice_shape
=
PartialShape
::
dynamic
();
if
(
lower_bounds
&&
upper_bounds
&&
strides
)
{
slice_shape
=
infer_slice_shape
(
this
,
get_input_partial_shape
(
0
),
lower_bounds
->
get_vector
<
int64_t
>
(),
upper_bounds
->
get_vector
<
int64_t
>
(),
strides
->
get_vector
<
int64_t
>
(),
m_lower_bounds_mask
,
m_upper_bounds_mask
,
m_new_axis
,
m_shrink_axis
,
m_ellipsis_mask
);
}
NODE_VALIDATION_CHECK
(
this
,
slice_shape
.
compatible
(
replacement_shape
),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: "
,
slice_shape
,
")"
);
set_output_type
(
0
,
result_et
,
arg_shape
);
}
shared_ptr
<
Node
>
op
::
DynReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
DynReplaceSlice
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
m_lower_bounds_mask
,
m_upper_bounds_mask
,
m_new_axis
,
m_shrink_axis
,
m_ellipsis_mask
);
}
void
op
::
DynReplaceSlice
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
throw
ngraph_error
(
"generate_adjoints not implemented for DynReplaceSlice"
);
}
src/ngraph/op/experimental/dyn_replace_slice.hpp
0 → 100644
View file @
b534a674
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace
ngraph
{
namespace
op
{
/// \brief Takes a slice of an input tensor, i.e., the sub-tensor that resides within a bounding box, optionally with stride.
class
DynReplaceSlice
:
public
Op
{
public
:
/// \brief Constructs a dynamic tensor replace-slice operation.
///
/// \param arg The tensor in which to replace the slice.
/// \param replacement Data to copy to the slice for replacement.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
/// \param lower_bounds_mask Ignores lower_bounds for axis with the mask set
/// \param upper_bounds_mask Ignores upper_bounds for axis with the mask set
/// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position
DynReplaceSlice
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
replacement
,
const
std
::
shared_ptr
<
Node
>&
lower_bounds
,
const
std
::
shared_ptr
<
Node
>&
upper_bounds
,
const
std
::
shared_ptr
<
Node
>&
strides
,
const
AxisSet
&
lower_bounds_mask
=
AxisSet
{},
const
AxisSet
&
upper_bounds_mask
=
AxisSet
{},
const
AxisSet
&
new_axis
=
AxisSet
{},
const
AxisSet
&
shrink_axis
=
AxisSet
{},
const
AxisSet
&
ellipsis_mask
=
AxisSet
{});
const
AxisSet
&
get_lower_bounds_mask
()
const
{
return
m_lower_bounds_mask
;
}
const
AxisSet
&
get_upper_bounds_mask
()
const
{
return
m_upper_bounds_mask
;
}
const
AxisSet
&
get_new_axis
()
const
{
return
m_new_axis
;
}
const
AxisSet
&
get_shrink_axis
()
const
{
return
m_shrink_axis
;
}
const
AxisSet
&
get_ellipsis_mask
()
const
{
return
m_ellipsis_mask
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
void
validate_and_infer_types
()
override
;
private
:
/// Helper method to compute output shape
Shape
compute_output_shape
()
const
;
AxisSet
m_lower_bounds_mask
;
AxisSet
m_upper_bounds_mask
;
AxisSet
m_new_axis
;
AxisSet
m_shrink_axis
;
AxisSet
m_ellipsis_mask
;
};
}
}
src/ngraph/op/experimental/dyn_slice.cpp
View file @
b534a674
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory>
#include <memory>
...
@@ -42,142 +43,6 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
...
@@ -42,142 +43,6 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
Shape
op
::
DynSlice
::
compute_output_shape
()
const
{
auto
input_shape
=
get_input_partial_shape
(
0
).
to_shape
();
auto
lower_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
1
));
auto
upper_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
2
));
auto
strides
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
));
if
(
lower_bounds
&&
upper_bounds
&&
strides
)
{
auto
lb
=
lower_bounds
->
get_vector
<
int64_t
>
();
auto
ub
=
upper_bounds
->
get_vector
<
int64_t
>
();
auto
str
=
strides
->
get_vector
<
int64_t
>
();
int
max_dims
=
input_shape
.
size
()
+
m_new_axis
.
size
();
if
(
lb
.
size
()
&&
ub
.
size
())
{
NODE_VALIDATION_CHECK
(
this
,
lb
.
size
()
==
ub
.
size
(),
"Lower bounds and Upper bounds needs to have same number of values"
);
}
if
(
lb
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
this
,
lb
.
size
()
==
str
.
size
(),
"Lower bounds and strides needs to have same number of values"
);
}
if
(
ub
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
this
,
ub
.
size
()
==
str
.
size
(),
"Upper bounds and strides needs to have same number of values"
);
}
int
bounds_size
=
lb
.
size
()
?
lb
.
size
()
:
(
ub
.
size
()
?
ub
.
size
()
:
(
str
.
size
()
?
str
.
size
()
:
0
));
NODE_VALIDATION_CHECK
(
this
,
m_ellipsis_mask
.
size
()
<=
1
,
"Ellipsis mask cannot specify more than one axis"
);
int
ellipsis_pos1
=
m_ellipsis_mask
.
size
()
?
*
m_ellipsis_mask
.
begin
()
:
max_dims
;
int
ellipsis_pos2
=
max_dims
;
bounds_size
-=
ellipsis_pos1
;
if
(
bounds_size
>
0
&&
(
max_dims
-
bounds_size
)
>
ellipsis_pos1
)
{
ellipsis_pos2
=
max_dims
-
bounds_size
;
}
std
::
vector
<
int
>
begin_dms
(
max_dims
,
0
);
std
::
vector
<
int
>
end_dms
(
max_dims
,
-
1
);
std
::
vector
<
int
>
stride_dms
(
max_dims
,
1
);
int
i
,
j
,
k
,
bj
,
ej
,
sj
;
Shape
out_dims
;
for
(
i
=
0
,
j
=
0
,
k
=
0
,
bj
=
0
,
ej
=
0
,
sj
=
0
;
i
<
max_dims
;
i
++
)
{
if
(
i
>=
ellipsis_pos1
&&
i
<
ellipsis_pos2
)
{
if
(
m_new_axis
.
find
(
i
)
==
m_new_axis
.
end
())
{
end_dms
[
i
]
=
end_dms
[
i
]
>=
0
?
end_dms
[
i
]
:
input_shape
[
j
++
]
+
end_dms
[
i
];
}
else
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
out_dims
.
push_back
(
static_cast
<
int
>
(
ceil
(
static_cast
<
float
>
(
abs
(
end_dms
[
i
]
-
begin_dms
[
i
])
+
1
)
/
static_cast
<
float
>
(
abs
(
stride_dms
[
i
])))));
k
=
ellipsis_pos1
;
continue
;
}
stride_dms
[
i
]
=
(
str
.
size
()
>
sj
&&
str
[
sj
]
!=
0
)
?
str
[
sj
++
]
:
1
;
// Use lower_bounds if mask is not set
if
(
m_lower_bounds_mask
.
find
(
j
)
==
m_lower_bounds_mask
.
end
())
{
begin_dms
[
i
]
=
lb
.
size
()
>
bj
?
lb
[
bj
]
:
(
stride_dms
[
i
]
>
0
?
0
:
-
1
);
}
else
{
begin_dms
[
i
]
=
stride_dms
[
i
]
>
0
?
0
:
-
1
;
}
bj
++
;
begin_dms
[
i
]
=
begin_dms
[
i
]
>=
0
?
begin_dms
[
i
]
:
input_shape
[
j
]
+
begin_dms
[
i
];
// Clipping 'begin'
begin_dms
[
i
]
=
(
begin_dms
[
i
]
<
0
)
?
0
:
(
begin_dms
[
i
]
>=
input_shape
[
j
]
?
input_shape
[
j
]
-
1
:
begin_dms
[
i
]);
// Use upper_bounds if mask is not set
if
(
m_upper_bounds_mask
.
find
(
j
)
==
m_upper_bounds_mask
.
end
())
{
int
end_dms_tmp
=
ub
.
size
()
>
ej
?
(
stride_dms
[
i
]
>
0
?
ub
[
ej
]
-
1
:
ub
[
ej
]
+
1
)
:
end_dms
[
i
];
end_dms
[
i
]
=
ub
.
size
()
>
ej
?
end_dms_tmp
:
(
stride_dms
[
i
]
>
0
?
-
1
:
0
);
}
else
{
end_dms
[
i
]
=
stride_dms
[
i
]
>
0
?
-
1
:
0
;
}
ej
++
;
end_dms
[
i
]
=
end_dms
[
i
]
>=
0
?
end_dms
[
i
]
:
input_shape
[
j
]
+
end_dms
[
i
];
// Clipping 'end'
end_dms
[
i
]
=
(
end_dms
[
i
]
<
0
)
?
0
:
(
end_dms
[
i
]
>=
input_shape
[
j
]
?
input_shape
[
j
]
-
1
:
end_dms
[
i
]);
if
(
m_new_axis
.
find
(
i
)
==
m_new_axis
.
end
())
{
j
++
;
}
else
{
end_dms
[
i
]
=
0
;
}
if
(
m_shrink_axis
.
find
(
k
)
!=
m_shrink_axis
.
end
())
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
else
{
out_dims
.
push_back
(
static_cast
<
int
>
(
ceil
(
static_cast
<
float
>
(
abs
(
end_dms
[
i
]
-
begin_dms
[
i
])
+
1
)
/
static_cast
<
float
>
(
abs
(
stride_dms
[
i
])))));
}
k
++
;
}
return
out_dims
;
}
return
Shape
{};
}
void
op
::
DynSlice
::
validate_and_infer_types
()
void
op
::
DynSlice
::
validate_and_infer_types
()
{
{
auto
lower_bounds_et
=
get_input_element_type
(
1
);
auto
lower_bounds_et
=
get_input_element_type
(
1
);
...
@@ -219,17 +84,24 @@ void op::DynSlice::validate_and_infer_types()
...
@@ -219,17 +84,24 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape
(
2
);
set_input_is_relevant_to_shape
(
2
);
set_input_is_relevant_to_shape
(
3
);
set_input_is_relevant_to_shape
(
3
);
if
(
get_input_partial_shape
(
0
).
is_static
())
auto
lower_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
1
));
auto
upper_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
2
));
auto
strides
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
));
if
(
lower_bounds
&&
upper_bounds
&&
strides
)
{
{
auto
shape
=
compute_output_shape
();
set_output_type
(
0
,
if
(
shape
!=
Shape
{})
get_input_element_type
(
0
),
{
infer_slice_shape
(
this
,
set_output_type
(
0
,
get_input_element_type
(
0
),
shape
);
get_input_partial_shape
(
0
),
}
lower_bounds
->
get_vector
<
int64_t
>
(),
else
upper_bounds
->
get_vector
<
int64_t
>
(),
{
strides
->
get_vector
<
int64_t
>
(),
set_output_type
(
0
,
get_input_element_type
(
0
),
PartialShape
::
dynamic
(
arg_shape
.
rank
()));
m_lower_bounds_mask
,
}
m_upper_bounds_mask
,
m_new_axis
,
m_shrink_axis
,
m_ellipsis_mask
));
}
}
else
else
{
{
...
...
src/ngraph/op/op_tbl.hpp
View file @
b534a674
...
@@ -84,6 +84,7 @@ NGRAPH_OP(Divide, ngraph::op)
...
@@ -84,6 +84,7 @@ NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
DynPad
,
ngraph
::
op
)
NGRAPH_OP
(
DynPad
,
ngraph
::
op
)
NGRAPH_OP
(
DynReplaceSlice
,
ngraph
::
op
)
NGRAPH_OP
(
DynReshape
,
ngraph
::
op
)
NGRAPH_OP
(
DynReshape
,
ngraph
::
op
)
NGRAPH_OP
(
DynSlice
,
ngraph
::
op
)
NGRAPH_OP
(
DynSlice
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
...
...
src/ngraph/op/pad.cpp
View file @
b534a674
...
@@ -84,7 +84,7 @@ void op::Pad::validate_and_infer_types()
...
@@ -84,7 +84,7 @@ void op::Pad::validate_and_infer_types()
if
(
arg_shape
[
i
].
is_static
())
if
(
arg_shape
[
i
].
is_static
())
{
{
ptrdiff_t
result_dim
=
ptrdiff_t
result_dim
=
m_padding_below
[
i
]
+
static_cast
<
ptrdiff
_t
>
(
arg_shape
[
i
])
+
m_padding_above
[
i
];
m_padding_below
[
i
]
+
static_cast
<
int64
_t
>
(
arg_shape
[
i
])
+
m_padding_above
[
i
];
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
result_dim
>=
0
,
result_dim
>=
0
,
"Inferred result dimension at axis "
,
"Inferred result dimension at axis "
,
...
...
src/ngraph/partial_shape.cpp
View file @
b534a674
...
@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
...
@@ -275,3 +275,16 @@ bool PartialShape::broadcast_merge_into(PartialShape& dst,
return
success
;
return
success
;
}
}
}
}
bool
PartialShape
::
all_non_negative
()
const
{
for
(
auto
&
d
:
m_dimensions
)
{
if
(
d
.
is_static
()
&&
int64_t
(
d
)
<
0
)
{
return
false
;
}
}
return
true
;
}
src/ngraph/partial_shape.hpp
View file @
b534a674
...
@@ -164,6 +164,10 @@ namespace ngraph
...
@@ -164,6 +164,10 @@ namespace ngraph
/// \throws std::invalid_argument If this PartialShape is dynamic.
/// \throws std::invalid_argument If this PartialShape is dynamic.
Shape
to_shape
()
const
;
Shape
to_shape
()
const
;
/// \brief Returns `true` if all static dimensions of the tensor are non-negative, else
/// `false`.
bool
all_non_negative
()
const
;
/// \brief Index operator for PartialShape.
/// \brief Index operator for PartialShape.
/// \param i The index of the dimension being selected.
/// \param i The index of the dimension being selected.
/// \return A reference to the `i`th Dimension of this shape.
/// \return A reference to the `i`th Dimension of this shape.
...
...
src/ngraph/pass/dyn_elimination.cpp
View file @
b534a674
...
@@ -19,10 +19,12 @@
...
@@ -19,10 +19,12 @@
#include "dyn_elimination.hpp"
#include "dyn_elimination.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/range.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/slice.hpp"
...
@@ -36,7 +38,8 @@ pass::DynElimination::DynElimination()
...
@@ -36,7 +38,8 @@ pass::DynElimination::DynElimination()
:
GraphRewrite
()
:
GraphRewrite
()
{
{
construct_transpose
();
construct_transpose
();
construct_broadcast
();
construct_dyn_broadcast
();
construct_dyn_replace_slice
();
construct_dyn_slice
();
construct_dyn_slice
();
construct_dyn_reshape
();
construct_dyn_reshape
();
construct_range
();
construct_range
();
...
@@ -89,7 +92,7 @@ void pass::DynElimination::construct_transpose()
...
@@ -89,7 +92,7 @@ void pass::DynElimination::construct_transpose()
add_matcher
(
transpose_matcher
,
transpose_callback
,
all_pass_property_off
);
add_matcher
(
transpose_matcher
,
transpose_callback
,
all_pass_property_off
);
}
}
void
pass
::
DynElimination
::
construct_broadcast
()
void
pass
::
DynElimination
::
construct_
dyn_
broadcast
()
{
{
auto
data_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
data_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
shape_arg_label
=
auto
shape_arg_label
=
...
@@ -444,6 +447,92 @@ void pass::DynElimination::construct_dyn_slice()
...
@@ -444,6 +447,92 @@ void pass::DynElimination::construct_dyn_slice()
add_matcher
(
dyn_slice_matcher
,
dyn_slice_callback
,
all_pass_property_off
);
add_matcher
(
dyn_slice_matcher
,
dyn_slice_callback
,
all_pass_property_off
);
}
}
void
pass
::
DynElimination
::
construct_dyn_replace_slice
()
{
auto
data_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
replacement_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
begins_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
3
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
ends_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
3
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
strides_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
i64
,
Shape
{
3
},
pattern
::
has_class
<
op
::
Constant
>
());
auto
dyn_replace_slice_pat
=
make_shared
<
op
::
DynReplaceSlice
>
(
data_arg_label
,
replacement_arg_label
,
begins_arg_label
,
ends_arg_label
,
strides_arg_label
,
AxisSet
{},
AxisSet
{},
AxisSet
{},
AxisSet
{},
AxisSet
{});
auto
dyn_replace_slice_callback
=
[
data_arg_label
,
replacement_arg_label
,
begins_arg_label
,
ends_arg_label
,
strides_arg_label
](
pattern
::
Matcher
&
m
)
{
auto
pattern_map
=
m
.
get_pattern_map
();
auto
data_arg
=
pattern_map
[
data_arg_label
];
auto
replacement_arg
=
pattern_map
[
replacement_arg_label
];
auto
begins_arg
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
begins_arg_label
]);
auto
ends_arg
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
ends_arg_label
]);
auto
strides_arg
=
static_pointer_cast
<
op
::
Constant
>
(
pattern_map
[
strides_arg_label
]);
auto
dyn_replace_slice
=
static_pointer_cast
<
op
::
DynReplaceSlice
>
(
m
.
get_match_root
());
if
(
data_arg
->
get_output_partial_shape
(
0
).
is_dynamic
()
||
replacement_arg
->
get_output_partial_shape
(
0
).
is_dynamic
()
||
begins_arg
->
get_element_type
()
!=
element
::
i64
||
ends_arg
->
get_element_type
()
!=
element
::
i64
||
strides_arg
->
get_element_type
()
!=
element
::
i64
)
{
return
false
;
}
SlicePlan
p
=
make_plan
(
data_arg
->
get_output_shape
(
0
),
begins_arg
->
get_vector
<
int64_t
>
(),
ends_arg
->
get_vector
<
int64_t
>
(),
strides_arg
->
get_vector
<
int64_t
>
(),
dyn_replace_slice
->
get_lower_bounds_mask
(),
dyn_replace_slice
->
get_upper_bounds_mask
(),
dyn_replace_slice
->
get_new_axis
(),
dyn_replace_slice
->
get_shrink_axis
(),
dyn_replace_slice
->
get_ellipsis_mask
());
shared_ptr
<
Node
>
substitute_replacement_arg
=
replacement_arg
;
if
(
!
p
.
reverse_axes
.
empty
())
{
substitute_replacement_arg
=
make_shared
<
op
::
Reverse
>
(
substitute_replacement_arg
,
p
.
reverse_axes
);
}
if
(
p
.
reshape_in_shape
!=
p
.
reshape_out_shape
)
{
substitute_replacement_arg
=
make_shared
<
op
::
Reshape
>
(
substitute_replacement_arg
,
ngraph
::
get_default_order
(
p
.
reshape_out_shape
),
p
.
reshape_in_shape
);
}
auto
substitute_rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
data_arg
,
substitute_replacement_arg
,
Coordinate
(
p
.
begins
.
begin
(),
p
.
begins
.
end
()),
Coordinate
(
p
.
ends
.
begin
(),
p
.
ends
.
end
()),
Strides
(
p
.
strides
.
begin
(),
p
.
strides
.
end
()));
replace_node
(
m
.
get_match_root
(),
substitute_rsl
);
return
true
;
};
auto
dyn_replace_slice_matcher
=
make_shared
<
pattern
::
Matcher
>
(
dyn_replace_slice_pat
,
"DynElimination.DynReplaceShape"
);
add_matcher
(
dyn_replace_slice_matcher
,
dyn_replace_slice_callback
,
all_pass_property_off
);
}
void
pass
::
DynElimination
::
construct_dyn_reshape
()
void
pass
::
DynElimination
::
construct_dyn_reshape
()
{
{
auto
data_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
data_arg_label
=
make_shared
<
pattern
::
op
::
Label
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
...
...
src/ngraph/pass/dyn_elimination.hpp
View file @
b534a674
...
@@ -30,7 +30,8 @@ namespace ngraph
...
@@ -30,7 +30,8 @@ namespace ngraph
private
:
private
:
void
construct_transpose
();
void
construct_transpose
();
void
construct_broadcast
();
void
construct_dyn_broadcast
();
void
construct_dyn_replace_slice
();
void
construct_dyn_slice
();
void
construct_dyn_slice
();
void
construct_dyn_reshape
();
void
construct_dyn_reshape
();
void
construct_range
();
void
construct_range
();
...
...
src/ngraph/runtime/gpu/gpu_backend.cpp
View file @
b534a674
...
@@ -218,6 +218,7 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
...
@@ -218,6 +218,7 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
{
{
set
<
string
>
unsupported_ops
=
{
"Quantize"
,
set
<
string
>
unsupported_ops
=
{
"Quantize"
,
"Dequantize"
,
"Dequantize"
,
"DynReplaceSlice"
,
"DynReshape"
,
"DynReshape"
,
"DynSlice"
,
"DynSlice"
,
"ShapeOf"
,
"ShapeOf"
,
...
...
src/ngraph/runtime/gpu/gpu_emitter.cpp
View file @
b534a674
...
@@ -62,6 +62,7 @@
...
@@ -62,6 +62,7 @@
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
...
@@ -612,6 +613,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
...
@@ -612,6 +613,11 @@ std::string runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
return
compiled_function
->
add_to_runtime
(
index
,
function_name
,
args
,
out
);
return
compiled_function
->
add_to_runtime
(
index
,
function_name
,
args
,
out
);
}
}
std
::
string
runtime
::
gpu
::
GPU_Emitter
::
emit_DynReplaceSlice
(
EMIT_ARGS
)
{
throw
unsupported_op
(
"Unsupported op '"
+
node
->
description
()
+
"'"
);
}
std
::
string
runtime
::
gpu
::
GPU_Emitter
::
emit_DynReshape
(
EMIT_ARGS
)
std
::
string
runtime
::
gpu
::
GPU_Emitter
::
emit_DynReshape
(
EMIT_ARGS
)
{
{
throw
unsupported_op
(
"Unsupported op '"
+
node
->
description
()
+
"'"
);
throw
unsupported_op
(
"Unsupported op '"
+
node
->
description
()
+
"'"
);
...
...
src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
View file @
b534a674
...
@@ -2061,6 +2061,7 @@ shared_ptr<runtime::Executable>
...
@@ -2061,6 +2061,7 @@ shared_ptr<runtime::Executable>
case
OP_TYPEID
:
:
DepthToSpace
:
case
OP_TYPEID
:
:
DepthToSpace
:
case
OP_TYPEID
:
:
DynBroadcast
:
case
OP_TYPEID
:
:
DynBroadcast
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
DynReplaceSlice
:
case
OP_TYPEID
:
:
DynReshape
:
case
OP_TYPEID
:
:
DynReshape
:
case
OP_TYPEID
:
:
DynSlice
:
case
OP_TYPEID
:
:
DynSlice
:
case
OP_TYPEID
:
:
Elu
:
case
OP_TYPEID
:
:
Elu
:
...
...
src/ngraph/runtime/intelgpu/unit_test.manifest
View file @
b534a674
...
@@ -18,6 +18,7 @@ replace_slice_matrix
...
@@ -18,6 +18,7 @@ replace_slice_matrix
replace_slice_matrix_inplace
replace_slice_matrix_inplace
replace_slice_scalar
replace_slice_scalar
replace_slice_vector
replace_slice_vector
dyn_replace_slice
shape_of_5d
shape_of_5d
shape_of_matrix
shape_of_matrix
shape_of_scalar
shape_of_scalar
...
...
src/ngraph/runtime/interpreter/int_executable.hpp
View file @
b534a674
...
@@ -1503,7 +1503,8 @@ private:
...
@@ -1503,7 +1503,8 @@ private:
case
OP_TYPEID
:
:
Transpose
:
case
OP_TYPEID
:
:
Transpose
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
Tile
:
case
OP_TYPEID
:
:
Tile
:
default
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
case
OP_TYPEID
:
:
DynReplaceSlice
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop
#endif
#endif
...
...
src/ngraph/serializer.cpp
View file @
b534a674
...
@@ -55,6 +55,7 @@
...
@@ -55,6 +55,7 @@
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
...
@@ -1130,6 +1131,25 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1130,6 +1131,25 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node
=
make_shared
<
op
::
DynPad
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
node
=
make_shared
<
op
::
DynPad
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
break
;
break
;
}
}
case
OP_TYPEID
:
:
DynReplaceSlice
:
{
auto
lower_bounds_mask
=
node_js
.
at
(
"lower_bounds_mask"
).
get
<
set
<
size_t
>>
();
auto
upper_bounds_mask
=
node_js
.
at
(
"upper_bounds_mask"
).
get
<
set
<
size_t
>>
();
auto
new_axis
=
node_js
.
at
(
"new_axis"
).
get
<
set
<
size_t
>>
();
auto
shrink_axis
=
node_js
.
at
(
"shrink_axis"
).
get
<
set
<
size_t
>>
();
auto
ellipsis_mask
=
node_js
.
at
(
"ellipsis_mask"
).
get
<
set
<
size_t
>>
();
node
=
make_shared
<
op
::
DynReplaceSlice
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
lower_bounds_mask
,
upper_bounds_mask
,
new_axis
,
shrink_axis
,
ellipsis_mask
);
break
;
}
case
OP_TYPEID
:
:
DynReshape
:
case
OP_TYPEID
:
:
DynReshape
:
{
{
node
=
make_shared
<
op
::
DynReshape
>
(
args
[
0
],
args
[
1
]);
node
=
make_shared
<
op
::
DynReshape
>
(
args
[
0
],
args
[
1
]);
...
@@ -1137,7 +1157,20 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1137,7 +1157,20 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
}
case
OP_TYPEID
:
:
DynSlice
:
case
OP_TYPEID
:
:
DynSlice
:
{
{
node
=
make_shared
<
op
::
DynSlice
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
auto
lower_bounds_mask
=
node_js
.
at
(
"lower_bounds_mask"
).
get
<
set
<
size_t
>>
();
auto
upper_bounds_mask
=
node_js
.
at
(
"upper_bounds_mask"
).
get
<
set
<
size_t
>>
();
auto
new_axis
=
node_js
.
at
(
"new_axis"
).
get
<
set
<
size_t
>>
();
auto
shrink_axis
=
node_js
.
at
(
"shrink_axis"
).
get
<
set
<
size_t
>>
();
auto
ellipsis_mask
=
node_js
.
at
(
"ellipsis_mask"
).
get
<
set
<
size_t
>>
();
node
=
make_shared
<
op
::
DynSlice
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
lower_bounds_mask
,
upper_bounds_mask
,
new_axis
,
shrink_axis
,
ellipsis_mask
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Elu
:
case
OP_TYPEID
:
:
Elu
:
...
@@ -2297,9 +2330,27 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2297,9 +2330,27 @@ json JSONSerializer::serialize_node(const Node& n)
}
}
case
OP_TYPEID
:
:
DynPad
:
{
break
;
case
OP_TYPEID
:
:
DynPad
:
{
break
;
}
}
case
OP_TYPEID
:
:
DynReplaceSlice
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
DynReplaceSlice
*>
(
&
n
);
node
[
"lower_bounds_mask"
]
=
tmp
->
get_lower_bounds_mask
();
node
[
"upper_bounds_mask"
]
=
tmp
->
get_upper_bounds_mask
();
node
[
"new_axis"
]
=
tmp
->
get_new_axis
();
node
[
"shrink_axis"
]
=
tmp
->
get_shrink_axis
();
node
[
"ellipsis_mask"
]
=
tmp
->
get_ellipsis_mask
();
break
;
}
case
OP_TYPEID
:
:
DynReshape
:
{
break
;
case
OP_TYPEID
:
:
DynReshape
:
{
break
;
}
}
case
OP_TYPEID
:
:
DynSlice
:
{
break
;
case
OP_TYPEID
:
:
DynSlice
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
DynSlice
*>
(
&
n
);
node
[
"lower_bounds_mask"
]
=
tmp
->
get_lower_bounds_mask
();
node
[
"upper_bounds_mask"
]
=
tmp
->
get_upper_bounds_mask
();
node
[
"new_axis"
]
=
tmp
->
get_new_axis
();
node
[
"shrink_axis"
]
=
tmp
->
get_shrink_axis
();
node
[
"ellipsis_mask"
]
=
tmp
->
get_ellipsis_mask
();
break
;
}
}
case
OP_TYPEID
:
:
Elu
:
{
break
;
case
OP_TYPEID
:
:
Elu
:
{
break
;
}
}
...
...
src/ngraph/validation_util.cpp
View file @
b534a674
...
@@ -142,8 +142,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
...
@@ -142,8 +142,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
ptrdiff_t
data_padded_dilated_dim
=
-
1
;
ptrdiff_t
data_padded_dilated_dim
=
-
1
;
if
(
data_dim_static
)
if
(
data_dim_static
)
{
{
data_padded_dilated_dim
=
(
static_cast
<
ptrdiff
_t
>
(
data_dilation
[
i
])
*
data_padded_dilated_dim
=
(
static_cast
<
int64
_t
>
(
data_dilation
[
i
])
*
(
static_cast
<
ptrdiff
_t
>
(
data_shape
[
i
])
-
1
))
+
(
static_cast
<
int64
_t
>
(
data_shape
[
i
])
-
1
))
+
1
+
data_padding_below
[
i
]
+
data_padding_above
[
i
];
1
+
data_padding_below
[
i
]
+
data_padding_above
[
i
];
NODE_VALIDATION_CHECK
(
NODE_VALIDATION_CHECK
(
node
,
node
,
...
@@ -158,8 +158,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
...
@@ -158,8 +158,8 @@ PartialShape ngraph::infer_windowed_reduction_output_shape(const Node* node,
ptrdiff_t
window_dilated_dim
=
-
1
;
ptrdiff_t
window_dilated_dim
=
-
1
;
if
(
window_dim_static
)
if
(
window_dim_static
)
{
{
window_dilated_dim
=
static_cast
<
ptrdiff
_t
>
(
window_dilation
[
i
])
*
window_dilated_dim
=
static_cast
<
int64
_t
>
(
window_dilation
[
i
])
*
(
static_cast
<
ptrdiff
_t
>
(
window_shape
[
i
])
-
1
)
+
(
static_cast
<
int64
_t
>
(
window_shape
[
i
])
-
1
)
+
1
;
1
;
NODE_VALIDATION_CHECK
(
node
,
NODE_VALIDATION_CHECK
(
node
,
...
@@ -628,3 +628,257 @@ void ngraph::infer_auto_padding(const Shape& image_shape,
...
@@ -628,3 +628,257 @@ void ngraph::infer_auto_padding(const Shape& image_shape,
padding_above
.
push_back
(
pad_type
==
op
::
PadType
::
SAME_UPPER
?
padding_rhs
:
padding_lhs
);
padding_above
.
push_back
(
pad_type
==
op
::
PadType
::
SAME_UPPER
?
padding_rhs
:
padding_lhs
);
}
}
}
}
PartialShape
ngraph
::
infer_slice_shape
(
const
Node
*
node
,
const
PartialShape
&
input_shape
,
const
std
::
vector
<
int64_t
>&
lb
,
const
std
::
vector
<
int64_t
>&
ub
,
const
std
::
vector
<
int64_t
>&
str
,
const
AxisSet
&
lb_mask
,
const
AxisSet
&
ub_mask
,
const
AxisSet
&
new_axis
,
const
AxisSet
&
shrink_axis
,
const
AxisSet
&
ellipsis_mask
)
{
if
(
lb
.
size
()
&&
ub
.
size
())
{
NODE_VALIDATION_CHECK
(
node
,
lb
.
size
()
==
ub
.
size
(),
"Lower bounds and Upper bounds needs to have same number of values"
);
}
if
(
lb
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
node
,
lb
.
size
()
==
str
.
size
(),
"Lower bounds and strides needs to have same number of values"
);
}
if
(
ub
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
node
,
ub
.
size
()
==
str
.
size
(),
"Upper bounds and strides needs to have same number of values"
);
}
if
(
input_shape
.
rank
().
is_dynamic
())
{
return
PartialShape
::
dynamic
();
}
int
max_dims
=
size_t
(
input_shape
.
rank
())
+
new_axis
.
size
();
int
bounds_size
=
lb
.
size
()
?
lb
.
size
()
:
(
ub
.
size
()
?
ub
.
size
()
:
(
str
.
size
()
?
str
.
size
()
:
0
));
int
ellipsis_pos1
=
ellipsis_mask
.
size
()
?
*
ellipsis_mask
.
begin
()
:
max_dims
;
int
ellipsis_pos2
=
max_dims
;
bounds_size
-=
ellipsis_pos1
;
if
(
bounds_size
>
0
&&
(
max_dims
-
bounds_size
)
>
ellipsis_pos1
)
{
ellipsis_pos2
=
max_dims
-
bounds_size
;
}
std
::
vector
<
Dimension
>
begin_dms
(
max_dims
,
0
);
std
::
vector
<
Dimension
>
end_dms
(
max_dims
,
-
1
);
std
::
vector
<
Dimension
>
stride_dms
(
max_dims
,
1
);
std
::
vector
<
Dimension
>
out_dims
;
int
j
=
0
;
int
k
=
0
;
int
bj
=
0
;
int
ej
=
0
;
int
sj
=
0
;
for
(
int
i
=
0
;
i
<
max_dims
;
i
++
)
{
if
(
i
>=
ellipsis_pos1
&&
i
<
ellipsis_pos2
)
{
if
(
new_axis
.
find
(
i
)
==
new_axis
.
end
())
{
if
(
end_dms
[
i
].
is_static
()
&&
int64_t
(
end_dms
[
i
])
<
0
)
{
end_dms
[
i
]
=
input_shape
[
j
++
]
+
end_dms
[
i
];
}
}
else
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
if
(
end_dms
[
i
].
is_dynamic
()
||
begin_dms
[
i
].
is_dynamic
()
||
stride_dms
[
i
].
is_dynamic
())
{
out_dims
.
push_back
(
Dimension
::
dynamic
());
}
else
{
out_dims
.
push_back
(
static_cast
<
int64_t
>
(
ceil
(
static_cast
<
float
>
(
abs
(
int64_t
(
end_dms
[
i
])
-
int64_t
(
begin_dms
[
i
]))
+
1
)
/
static_cast
<
float
>
(
abs
(
int64_t
(
stride_dms
[
i
]))))));
}
k
=
ellipsis_pos1
;
continue
;
}
stride_dms
[
i
]
=
(
str
.
size
()
>
sj
&&
str
[
sj
]
!=
0
)
?
str
[
sj
++
]
:
1
;
// Use lower_bounds if mask is not set
if
(
lb_mask
.
find
(
j
)
==
lb_mask
.
end
())
{
if
(
lb
.
size
()
>
bj
)
{
begin_dms
[
i
]
=
lb
[
bj
];
}
else
if
(
stride_dms
[
i
].
is_dynamic
())
{
begin_dms
[
i
]
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
stride_dms
[
i
])
>
0
)
{
begin_dms
[
i
]
=
0
;
}
else
{
begin_dms
[
i
]
=
-
1
;
}
}
else
if
(
stride_dms
[
i
].
is_dynamic
())
{
begin_dms
[
i
]
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
stride_dms
[
i
])
>
0
)
{
begin_dms
[
i
]
=
0
;
}
else
{
begin_dms
[
i
]
=
-
1
;
}
bj
++
;
if
(
begin_dms
[
i
].
is_static
()
&&
int64_t
(
begin_dms
[
i
])
<
0
)
{
begin_dms
[
i
]
=
input_shape
[
j
]
+
begin_dms
[
i
];
}
// Clipping 'begin'
if
(
begin_dms
[
i
].
is_static
())
{
if
(
int64_t
(
begin_dms
[
i
])
<
0
)
{
begin_dms
[
i
]
=
0
;
}
else
if
(
input_shape
[
j
].
is_dynamic
())
{
begin_dms
[
i
]
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
begin_dms
[
i
])
>=
int64_t
(
input_shape
[
j
]))
{
begin_dms
[
i
]
=
input_shape
[
j
]
-
1
;
}
}
// Use upper_bounds if mask is not set
if
(
ub_mask
.
find
(
j
)
==
ub_mask
.
end
())
{
Dimension
end_dms_tmp
;
if
(
ub
.
size
()
<=
ej
)
{
end_dms_tmp
=
end_dms
[
i
];
}
else
if
(
stride_dms
[
i
].
is_dynamic
())
{
end_dms_tmp
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
stride_dms
[
i
])
>
0
)
{
end_dms_tmp
=
ub
[
ej
]
-
1
;
}
else
{
end_dms_tmp
=
ub
[
ej
]
+
1
;
}
if
(
ub
.
size
()
>
ej
)
{
end_dms
[
i
]
=
end_dms_tmp
;
}
else
if
(
stride_dms
[
i
].
is_dynamic
())
{
end_dms
[
i
]
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
stride_dms
[
i
])
>
0
)
{
end_dms
[
i
]
=
-
1
;
}
else
{
end_dms
[
i
]
=
0
;
}
}
else
{
if
(
stride_dms
[
i
].
is_dynamic
())
{
end_dms
[
i
]
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
stride_dms
[
i
])
>
0
)
{
end_dms
[
i
]
=
-
1
;
}
else
{
end_dms
[
i
]
=
0
;
}
}
ej
++
;
if
(
end_dms
[
i
].
is_static
()
&&
int64_t
(
end_dms
[
i
])
<
0
)
{
end_dms
[
i
]
=
input_shape
[
j
]
+
end_dms
[
i
];
}
// Clipping 'end'
if
(
end_dms
[
i
].
is_static
())
{
if
(
int64_t
(
end_dms
[
i
])
<
0
)
{
end_dms
[
i
]
=
0
;
}
else
if
(
input_shape
[
j
].
is_dynamic
())
{
end_dms
[
i
]
=
Dimension
::
dynamic
();
}
else
if
(
int64_t
(
end_dms
[
i
])
>=
int64_t
(
input_shape
[
j
]))
{
end_dms
[
i
]
=
input_shape
[
j
]
-
1
;
}
}
if
(
new_axis
.
find
(
i
)
==
new_axis
.
end
())
{
j
++
;
}
else
{
end_dms
[
i
]
=
0
;
}
if
(
shrink_axis
.
find
(
k
)
!=
shrink_axis
.
end
())
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
else
if
(
end_dms
[
i
].
is_dynamic
()
||
begin_dms
[
i
].
is_dynamic
()
||
stride_dms
[
i
].
is_dynamic
())
{
out_dims
.
push_back
(
Dimension
::
dynamic
());
}
else
{
out_dims
.
push_back
(
static_cast
<
int64_t
>
(
ceil
(
static_cast
<
float
>
(
abs
(
int64_t
(
end_dms
[
i
])
-
int64_t
(
begin_dms
[
i
]))
+
1
)
/
static_cast
<
float
>
(
abs
(
int64_t
(
stride_dms
[
i
]))))));
}
k
++
;
}
return
out_dims
;
}
src/ngraph/validation_util.hpp
View file @
b534a674
...
@@ -94,4 +94,15 @@ namespace ngraph
...
@@ -94,4 +94,15 @@ namespace ngraph
const
op
::
PadType
pad_type
,
const
op
::
PadType
pad_type
,
CoordinateDiff
&
padding_above
,
CoordinateDiff
&
padding_above
,
CoordinateDiff
&
padding_below
);
CoordinateDiff
&
padding_below
);
PartialShape
infer_slice_shape
(
const
Node
*
node
,
const
PartialShape
&
input_shape
,
const
std
::
vector
<
int64_t
>&
lb
,
const
std
::
vector
<
int64_t
>&
ub
,
const
std
::
vector
<
int64_t
>&
str
,
const
AxisSet
&
lb_mask
,
const
AxisSet
&
ub_mask
,
const
AxisSet
&
new_axis
,
const
AxisSet
&
shrink_mask
,
const
AxisSet
&
ellipsis_mask
);
}
}
test/CMakeLists.txt
View file @
b534a674
...
@@ -167,6 +167,7 @@ set(MULTI_TEST_SRC
...
@@ -167,6 +167,7 @@ set(MULTI_TEST_SRC
backend_test.in.cpp
backend_test.in.cpp
backend_unary_elementwise.in.cpp
backend_unary_elementwise.in.cpp
convolution_test.in.cpp
convolution_test.in.cpp
dyn_replace_slice_test.in.cpp
dyn_slice_test.in.cpp
dyn_slice_test.in.cpp
dynamic.in.cpp
dynamic.in.cpp
)
)
...
...
test/dyn_elimination.cpp
View file @
b534a674
...
@@ -132,6 +132,56 @@ TEST(dyn_elimination, slice)
...
@@ -132,6 +132,56 @@ TEST(dyn_elimination, slice)
ASSERT_EQ
(
f
->
get_results
().
at
(
0
)
->
get_shape
(),
(
Shape
{
2
,
4
,
2
,
2
,
1
,
2
,
2
}));
ASSERT_EQ
(
f
->
get_results
().
at
(
0
)
->
get_shape
(),
(
Shape
{
2
,
4
,
2
,
2
,
1
,
2
,
2
}));
}
}
TEST
(
dyn_elimination
,
replace_slice
)
{
// input has shape [2,4,6,8,2,2,2]
// slice in numpy syntax is [0:,:4,2:6:2,7:3:-2,np.newaxis,...,1]
// slice shape should be [2,4,2,2,1,2,2] (so sayeth numpy!)
Shape
shape_in
{
2
,
4
,
6
,
8
,
2
,
2
,
2
};
Shape
shape_slice
{
2
,
4
,
2
,
2
,
1
,
2
,
2
};
auto
input
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_in
);
auto
replacement
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_slice
);
auto
constant_lb
=
make_shared
<
op
::
Constant
>
(
element
::
i64
,
Shape
{
7
},
vector
<
int64_t
>
{
0
,
3
,
2
,
7
,
0
,
0
,
1
});
auto
constant_ub
=
make_shared
<
op
::
Constant
>
(
element
::
i64
,
Shape
{
7
},
vector
<
int64_t
>
{
0
,
4
,
6
,
3
,
0
,
0
,
0
});
auto
constant_strides
=
make_shared
<
op
::
Constant
>
(
element
::
i64
,
Shape
{
7
},
vector
<
int64_t
>
{
1
,
1
,
2
,
-
2
,
0
,
0
,
0
});
AxisSet
lower_bounds_mask
{
1
};
AxisSet
upper_bounds_mask
{
0
};
AxisSet
new_axis_mask
{
4
};
AxisSet
shrink_mask
{
6
};
AxisSet
ellipsis_mask
{
5
};
auto
rsl
=
make_shared
<
op
::
DynReplaceSlice
>
(
input
,
replacement
,
constant_lb
,
constant_ub
,
constant_strides
,
lower_bounds_mask
,
upper_bounds_mask
,
new_axis_mask
,
shrink_mask
,
ellipsis_mask
);
ASSERT_EQ
(
rsl
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
rsl
->
get_shape
(),
(
Shape
{
2
,
4
,
6
,
8
,
2
,
2
,
2
}));
auto
f
=
make_shared
<
Function
>
(
rsl
,
ParameterVector
{
input
,
replacement
});
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
DynElimination
>
();
pass_manager
.
run_passes
(
f
);
ASSERT_EQ
(
count_ops_of_type
<
op
::
DynReplaceSlice
>
(
f
),
0
);
ASSERT_EQ
(
count_ops_of_type
<
op
::
ReplaceSlice
>
(
f
),
1
);
ASSERT_EQ
(
count_ops_of_type
<
op
::
Reshape
>
(
f
),
1
);
ASSERT_EQ
(
count_ops_of_type
<
op
::
Reverse
>
(
f
),
1
);
ASSERT_EQ
(
f
->
get_results
().
at
(
0
)
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
f
->
get_results
().
at
(
0
)
->
get_shape
(),
(
Shape
{
2
,
4
,
6
,
8
,
2
,
2
,
2
}));
}
TEST
(
dyn_elimination
,
reshape
)
TEST
(
dyn_elimination
,
reshape
)
{
{
auto
input_arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
,
6
,
8
});
auto
input_arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
,
6
,
8
});
...
...
test/dyn_replace_slice_test.in.cpp
0 → 100644
View file @
b534a674
This diff is collapsed.
Click to expand it.
test/ref_generators/generate_dyn_replace_slice_ref.py
0 → 100644
View file @
b534a674
This diff is collapsed.
Click to expand it.
test/type_prop.cpp
View file @
b534a674
This diff is collapsed.
Click to expand it.
test/update_dyn_replace_slice_reference.sh
0 → 100755
View file @
b534a674
#!/bin/bash
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
declare
THIS_SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
python
${
THIS_SCRIPT_DIR
}
/ref_generators/generate_dyn_replace_slice_ref.py
${
THIS_SCRIPT_DIR
}
/dyn_replace_slice_test.in.cpp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment