Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
c99d65a0
Commit
c99d65a0
authored
Jun 12, 2019
by
Adam Procter
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add wip files
parent
b9a599a1
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
226 additions
and
0 deletions
+226
-0
dyn_replace_slice.cpp
src/ngraph/op/experimental/dyn_replace_slice.cpp
+148
-0
dyn_replace_slice.hpp
src/ngraph/op/experimental/dyn_replace_slice.hpp
+78
-0
No files found.
src/ngraph/op/experimental/dyn_replace_slice.cpp
0 → 100644
View file @
c99d65a0
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include <memory>
using
namespace
std
;
using
namespace
ngraph
;
op
::
DynReplaceSlice
::
DynReplaceSlice
(
const
shared_ptr
<
Node
>&
arg
,
const
shared_ptr
<
Node
>&
replacement
,
const
shared_ptr
<
Node
>&
lower_bounds
,
const
shared_ptr
<
Node
>&
upper_bounds
,
const
shared_ptr
<
Node
>&
strides
,
const
AxisSet
&
lower_bounds_mask
,
const
AxisSet
&
upper_bounds_mask
,
const
AxisSet
&
new_axis
,
const
AxisSet
&
shrink_axis
,
const
AxisSet
&
ellipsis_mask
)
:
Op
(
"DynReplaceSlice"
,
check_single_output_args
({
arg
,
replacement
,
lower_bounds
,
upper_bounds
,
strides
}))
,
m_lower_bounds_mask
(
lower_bounds_mask
)
,
m_upper_bounds_mask
(
upper_bounds_mask
)
,
m_new_axis
(
new_axis
)
,
m_shrink_axis
(
shrink_axis
)
,
m_ellipsis_mask
(
ellipsis_mask
)
{
constructor_validate_and_infer_types
();
}
void
op
::
DynReplaceSlice
::
validate_and_infer_types
()
{
auto
arg_et
=
get_input_element_type
(
0
);
auto
replacement_et
=
get_input_element_type
(
1
);
auto
lower_bounds_et
=
get_input_element_type
(
2
);
auto
upper_bounds_et
=
get_input_element_type
(
3
);
auto
strides_et
=
get_input_element_type
(
4
);
element
::
Type
result_et
;
// check data types
NODE_VALIDATION_CHECK
(
this
,
element
::
Type
::
merge
(
result_et
,
arg_et
,
replacement_et
),
"Argument element type is not compatible with replacement element type"
);
NODE_VALIDATION_CHECK
(
this
,
lower_bounds_et
.
compatible
(
element
::
Type_t
::
i64
),
"Lower bounds must have element type i64."
);
NODE_VALIDATION_CHECK
(
this
,
upper_bounds_et
.
compatible
(
element
::
Type_t
::
i64
),
"Upper bounds must have element type i64."
);
NODE_VALIDATION_CHECK
(
this
,
strides_et
.
compatible
(
element
::
Type_t
::
i64
),
"Strides must have element type i64"
);
// check shapes
auto
arg_shape
=
get_input_partial_shape
(
0
);
auto
replacement_shape
=
get_input_partial_shape
(
1
);
auto
lower_bounds_shape
=
get_input_partial_shape
(
2
);
auto
upper_bounds_shape
=
get_input_partial_shape
(
3
);
auto
strides_shape
=
get_input_partial_shape
(
4
);
NODE_VALIDATION_CHECK
(
this
,
lower_bounds_shape
.
rank
().
compatible
(
1
),
"Lower bounds shape must have rank 1, got "
,
lower_bounds_shape
.
rank
(),
"."
);
NODE_VALIDATION_CHECK
(
this
,
upper_bounds_shape
.
rank
().
compatible
(
1
),
"Upper bounds shape must have rank 1, got "
,
upper_bounds_shape
.
rank
(),
"."
);
NODE_VALIDATION_CHECK
(
this
,
strides_shape
.
rank
().
compatible
(
1
),
"Strides shape must have rank 1, got "
,
strides_shape
.
rank
(),
"."
);
set_input_is_relevant_to_shape
(
2
);
set_input_is_relevant_to_shape
(
3
);
set_input_is_relevant_to_shape
(
4
);
auto
lower_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
2
));
auto
upper_bounds
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
3
));
auto
strides
=
dynamic_pointer_cast
<
op
::
Constant
>
(
get_argument
(
4
));
if
(
lower_bounds
&&
upper_bounds
&&
strides
)
{
auto
inferred_slice_shape
=
infer_slice_shape
(
this
,
get_input_partial_shape
(
0
),
lower_bounds
->
get_vector
<
int64_t
>
(),
upper_bounds
->
get_vector
<
int64_t
>
(),
strides
->
get_vector
<
int64_t
>
(),
m_lower_bounds_mask
,
m_upper_bounds_mask
,
m_new_axis
,
m_shrink_axis
,
m_ellipsis_mask
);
NODE_VALIDATION_CHECK
(
this
,
replacement_shape
.
compatible
(
inferred_slice_shape
),
"Shape of the replacement is not compatible with the shape of the "
"slice (shape of slice: "
,
inferred_slice_shape
,
")"
);
}
PartialShape
output_shape
=
arg_shape
;
NODE_VALIDATION_CHECK
(
this
,
PartialShape
::
merge_into
(
output_shape
,
PartialShape
::
dynamic
(
replacement_shape
.
rank
())),
"Rank of the replacement is not compatible with rank of the argument tensor"
);
set_output_type
(
0
,
get_input_element_type
(
0
),
output_shape
);
}
shared_ptr
<
Node
>
op
::
DynReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
DynReplaceSlice
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
new_args
.
at
(
4
),
m_lower_bounds_mask
,
m_upper_bounds_mask
,
m_new_axis
,
m_shrink_axis
,
m_ellipsis_mask
);
}
void
op
::
DynReplaceSlice
::
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
{
throw
ngraph_error
(
"generate_adjoints not implemented for DynReplaceSlice"
);
}
src/ngraph/op/experimental/dyn_replace_slice.hpp
0 → 100644
View file @
c99d65a0
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace
ngraph
{
namespace
op
{
/// \brief Takes a slice of an input tensor, i.e., the sub-tensor that resides within a bounding box, optionally with stride.
class
DynReplaceSlice
:
public
Op
{
public
:
/// \brief Constructs a dynamic tensor replace-slice operation.
///
/// \param arg The tensor in which to replace the slice.
/// \param replacement Data to copy to the slice for replacement.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
/// \param lower_bounds_mask Ignores lower_bounds for axis with the mask set
/// \param upper_bounds_mask Ignores upper_bounds for axis with the mask set
/// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position
DynReplaceSlice
(
const
std
::
shared_ptr
<
Node
>&
arg
,
const
std
::
shared_ptr
<
Node
>&
replacement
,
const
std
::
shared_ptr
<
Node
>&
lower_bounds
,
const
std
::
shared_ptr
<
Node
>&
upper_bounds
,
const
std
::
shared_ptr
<
Node
>&
strides
,
const
AxisSet
&
lower_bounds_mask
=
AxisSet
{},
const
AxisSet
&
upper_bounds_mask
=
AxisSet
{},
const
AxisSet
&
new_axis
=
AxisSet
{},
const
AxisSet
&
shrink_axis
=
AxisSet
{},
const
AxisSet
&
ellipsis_mask
=
AxisSet
{});
const
AxisSet
&
get_lower_bounds_mask
()
const
{
return
m_lower_bounds_mask
;
}
const
AxisSet
&
get_upper_bounds_mask
()
const
{
return
m_upper_bounds_mask
;
}
const
AxisSet
&
get_new_axis
()
const
{
return
m_new_axis
;
}
const
AxisSet
&
get_shrink_axis
()
const
{
return
m_shrink_axis
;
}
const
AxisSet
&
get_ellipsis_mask
()
const
{
return
m_ellipsis_mask
;
}
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
void
validate_and_infer_types
()
override
;
private
:
/// Helper method to compute output shape
Shape
compute_output_shape
()
const
;
AxisSet
m_lower_bounds_mask
;
AxisSet
m_upper_bounds_mask
;
AxisSet
m_new_axis
;
AxisSet
m_shrink_axis
;
AxisSet
m_ellipsis_mask
;
};
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment