Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
72bf9831
Unverified
Commit
72bf9831
authored
May 16, 2019
by
Michał Karzyński
Committed by
GitHub
May 16, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Fused] Unsqueeze op (#2916)
parent
b2ca3e79
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
179 additions
and
36 deletions
+179
-36
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-0
unsqueeze.cpp
src/ngraph/frontend/onnx_import/op/unsqueeze.cpp
+8
-35
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
unsqueeze.cpp
src/ngraph/op/fused/unsqueeze.cpp
+86
-0
unsqueeze.hpp
src/ngraph/op/fused/unsqueeze.hpp
+43
-0
fused_op_tbl.hpp
src/ngraph/op/fused_op_tbl.hpp
+1
-0
intelgpu_backend.cpp
src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
+4
-1
serializer.cpp
src/ngraph/serializer.cpp
+8
-0
backend_fused_op.in.cpp
test/backend_fused_op.in.cpp
+15
-0
type_prop.cpp
test/type_prop.cpp
+11
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
72bf9831
...
...
@@ -312,6 +312,8 @@ set (SRC
op/fused/space_to_depth.hpp
op/fused/squeeze.cpp
op/fused/squeeze.hpp
op/fused/unsqueeze.cpp
op/fused/unsqueeze.hpp
op/util/arithmetic_reduction.cpp
op/util/arithmetic_reduction.hpp
op/util/binary_elementwise_arithmetic.cpp
...
...
src/ngraph/frontend/onnx_import/op/unsqueeze.cpp
View file @
72bf9831
...
...
@@ -14,16 +14,9 @@
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include <set>
#include <vector>
#include "exceptions.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/reshape.hpp"
#include "ngraph/util.hpp"
#include "unsqueeze.hpp"
#include "utils/reshape.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/constant.hpp"
#include "squeeze.hpp"
namespace
ngraph
{
...
...
@@ -35,31 +28,11 @@ namespace ngraph
{
NodeVector
unsqueeze
(
const
Node
&
node
)
{
NodeVector
inputs
{
node
.
get_ng_inputs
()};
auto
data
=
inputs
.
at
(
0
);
auto
data_shape
=
data
->
get_shape
();
auto
axes
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
int64_t
>>
(
"axes"
);
ASSERT_VALID_ARGUMENT
(
node
,
!
axes
.
empty
())
<<
"'axes' attribute is mandatory."
;
ASSERT_VALID_ARGUMENT
(
node
,
axes
.
size
()
==
std
::
set
<
std
::
int64_t
>
(
std
::
begin
(
axes
),
std
::
end
(
axes
)).
size
())
<<
"'axes' has a duplicate axis."
;
std
::
sort
(
std
::
begin
(
axes
),
std
::
end
(
axes
),
std
::
less
<
int64_t
>
());
AxisVector
input_order
{
ngraph
::
get_default_order
(
data_shape
.
size
())};
for
(
auto
axis
:
axes
)
{
ASSERT_VALID_ARGUMENT
(
node
,
axis
>=
0
&&
axis
<=
data_shape
.
size
())
<<
"provided 'axes' attribute is not valid."
;
data_shape
.
insert
(
std
::
next
(
std
::
begin
(
data_shape
),
axis
),
1
);
}
return
{
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
data
,
input_order
,
data_shape
)};
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
axes
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
int64_t
>>
(
"axes"
,
{});
auto
axes_node
=
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
element
::
i64
,
Shape
{
axes
.
size
()},
axes
);
return
{
std
::
make_shared
<
ngraph
::
op
::
Unsqueeze
>
(
data
,
axes_node
)};
}
}
// namespace set_1
...
...
src/ngraph/ngraph.hpp
View file @
72bf9831
...
...
@@ -109,6 +109,7 @@
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
...
...
src/ngraph/op/fused/unsqueeze.cpp
0 → 100644
View file @
72bf9831
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstddef>
#include <functional>
#include <iterator>
#include <set>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
using
namespace
std
;
using
namespace
ngraph
;
op
::
Unsqueeze
::
Unsqueeze
(
const
shared_ptr
<
Node
>&
data
,
const
shared_ptr
<
Node
>&
axes
)
:
FusedOp
(
"Unsqueeze"
,
{
data
,
axes
})
{
constructor_validate_and_infer_types
();
}
void
op
::
Unsqueeze
::
pre_validate_and_infer_types
()
{
auto
axes_node
=
get_argument
(
1
);
// Currently only support Constant node for axes.
NODE_VALIDATION_CHECK
(
this
,
axes_node
->
is_constant
(),
"doesn't support 'axes' input of other type than a Constant."
);
}
NodeVector
op
::
Unsqueeze
::
decompose_op
()
const
{
auto
data
=
get_argument
(
0
);
auto
axes_node
=
get_argument
(
1
);
// Get value of axes from Constant
auto
axes_constant
=
dynamic_pointer_cast
<
op
::
Constant
>
(
axes_node
);
auto
axes
=
axes_constant
->
get_vector
<
size_t
>
();
auto
data_shape
=
data
->
get_shape
();
NODE_VALIDATION_CHECK
(
this
,
!
axes
.
empty
(),
"'axes' input is mandatory."
);
NODE_VALIDATION_CHECK
(
this
,
axes
.
size
()
==
set
<
int64_t
>
(
begin
(
axes
),
end
(
axes
)).
size
(),
"'axes' input has a duplicate axis."
);
sort
(
begin
(
axes
),
end
(
axes
),
less
<
int64_t
>
());
AxisVector
input_order
{
ngraph
::
get_default_order
(
data_shape
.
size
())};
for
(
auto
axis
:
axes
)
{
NODE_VALIDATION_CHECK
(
this
,
axis
>=
0
&&
axis
<=
data_shape
.
size
(),
"provided 'axes' value "
,
axis
,
" is not valid."
);
data_shape
.
insert
(
next
(
begin
(
data_shape
),
axis
),
1
);
}
return
{
make_shared
<
ngraph
::
op
::
Reshape
>
(
data
,
input_order
,
data_shape
)};
}
shared_ptr
<
Node
>
op
::
Unsqueeze
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
if
(
new_args
.
size
()
!=
2
)
{
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
return
make_shared
<
Unsqueeze
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
));
}
src/ngraph/op/fused/unsqueeze.hpp
0 → 100644
View file @
72bf9831
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace
ngraph
{
namespace
op
{
class
Unsqueeze
:
public
ngraph
::
op
::
util
::
FusedOp
{
public
:
Unsqueeze
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
data
,
const
std
::
shared_ptr
<
ngraph
::
Node
>&
axes
);
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
}
}
src/ngraph/op/fused_op_tbl.hpp
View file @
72bf9831
...
...
@@ -33,3 +33,4 @@ NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP
(
ScaleShift
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
Squeeze
,
ngraph
::
op
)
NGRAPH_OP
(
Unsqueeze
,
ngraph
::
op
)
src/ngraph/runtime/intelgpu/intelgpu_backend.cpp
View file @
72bf9831
...
...
@@ -90,6 +90,7 @@
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
...
...
@@ -2079,6 +2080,7 @@ shared_ptr<runtime::Executable>
case
OP_TYPEID
:
:
StopGradient
:
case
OP_TYPEID
:
:
Tile
:
case
OP_TYPEID
:
:
Transpose
:
case
OP_TYPEID
:
:
Unsqueeze
:
default
:
{
throw
unsupported_op
(
"Unsupported op '"
+
op
->
description
()
+
...
...
@@ -2171,7 +2173,8 @@ bool runtime::intelgpu::IntelGPUBackend::is_supported_impl(const Node& node)
case
OP_TYPEID
:
:
PRelu
:
case
OP_TYPEID
:
:
ScaleShift
:
case
OP_TYPEID
:
:
SpaceToDepth
:
case
OP_TYPEID
:
:
Squeeze
:
{
return
false
;
case
OP_TYPEID
:
:
Squeeze
:
case
OP_TYPEID
:
:
Unsqueeze
:
{
return
false
;
}
default
:
{
return
true
;
}
...
...
src/ngraph/serializer.cpp
View file @
72bf9831
...
...
@@ -80,6 +80,7 @@
#include "ngraph/op/fused/scale_shift.hpp"
#include "ngraph/op/fused/space_to_depth.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/op/get_output_element.hpp"
...
...
@@ -1501,6 +1502,11 @@ static shared_ptr<ngraph::Function>
node
=
make_shared
<
op
::
StopGradient
>
(
args
[
0
]);
break
;
}
case
OP_TYPEID
:
:
Unsqueeze
:
{
node
=
make_shared
<
op
::
Unsqueeze
>
(
args
[
0
],
args
[
1
]);
break
;
}
case
OP_TYPEID
:
:
UnknownOp
:
{
stringstream
ss
;
...
...
@@ -2227,6 +2233,8 @@ static json write(const Node& n, bool binary_constant_data)
}
case
OP_TYPEID
:
:
Transpose
:
{
break
;
}
case
OP_TYPEID
:
:
Unsqueeze
:
{
break
;
}
case
OP_TYPEID
:
:
UnknownOp
:
{
break
;
}
}
...
...
test/backend_fused_op.in.cpp
View file @
72bf9831
...
...
@@ -773,6 +773,21 @@ NGRAPH_TEST(${BACKEND_NAME}, grn_2d_with_bias)
test_case
.
run
();
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
unsqueeze
)
{
auto
data_node
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
2
});
auto
axes_node
=
make_shared
<
ngraph
::
op
::
Constant
>
(
element
::
u64
,
Shape
{
2
},
vector
<
int64_t
>
{
1
,
2
});
auto
squeeze
=
make_shared
<
op
::
Unsqueeze
>
(
data_node
,
axes_node
);
auto
function
=
make_shared
<
Function
>
(
NodeVector
{
squeeze
},
ParameterVector
{
data_node
});
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
function
,
"${BACKEND_NAME}"
);
auto
data
=
vector
<
float
>
{
1.0
f
,
2.0
f
,
3.0
f
,
4.0
f
,
5.0
f
,
6.0
f
,
7.0
f
,
8.0
f
};
test_case
.
add_input
(
data
);
test_case
.
add_expected_output
<
float
>
(
Shape
{
4
,
1
,
1
,
2
},
data
);
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
scale_shift_no_broadcast
)
{
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f64
,
Shape
{
3
,
6
});
...
...
test/type_prop.cpp
View file @
72bf9831
...
...
@@ -14495,6 +14495,17 @@ TEST(type_prop, fused_clamp)
EXPECT_EQ
(
clamp
->
get_shape
(),
(
Shape
{
2
,
2
}));
}
TEST
(
type_prop
,
unsqueeze
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
1
,
4
,
1
,
8
});
auto
axes_node
=
make_shared
<
ngraph
::
op
::
Constant
>
(
element
::
u64
,
Shape
{
2
},
vector
<
int64_t
>
{
1
,
2
});
auto
squeeze
=
make_shared
<
op
::
Unsqueeze
>
(
param
,
axes_node
);
ASSERT_EQ
(
squeeze
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
squeeze
->
get_shape
(),
(
Shape
{
4
,
1
,
1
,
1
,
4
,
1
,
8
}));
}
TEST
(
type_prop
,
scale_shift_no_broadcast
)
{
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f64
,
Shape
{
3
,
6
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment