Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
b9a599a1
Commit
b9a599a1
authored
Jun 12, 2019
by
Adam Procter
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
wip
parent
beb8c442
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
194 additions
and
146 deletions
+194
-146
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-0
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
dyn_slice.cpp
src/ngraph/op/experimental/dyn_slice.cpp
+18
-146
validation_util.cpp
src/ngraph/validation_util.cpp
+162
-0
validation_util.hpp
src/ngraph/validation_util.hpp
+11
-0
type_prop.cpp
test/type_prop.cpp
+0
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
b9a599a1
...
...
@@ -142,6 +142,8 @@ set (SRC
op/experimental/dyn_broadcast.hpp
op/experimental/dyn_pad.cpp
op/experimental/dyn_pad.hpp
op/experimental/dyn_replace_slice.cpp
op/experimental/dyn_replace_slice.hpp
op/experimental/dyn_reshape.cpp
op/experimental/dyn_reshape.hpp
op/experimental/dyn_slice.cpp
...
...
src/ngraph/ngraph.hpp
View file @
b9a599a1
...
...
@@ -89,6 +89,7 @@
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
...
...
src/ngraph/op/experimental/dyn_slice.cpp
View file @
b9a599a1
...
...
@@ -17,6 +17,7 @@
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory>
...
...
@@ -42,142 +43,6 @@ op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types
();
}
Shape
op
::
DynSlice
::
compute_output_shape
()
const
{
auto
input_shape
=
get_input_partial_shape
(
0
).
to_shape
();
auto
lower_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
1
));
auto
upper_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
2
));
auto
strides
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
));
if
(
lower_bounds
&&
upper_bounds
&&
strides
)
{
auto
lb
=
lower_bounds
->
get_vector
<
int64_t
>
();
auto
ub
=
upper_bounds
->
get_vector
<
int64_t
>
();
auto
str
=
strides
->
get_vector
<
int64_t
>
();
int
max_dims
=
input_shape
.
size
()
+
m_new_axis
.
size
();
if
(
lb
.
size
()
&&
ub
.
size
())
{
NODE_VALIDATION_CHECK
(
this
,
lb
.
size
()
==
ub
.
size
(),
"Lower bounds and Upper bounds needs to have same number of values"
);
}
if
(
lb
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
this
,
lb
.
size
()
==
str
.
size
(),
"Lower bounds and strides needs to have same number of values"
);
}
if
(
ub
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
this
,
ub
.
size
()
==
str
.
size
(),
"Upper bounds and strides needs to have same number of values"
);
}
int
bounds_size
=
lb
.
size
()
?
lb
.
size
()
:
(
ub
.
size
()
?
ub
.
size
()
:
(
str
.
size
()
?
str
.
size
()
:
0
));
NODE_VALIDATION_CHECK
(
this
,
m_ellipsis_mask
.
size
()
<=
1
,
"Ellipsis mask cannot specify more than one axis"
);
int
ellipsis_pos1
=
m_ellipsis_mask
.
size
()
?
*
m_ellipsis_mask
.
begin
()
:
max_dims
;
int
ellipsis_pos2
=
max_dims
;
bounds_size
-=
ellipsis_pos1
;
if
(
bounds_size
>
0
&&
(
max_dims
-
bounds_size
)
>
ellipsis_pos1
)
{
ellipsis_pos2
=
max_dims
-
bounds_size
;
}
std
::
vector
<
int
>
begin_dms
(
max_dims
,
0
);
std
::
vector
<
int
>
end_dms
(
max_dims
,
-
1
);
std
::
vector
<
int
>
stride_dms
(
max_dims
,
1
);
int
i
,
j
,
k
,
bj
,
ej
,
sj
;
Shape
out_dims
;
for
(
i
=
0
,
j
=
0
,
k
=
0
,
bj
=
0
,
ej
=
0
,
sj
=
0
;
i
<
max_dims
;
i
++
)
{
if
(
i
>=
ellipsis_pos1
&&
i
<
ellipsis_pos2
)
{
if
(
m_new_axis
.
find
(
i
)
==
m_new_axis
.
end
())
{
end_dms
[
i
]
=
end_dms
[
i
]
>=
0
?
end_dms
[
i
]
:
input_shape
[
j
++
]
+
end_dms
[
i
];
}
else
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
out_dims
.
push_back
(
static_cast
<
int
>
(
ceil
(
static_cast
<
float
>
(
abs
(
end_dms
[
i
]
-
begin_dms
[
i
])
+
1
)
/
static_cast
<
float
>
(
abs
(
stride_dms
[
i
])))));
k
=
ellipsis_pos1
;
continue
;
}
stride_dms
[
i
]
=
(
str
.
size
()
>
sj
&&
str
[
sj
]
!=
0
)
?
str
[
sj
++
]
:
1
;
// Use lower_bounds if mask is not set
if
(
m_lower_bounds_mask
.
find
(
j
)
==
m_lower_bounds_mask
.
end
())
{
begin_dms
[
i
]
=
lb
.
size
()
>
bj
?
lb
[
bj
]
:
(
stride_dms
[
i
]
>
0
?
0
:
-
1
);
}
else
{
begin_dms
[
i
]
=
stride_dms
[
i
]
>
0
?
0
:
-
1
;
}
bj
++
;
begin_dms
[
i
]
=
begin_dms
[
i
]
>=
0
?
begin_dms
[
i
]
:
input_shape
[
j
]
+
begin_dms
[
i
];
// Clipping 'begin'
begin_dms
[
i
]
=
(
begin_dms
[
i
]
<
0
)
?
0
:
(
begin_dms
[
i
]
>=
input_shape
[
j
]
?
input_shape
[
j
]
-
1
:
begin_dms
[
i
]);
// Use upper_bounds if mask is not set
if
(
m_upper_bounds_mask
.
find
(
j
)
==
m_upper_bounds_mask
.
end
())
{
int
end_dms_tmp
=
ub
.
size
()
>
ej
?
(
stride_dms
[
i
]
>
0
?
ub
[
ej
]
-
1
:
ub
[
ej
]
+
1
)
:
end_dms
[
i
];
end_dms
[
i
]
=
ub
.
size
()
>
ej
?
end_dms_tmp
:
(
stride_dms
[
i
]
>
0
?
-
1
:
0
);
}
else
{
end_dms
[
i
]
=
stride_dms
[
i
]
>
0
?
-
1
:
0
;
}
ej
++
;
end_dms
[
i
]
=
end_dms
[
i
]
>=
0
?
end_dms
[
i
]
:
input_shape
[
j
]
+
end_dms
[
i
];
// Clipping 'end'
end_dms
[
i
]
=
(
end_dms
[
i
]
<
0
)
?
0
:
(
end_dms
[
i
]
>=
input_shape
[
j
]
?
input_shape
[
j
]
-
1
:
end_dms
[
i
]);
if
(
m_new_axis
.
find
(
i
)
==
m_new_axis
.
end
())
{
j
++
;
}
else
{
end_dms
[
i
]
=
0
;
}
if
(
m_shrink_axis
.
find
(
k
)
!=
m_shrink_axis
.
end
())
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
else
{
out_dims
.
push_back
(
static_cast
<
int
>
(
ceil
(
static_cast
<
float
>
(
abs
(
end_dms
[
i
]
-
begin_dms
[
i
])
+
1
)
/
static_cast
<
float
>
(
abs
(
stride_dms
[
i
])))));
}
k
++
;
}
return
out_dims
;
}
return
Shape
{};
}
void
op
::
DynSlice
::
validate_and_infer_types
()
{
auto
lower_bounds_et
=
get_input_element_type
(
1
);
...
...
@@ -219,17 +84,24 @@ void op::DynSlice::validate_and_infer_types()
set_input_is_relevant_to_shape
(
2
);
set_input_is_relevant_to_shape
(
3
);
if
(
get_input_partial_shape
(
0
).
is_static
())
auto
lower_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
1
));
auto
upper_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
2
));
auto
strides
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
));
if
(
lower_bounds
&&
upper_bounds
&&
strides
)
{
auto
shape
=
compute_output_shape
();
if
(
shape
!=
Shape
{})
{
set_output_type
(
0
,
get_input_element_type
(
0
),
shape
);
}
else
{
set_output_type
(
0
,
get_input_element_type
(
0
),
PartialShape
::
dynamic
(
arg_shape
.
rank
()));
}
set_output_type
(
0
,
get_input_element_type
(
0
),
infer_slice_shape
(
this
,
get_input_partial_shape
(
0
),
lower_bounds
->
get_vector
<
int64_t
>
(),
upper_bounds
->
get_vector
<
int64_t
>
(),
strides
->
get_vector
<
int64_t
>
(),
m_lower_bounds_mask
,
m_upper_bounds_mask
,
m_new_axis
,
m_shrink_axis
,
m_ellipsis_mask
));
}
else
{
...
...
src/ngraph/validation_util.cpp
View file @
b9a599a1
...
...
@@ -614,3 +614,165 @@ void ngraph::infer_auto_padding(const Shape& image_shape,
padding_above
.
push_back
(
pad_type
==
op
::
PadType
::
SAME_UPPER
?
padding_rhs
:
padding_lhs
);
}
}
PartialShape
ngraph
::
infer_slice_shape
(
const
Node
*
node
,
const
PartialShape
&
input_shape
,
const
std
::
vector
<
int64_t
>&
lb
,
const
std
::
vector
<
int64_t
>&
ub
,
const
std
::
vector
<
int64_t
>&
str
,
const
AxisSet
&
lb_mask
,
const
AxisSet
&
ub_mask
,
const
AxisSet
&
new_axis
,
const
AxisSet
&
shrink_axis
,
const
AxisSet
&
ellipsis_mask
)
{
// TODO(amprocte): double-check that these checks are needed.
if
(
lb
.
size
()
&&
ub
.
size
())
{
NODE_VALIDATION_CHECK
(
node
,
lb
.
size
()
==
ub
.
size
(),
"Lower bounds and Upper bounds needs to have same number of values"
);
}
if
(
lb
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
node
,
lb
.
size
()
==
str
.
size
(),
"Lower bounds and strides needs to have same number of values"
);
}
if
(
ub
.
size
()
&&
str
.
size
())
{
NODE_VALIDATION_CHECK
(
node
,
ub
.
size
()
==
str
.
size
(),
"Upper bounds and strides needs to have same number of values"
);
}
if
(
input_shape
.
rank
().
is_dynamic
())
{
return
PartialShape
::
dynamic
();
}
int
max_dims
=
size_t
(
input_shape
.
rank
())
+
new_axis
.
size
();
int
bounds_size
=
lb
.
size
()
?
lb
.
size
()
:
(
ub
.
size
()
?
ub
.
size
()
:
(
str
.
size
()
?
str
.
size
()
:
0
));
int
ellipsis_pos1
=
ellipsis_mask
.
size
()
?
*
ellipsis_mask
.
begin
()
:
max_dims
;
int
ellipsis_pos2
=
max_dims
;
bounds_size
-=
ellipsis_pos1
;
if
(
bounds_size
>
0
&&
(
max_dims
-
bounds_size
)
>
ellipsis_pos1
)
{
ellipsis_pos2
=
max_dims
-
bounds_size
;
}
std
::
vector
<
Dimension
>
begin_dms
(
max_dims
,
0
);
std
::
vector
<
Dimension
>
end_dms
(
max_dims
,
-
1
);
std
::
vector
<
Dimension
>
stride_dms
(
max_dims
,
1
);
int
i
,
j
,
k
,
bj
,
ej
,
sj
;
std
::
vector
<
Dimension
>
out_dims
;
for
(
i
=
0
,
j
=
0
,
k
=
0
,
bj
=
0
,
ej
=
0
,
sj
=
0
;
i
<
max_dims
;
i
++
)
{
if
(
i
>=
ellipsis_pos1
&&
i
<
ellipsis_pos2
)
{
if
(
new_axis
.
find
(
i
)
==
new_axis
.
end
())
{
end_dms
[
i
]
=
end_dms
[
i
].
is_static
()
&&
int64_t
(
end_dms
[
i
])
>=
0
?
end_dms
[
i
]
:
input_shape
[
j
++
]
+
end_dms
[
i
];
}
else
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
out_dims
.
push_back
(
(
end_dms
[
i
].
is_dynamic
()
||
begin_dms
[
i
].
is_dynamic
()
||
stride_dms
[
i
].
is_dynamic
())
?
Dimension
::
dynamic
()
:
static_cast
<
int64_t
>
(
ceil
(
static_cast
<
float
>
(
abs
(
int64_t
(
end_dms
[
i
])
-
int64_t
(
begin_dms
[
i
]))
+
1
)
/
static_cast
<
float
>
(
abs
(
int64_t
(
stride_dms
[
i
]))))));
k
=
ellipsis_pos1
;
continue
;
}
stride_dms
[
i
]
=
(
str
.
size
()
>
sj
&&
str
[
sj
]
!=
0
)
?
str
[
sj
++
]
:
1
;
// Use lower_bounds if mask is not set
if
(
lb_mask
.
find
(
j
)
==
lb_mask
.
end
())
{
begin_dms
[
i
]
=
lb
.
size
()
>
bj
?
lb
[
bj
]
:
(
stride_dms
[
i
].
is_dynamic
()
?
Dimension
::
dynamic
()
:
(
int64_t
(
stride_dms
[
i
])
>
0
?
0
:
-
1
));
}
else
{
begin_dms
[
i
]
=
stride_dms
[
i
].
is_dynamic
()
?
Dimension
::
dynamic
()
:
(
int64_t
(
stride_dms
[
i
])
>
0
?
0
:
-
1
);
}
bj
++
;
begin_dms
[
i
]
=
(
begin_dms
[
i
].
is_static
()
&&
int64_t
(
begin_dms
[
i
])
>=
0
)
?
begin_dms
[
i
]
:
input_shape
[
j
]
+
begin_dms
[
i
];
// Clipping 'begin'
begin_dms
[
i
]
=
(
begin_dms
[
i
].
is_static
()
&&
int64_t
(
begin_dms
[
i
])
<
0
)
?
0
:
(
begin_dms
[
i
].
is_static
()
&&
input_shape
[
j
].
is_static
()
&&
int64_t
(
begin_dms
[
i
])
>=
int64_t
(
input_shape
[
j
])
?
input_shape
[
j
]
-
1
:
begin_dms
[
i
]);
// Use upper_bounds if mask is not set
if
(
ub_mask
.
find
(
j
)
==
ub_mask
.
end
())
{
Dimension
end_dms_tmp
=
ub
.
size
()
>
ej
?
(
stride_dms
[
i
].
is_static
()
&&
int64_t
(
stride_dms
[
i
])
>
0
?
ub
[
ej
]
-
1
:
ub
[
ej
]
+
1
)
:
end_dms
[
i
];
end_dms
[
i
]
=
ub
.
size
()
>
ej
?
end_dms_tmp
:
(
stride_dms
[
i
].
is_static
()
&&
int64_t
(
stride_dms
[
i
])
>
0
?
-
1
:
0
);
}
else
{
end_dms
[
i
]
=
stride_dms
[
i
].
is_static
()
&&
int64_t
(
stride_dms
[
i
])
>
0
?
-
1
:
0
;
}
ej
++
;
end_dms
[
i
]
=
end_dms
[
i
].
is_static
()
&&
int64_t
(
end_dms
[
i
])
>=
0
?
end_dms
[
i
]
:
input_shape
[
j
]
+
end_dms
[
i
];
// Clipping 'end'
end_dms
[
i
]
=
(
end_dms
[
i
].
is_static
()
&&
int64_t
(
end_dms
[
i
])
<
0
)
?
0
:
(
end_dms
[
i
].
is_static
()
&&
input_shape
[
j
].
is_static
()
&&
int64_t
(
end_dms
[
i
])
>=
int64_t
(
input_shape
[
j
])
?
input_shape
[
j
]
-
1
:
end_dms
[
i
]);
if
(
new_axis
.
find
(
i
)
==
new_axis
.
end
())
{
j
++
;
}
else
{
end_dms
[
i
]
=
0
;
}
if
(
shrink_axis
.
find
(
k
)
!=
shrink_axis
.
end
())
{
end_dms
[
i
]
=
begin_dms
[
i
];
}
else
{
out_dims
.
push_back
(
end_dms
[
i
].
is_dynamic
()
||
begin_dms
[
i
].
is_dynamic
()
||
stride_dms
[
i
].
is_dynamic
()
?
Dimension
::
dynamic
()
:
static_cast
<
int64_t
>
(
ceil
(
static_cast
<
float
>
(
abs
(
int64_t
(
end_dms
[
i
])
-
int64_t
(
begin_dms
[
i
]))
+
1
)
/
static_cast
<
float
>
(
abs
(
int64_t
(
stride_dms
[
i
]))))));
}
k
++
;
}
return
out_dims
;
}
src/ngraph/validation_util.hpp
View file @
b9a599a1
...
...
@@ -92,4 +92,15 @@ namespace ngraph
const
op
::
PadType
pad_type
,
CoordinateDiff
&
padding_above
,
CoordinateDiff
&
padding_below
);
PartialShape
infer_slice_shape
(
const
Node
*
node
,
const
PartialShape
&
input_shape
,
const
std
::
vector
<
int64_t
>&
lb
,
const
std
::
vector
<
int64_t
>&
ub
,
const
std
::
vector
<
int64_t
>&
str
,
const
AxisSet
&
lb_mask
,
const
AxisSet
&
ub_mask
,
const
AxisSet
&
new_axis
,
const
AxisSet
&
shrink_mask
,
const
AxisSet
&
ellipsis_mask
);
}
test/type_prop.cpp
View file @
b9a599a1
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment