Commit 80022ea1 authored by Ivan Tikhonov's avatar Ivan Tikhonov Committed by Scott Cyphers

New DynSlice (Strided Slice) realization (#3662)

* Strided slice

* Strided slice

* default value for strides

* Added new strided slice test, enabled old tests, refactoring

* Refactoring

* Autogenerated file: dyn_replace_slice tests

* Renaming

* Fix codestyle

* Fix build on MacOS

* Fix codestyle

* Add several tests in unit_test.manifest to skip it on PlaidML

* Disable all dyn_replace_slice tests on PlaidML
parent 04f212b7
......@@ -281,3 +281,6 @@ random_uniform_all_static_seed_used
random_uniform_seed_use_dynamic
random_uniform_all_static_range_dynamic
random_uniform_dynamic_shapes
# shapes with zeros dimensions like (5, 0, 5) not supported in PlaidML backend
dyn_replace_slice
\ No newline at end of file
This diff is collapsed.
......@@ -94,12 +94,12 @@ namespace ngraph
PartialShape infer_slice_shape(const Node* node,
const PartialShape& input_shape,
const std::vector<int64_t>& lb,
const std::vector<int64_t>& ub,
const std::vector<int64_t>& str,
const AxisSet& lb_mask,
const AxisSet& ub_mask,
const AxisSet& new_axis,
const AxisSet& shrink_mask,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& end,
const std::vector<int64_t>& strides,
const AxisSet& begin_mask,
const AxisSet& end_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -618,9 +618,7 @@ def main():
t[3::-2] = None
t[4::-2] = None
t[5::-2] = None
# TODO(amprocte): Failing due to bug in DynReplaceSlice inference.
# Re-enable when NGCORE-510 is fixed.
#t[-9000:-8000:2] = None
t[-9000:-8000:2] = None
t[-9000:8000:2] = None
t[-5:5:2] = None
t[np.newaxis] = None
......@@ -639,9 +637,7 @@ def main():
t[0:100:2] = None
t[4:0:-2] = None
t[4:0:-3] = None
# TODO(amprocte): Failing due to bug in DynReplaceSlice inference.
# Re-enable when NGCORE-510 is fixed.
#t[3:2:1] = None
t[3:2:1] = None
t[4::-2] = None
#
......@@ -690,6 +686,35 @@ def main():
t[7:0:-3] = None
t[7::-3] = None
t.set_dtype('int32')
t.set_shape((4, 5))
t[2:4, ...] = None
t[4:2, ...] = None
t[4:2:-3, ...] = None
t[-100:100, ...] = None
t[..., 2:] = None
t[..., 2:4] = None
t[..., :] = None
t[..., -100:100] = None
t.set_shape((5, 6, 10, 8))
t[2:4, ..., 1:7:3, 7:2:-2] = None
t[..., 1:7:3, 7:2:-2] = None
t[2:4, ..., :3, -3:2:-2] = None
t[2:4, ..., 1:7:-3, 7:2:-2] = None
t[2:4, ..., :, np.newaxis, 0] = None
t.set_shape((2, 2, 3, 2, 3, 3))
t[2:6:2, ..., :, 2:1:-1] = None
t[np.newaxis, 1, ..., np.newaxis, 2:1:-1] = None
t[1, ..., np.newaxis, 2:1:-1] = None
t[np.newaxis, np.newaxis, 2:1:-1, ...] = None
t.set_shape((3, 3, 3, 2, 3))
t[6:1:-2, ..., 1:2, 2:1:-1] = None
t.set_shape((3, 3, 3, 2, 3))
t[..., 1:2, 2:1:-1] = None
t.set_dtype('int32')
t[80000] = None # error expected (shrink-axis OOB)
t[-80000] = None # error expected (shrink-axis OOB)
......
......@@ -513,7 +513,7 @@ INSTANTIATE_TEST_CASE_P(
/* Axis Masks: New, Shrink, Ellipsis */
DynReplaceSliceParams{{10}, {1}, {1}, {0}, {1, 10}, {0}, {10}, {}, {}, {}, {0}, {}, {}},
DynReplaceSliceParams{
{1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}, {0, 0}, {1, 2}, {}, {}, {}, {}, {}, {1}},
{1, 2, 3}, {2}, {2}, {0}, {1, 2, 3}, {0, 0}, {1, 2}, {}, {}, {}, {}, {}, {1}},
DynReplaceSliceParams{{1, 2, 3},
{4},
{4},
......@@ -530,7 +530,53 @@ INSTANTIATE_TEST_CASE_P(
DynReplaceSliceParams{
{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}, {0, 0, 1}, {2, 2, 2}, {}, {}, {}, {0}, {}, {1}},
DynReplaceSliceParams{
{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}, {-1}, {0}, {-2}, {1}, {1}, {}, {1}, {}},
{1, 2, 2, 2}, {1}, {1}, {1}, {0, 2, 2, 2}, {-1}, {0}, {-2}, {1}, {1}, {}, {1}, {}},
DynReplaceSliceParams{{9, 10, 12, 2, 3}, /*arg_shape*/
{4}, /*lower_bounds_shape*/
{4}, /*upper_bounds_shape*/
{4}, /*strides_shape*/
{2, 10, 12, 2, 0}, /*replacement_shape*/
{2, 0, 0, 3}, /*lower_bounds_val*/
{6, 0, 0, 2}, /*upper_bounds_val*/
{2, 1, 1, -1}, /*strides_val*/
{2}, /*lower_bounds_mask*/
{2}, /*upper_bounds_mask*/
{}, /*new_axis*/
{}, /*shrink_axis*/
{1}}, /*ellipsis_mask*/
DynReplaceSliceParams{{9, 10, 12, 2, 3}, /*arg_shape*/
{4}, /*lower_bounds_shape*/
{4}, /*upper_bounds_shape*/
{4}, /*strides_shape*/
{3, 10, 12, 1, 0}, /*replacement_shape*/
{6, 0, 1, 3}, /*lower_bounds_val*/
{1, 0, 2, 2}, /*upper_bounds_val*/
{-2, 1, 1, -1}, /*strides_val*/
{}, /*lower_bounds_mask*/
{}, /*upper_bounds_mask*/
{}, /*new_axis*/
{}, /*shrink_axis*/
{1}}, /*ellipsis_mask*/
DynReplaceSliceParams{{9, 10, 12, 2, 3}, /*arg_shape*/
{3}, /*lower_bounds_shape*/
{3}, /*upper_bounds_shape*/
{3}, /*strides_shape*/
{9, 10, 12, 1, 0}, /*replacement_shape*/
{0, 1, 3}, /*lower_bounds_val*/
{0, 2, 2}, /*upper_bounds_val*/
{1, 1, -1}, /*strides_val*/
{}, /*lower_bounds_mask*/
{}, /*upper_bounds_mask*/
{}, /*new_axis*/
{}, /*shrink_axis*/
{0}}, /*ellipsis_mask*/
DynReplaceSliceParams{{1, 2, 2, 2},
{4},
{4},
......
......@@ -179,9 +179,8 @@ INSTANTIATE_TEST_CASE_P(
type_prop,
DeduceDynSliceTest,
::testing::Values(
// TODO(jbobba): These tests should pass.
// DynSliceParams({{4}, {1}, {1}, {1}, {0}}, {{-9000}, {-8000}, {2}}, {{}, {}, {}, {}, {}}),
// DynSliceParams({{5}, {1}, {1}, {1}, {0}}, {{3}, {2}, {1}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{4}, {1}, {1}, {1}, {0}}, {{-9000}, {-8000}, {2}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{5}, {1}, {1}, {1}, {0}}, {{3}, {2}, {1}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{2, 3, 4, 5, 6}, {5}, {5}, {5}, {1, 2, 1, 1, 3}},
{{0, 1, 2, 3, 1}, {1, 3, 3, 5, 6}, {1, 1, 1, 2, 2}},
{{}, {}, {}, {}, {}}),
......@@ -205,7 +204,7 @@ INSTANTIATE_TEST_CASE_P(
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{-1}, {0}, {-2}}, {{}, {0}, {}, {}, {}}),
// Axis Masks: New, Shrink, Ellipsis
DynSliceParams({{10}, {1}, {1}, {0}, {1, 10}}, {{0}, {10}, {}}, {{}, {}, {0}, {}, {}}),
DynSliceParams({{1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}},
DynSliceParams({{1, 2, 3}, {2}, {2}, {0}, {1, 2, 3}},
{{0, 0}, {1, 2}, {}},
{{}, {}, {}, {}, {1}}),
DynSliceParams({{1, 2, 3}, {4}, {4}, {0}, {1, 2, 1}},
......@@ -214,7 +213,7 @@ INSTANTIATE_TEST_CASE_P(
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}},
{{0, 0, 1}, {2, 2, 2}, {}},
{{}, {}, {0}, {}, {1}}),
DynSliceParams({{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}},
DynSliceParams({{1, 2, 2, 2}, {1}, {1}, {1}, {0, 2, 2, 2}},
{{-1}, {0}, {-2}},
{{1}, {1}, {}, {1}, {}}),
DynSliceParams({{1, 2, 2, 2}, {4}, {4}, {0}, {1, 2, 2}},
......@@ -222,7 +221,56 @@ INSTANTIATE_TEST_CASE_P(
{{1}, {1}, {}, {1}, {}}),
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2}},
{{0, 0, 1}, {2, 2, 2}, {}},
{{}, {}, {0}, {2}, {1}})));
{{}, {}, {0}, {2}, {1}}),
DynSliceParams({{9, 10, 12, 2, 3}, /*arg_shape*/
{4}, /*lower_bounds_shape*/
{4}, /*upper_bounds_shape*/
{4}, /*strides_shape*/
{2, 10, 12, 2, 0}}, /*replacement_shape*/
{{2, 0, 0, 3}, /*lower_bounds_val*/
{6, 0, 0, 2}, /*upper_bounds_val*/
{2, 1, 1, -1}}, /*strides_val*/
{{2}, /*lower_bounds_mask*/
{2}, /*upper_bounds_mask*/
{}, /*new_axis*/
{}, /*shrink_axis*/
{1}}), /*ellipsis_mask*/
DynSliceParams({{9, 10, 12, 2, 3}, /*arg_shape*/
{4}, /*lower_bounds_shape*/
{4}, /*upper_bounds_shape*/
{4}, /*strides_shape*/
{3, 10, 12, 1, 0}}, /*replacement_shape*/
{{6, 0, 1, 3}, /*lower_bounds_val*/
{1, 0, 2, 2}, /*upper_bounds_val*/
{-2, 1, 1, -1}}, /*strides_val*/
{{}, /*lower_bounds_mask*/
{}, /*upper_bounds_mask*/
{}, /*new_axis*/
{}, /*shrink_axis*/
{1}}), /*ellipsis_mask*/
DynSliceParams({{9, 10, 12, 2, 3}, /*arg_shape*/
{3}, /*lower_bounds_shape*/
{3}, /*upper_bounds_shape*/
{3}, /*strides_shape*/
{9, 10, 12, 1, 0}}, /*replacement_shape*/
{{0, 1, 3}, /*lower_bounds_val*/
{0, 2, 2}, /*upper_bounds_val*/
{1, 1, -1}}, /*strides_val*/
{{}, /*lower_bounds_mask*/
{}, /*upper_bounds_mask*/
{}, /*new_axis*/
{}, /*shrink_axis*/
{0}}) /*ellipsis_mask*/
));
void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment