Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
778b6004
Commit
778b6004
authored
Jun 19, 2019
by
Jayaram Bobba
Committed by
Scott Cyphers
Jun 19, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Move GroupConv and Slice op to output-handle based constructors (#3083)
parent
13c05a47
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
66 additions
and
24 deletions
+66
-24
group_conv.cpp
src/ngraph/op/fused/group_conv.cpp
+20
-15
group_conv.hpp
src/ngraph/op/fused/group_conv.hpp
+6
-2
slice.cpp
src/ngraph/op/slice.cpp
+10
-4
slice.hpp
src/ngraph/op/slice.hpp
+7
-3
fused_op.cpp
src/ngraph/op/util/fused_op.cpp
+5
-0
fused_op.hpp
src/ngraph/op/util/fused_op.hpp
+2
-0
build_graph.cpp
test/build_graph.cpp
+16
-0
No files found.
src/ngraph/op/fused/group_conv.cpp
View file @
778b6004
...
@@ -27,8 +27,14 @@
...
@@ -27,8 +27,14 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
GroupConvolution
::
GroupConvolution
(
const
shared_ptr
<
Node
>&
data_batch
,
const
string
op
::
GroupConvolution
::
type_name
{
"GroupConvolution"
};
const
shared_ptr
<
Node
>&
filters
,
op
::
GroupConvolution
::
GroupConvolution
()
{
}
op
::
GroupConvolution
::
GroupConvolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_below
,
...
@@ -36,7 +42,7 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
...
@@ -36,7 +42,7 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
const
Strides
&
data_dilation_strides
,
const
Strides
&
data_dilation_strides
,
const
size_t
groups
,
const
size_t
groups
,
const
PadType
&
pad_type
)
const
PadType
&
pad_type
)
:
FusedOp
(
"GroupConvolution"
,
check_single_output_args
({
data_batch
,
filters
})
)
:
FusedOp
(
{
data_batch
,
filters
}
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_window_movement_strides
(
window_movement_strides
)
,
m_window_dilation_strides
(
window_dilation_strides
)
,
m_window_dilation_strides
(
window_dilation_strides
)
,
m_padding_below
(
padding_below
)
,
m_padding_below
(
padding_below
)
...
@@ -45,7 +51,6 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
...
@@ -45,7 +51,6 @@ op::GroupConvolution::GroupConvolution(const shared_ptr<Node>& data_batch,
,
m_groups
(
groups
)
,
m_groups
(
groups
)
,
m_pad_type
(
pad_type
)
,
m_pad_type
(
pad_type
)
{
{
// TODO: Move this out of constructor to validate_and_infer_types()
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
...
@@ -129,35 +134,35 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_
...
@@ -129,35 +134,35 @@ shared_ptr<Node> op::GroupConvolution::copy_with_new_args(const NodeVector& new_
NodeVector
op
::
GroupConvolution
::
decompose_op
()
const
NodeVector
op
::
GroupConvolution
::
decompose_op
()
const
{
{
auto
data
=
get_argumen
t
(
0
);
auto
data
=
inpu
t
(
0
);
auto
filters
=
get_argumen
t
(
1
);
auto
filters
=
inpu
t
(
1
);
// Split one convolution op to N ops where N is the number of groups
// Split one convolution op to N ops where N is the number of groups
// and concat results after computation.
// and concat results after computation.
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
// reference: https://github.com/NervanaSystems/ngraph-mxnet/blob/fdd692/src/ngraph/ngraph_emitter.cc#L822-L856
std
::
size_t
n_data_channels
{
data
->
get_shape
().
at
(
1
)};
std
::
size_t
n_data_channels
{
data
.
get_shape
().
at
(
1
)};
std
::
size_t
n_filters_channels
{
filters
->
get_shape
().
at
(
0
)};
std
::
size_t
n_filters_channels
{
filters
.
get_shape
().
at
(
0
)};
std
::
size_t
data_group_size
{
n_data_channels
/
m_groups
};
std
::
size_t
data_group_size
{
n_data_channels
/
m_groups
};
std
::
size_t
filters_group_size
{
n_filters_channels
/
m_groups
};
std
::
size_t
filters_group_size
{
n_filters_channels
/
m_groups
};
NodeVector
convolution_nodes
;
NodeVector
convolution_nodes
;
// initial bounds for splice
// initial bounds for splice
std
::
vector
<
std
::
size_t
>
data_lower_bounds
(
data
->
get_shape
().
size
());
std
::
vector
<
std
::
size_t
>
data_lower_bounds
(
data
.
get_shape
().
size
());
std
::
vector
<
std
::
size_t
>
data_upper_bounds
{
data
->
get_shape
()};
std
::
vector
<
std
::
size_t
>
data_upper_bounds
{
data
.
get_shape
()};
std
::
vector
<
std
::
size_t
>
filters_lower_bounds
(
filters
->
get_shape
().
size
());
std
::
vector
<
std
::
size_t
>
filters_lower_bounds
(
filters
.
get_shape
().
size
());
std
::
vector
<
std
::
size_t
>
filters_upper_bounds
{
filters
->
get_shape
()};
std
::
vector
<
std
::
size_t
>
filters_upper_bounds
{
filters
.
get_shape
()};
for
(
std
::
size_t
group
{
0
};
group
<
m_groups
;
++
group
)
for
(
std
::
size_t
group
{
0
};
group
<
m_groups
;
++
group
)
{
{
// slice data
// slice data
data_lower_bounds
[
1
]
=
group
*
data_group_size
;
data_lower_bounds
[
1
]
=
group
*
data_group_size
;
data_upper_bounds
[
1
]
=
(
group
+
1
)
*
data_group_size
;
data_upper_bounds
[
1
]
=
(
group
+
1
)
*
data_group_size
;
auto
sliced_data
=
auto
sliced_data
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
data
,
data_lower_bounds
,
data_upper_bounds
);
data
.
get_source_output
()
,
data_lower_bounds
,
data_upper_bounds
);
// slice filters
// slice filters
filters_lower_bounds
[
0
]
=
group
*
filters_group_size
;
filters_lower_bounds
[
0
]
=
group
*
filters_group_size
;
filters_upper_bounds
[
0
]
=
(
group
+
1
)
*
filters_group_size
;
filters_upper_bounds
[
0
]
=
(
group
+
1
)
*
filters_group_size
;
auto
sliced_filters
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
auto
sliced_filters
=
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
filters
,
filters_lower_bounds
,
filters_upper_bounds
);
filters
.
get_source_output
()
,
filters_lower_bounds
,
filters_upper_bounds
);
convolution_nodes
.
push_back
(
convolution_nodes
.
push_back
(
std
::
make_shared
<
ngraph
::
op
::
Convolution
>
(
sliced_data
,
std
::
make_shared
<
ngraph
::
op
::
Convolution
>
(
sliced_data
,
...
...
src/ngraph/op/fused/group_conv.hpp
View file @
778b6004
...
@@ -29,8 +29,12 @@ namespace ngraph
...
@@ -29,8 +29,12 @@ namespace ngraph
class
GroupConvolution
:
public
ngraph
::
op
::
util
::
FusedOp
class
GroupConvolution
:
public
ngraph
::
op
::
util
::
FusedOp
{
{
public
:
public
:
GroupConvolution
(
const
std
::
shared_ptr
<
Node
>&
data_batch
,
NGRAPH_API
const
std
::
shared_ptr
<
Node
>&
filters
,
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
GroupConvolution
();
GroupConvolution
(
const
Output
<
Node
>&
data_batch
,
const
Output
<
Node
>&
filters
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_movement_strides
,
const
Strides
&
window_dilation_strides
,
const
Strides
&
window_dilation_strides
,
const
CoordinateDiff
&
padding_below
,
const
CoordinateDiff
&
padding_below
,
...
...
src/ngraph/op/slice.cpp
View file @
778b6004
...
@@ -19,11 +19,17 @@
...
@@ -19,11 +19,17 @@
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
op
::
Slice
::
Slice
(
const
shared_ptr
<
Node
>&
arg
,
const
string
op
::
Slice
::
type_name
{
"Slice"
};
op
::
Slice
::
Slice
()
{
}
op
::
Slice
::
Slice
(
const
Output
<
Node
>&
arg
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
upper_bounds
,
const
Coordinate
&
upper_bounds
,
const
Strides
&
strides
)
const
Strides
&
strides
)
:
Op
(
"Slice"
,
check_single_output_args
({
arg
})
)
:
Op
(
{
arg
}
)
,
m_lower_bounds
(
lower_bounds
)
,
m_lower_bounds
(
lower_bounds
)
,
m_upper_bounds
(
upper_bounds
)
,
m_upper_bounds
(
upper_bounds
)
,
m_strides
(
strides
)
,
m_strides
(
strides
)
...
@@ -31,10 +37,10 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
...
@@ -31,10 +37,10 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
op
::
Slice
::
Slice
(
const
shared_ptr
<
Node
>&
arg
,
op
::
Slice
::
Slice
(
const
Output
<
Node
>&
arg
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
upper_bounds
)
const
Coordinate
&
upper_bounds
)
:
Op
(
"Slice"
,
check_single_output_args
({
arg
})
)
:
Op
(
{
arg
}
)
,
m_lower_bounds
(
lower_bounds
)
,
m_lower_bounds
(
lower_bounds
)
,
m_upper_bounds
(
upper_bounds
)
,
m_upper_bounds
(
upper_bounds
)
,
m_strides
(
Strides
())
,
m_strides
(
Strides
())
...
...
src/ngraph/op/slice.hpp
View file @
778b6004
...
@@ -28,6 +28,11 @@ namespace ngraph
...
@@ -28,6 +28,11 @@ namespace ngraph
class
Slice
:
public
Op
class
Slice
:
public
Op
{
{
public
:
public
:
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a tensor slice operation
Slice
();
/// \brief Constructs a tensor slice operation.
/// \brief Constructs a tensor slice operation.
///
///
/// \param arg The tensor to be sliced.
/// \param arg The tensor to be sliced.
...
@@ -35,17 +40,16 @@ namespace ngraph
...
@@ -35,17 +40,16 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
/// every nth row and every mth column of the input matrix.
Slice
(
const
std
::
shared_ptr
<
Node
>&
arg
,
Slice
(
const
Output
<
Node
>&
arg
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
upper_bounds
,
const
Coordinate
&
upper_bounds
,
const
Strides
&
strides
);
const
Strides
&
strides
);
/// \brief Constructs a tensor slice operation with unit strides; i.e., every element inside the bounding box will be copied to the output slice.
/// \brief Constructs a tensor slice operation with unit strides; i.e., every element inside the bounding box will be copied to the output slice.
///
///
/// \param arg The tensor to be sliced.
/// \param arg The tensor to be sliced.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
Slice
(
const
std
::
shared_ptr
<
Node
>&
arg
,
Slice
(
const
Output
<
Node
>&
arg
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
lower_bounds
,
const
Coordinate
&
upper_bounds
);
const
Coordinate
&
upper_bounds
);
...
...
src/ngraph/op/util/fused_op.cpp
View file @
778b6004
...
@@ -30,6 +30,11 @@ op::util::FusedOp::FusedOp(const NodeVector& args)
...
@@ -30,6 +30,11 @@ op::util::FusedOp::FusedOp(const NodeVector& args)
{
{
}
}
op
::
util
::
FusedOp
::
FusedOp
(
const
OutputVector
&
args
)
:
Op
(
args
)
{
}
op
::
util
::
FusedOp
::
FusedOp
(
const
std
::
string
&
node_type
,
const
NodeVector
&
args
)
op
::
util
::
FusedOp
::
FusedOp
(
const
std
::
string
&
node_type
,
const
NodeVector
&
args
)
:
Op
(
node_type
,
args
)
:
Op
(
node_type
,
args
)
{
{
...
...
src/ngraph/op/util/fused_op.hpp
View file @
778b6004
...
@@ -51,6 +51,8 @@ namespace ngraph
...
@@ -51,6 +51,8 @@ namespace ngraph
/// \param args Nodes that produce the input tensors for the fused op
/// \param args Nodes that produce the input tensors for the fused op
FusedOp
(
const
NodeVector
&
args
);
FusedOp
(
const
NodeVector
&
args
);
FusedOp
(
const
OutputVector
&
args
);
/// \brief Constructs a FusedOp
/// \brief Constructs a FusedOp
///
///
/// \param args Nodes that produce the input tensors for the fused op
/// \param args Nodes that produce the input tensors for the fused op
...
...
test/build_graph.cpp
View file @
778b6004
...
@@ -150,3 +150,19 @@ TEST(build_graph, no_arg_construction)
...
@@ -150,3 +150,19 @@ TEST(build_graph, no_arg_construction)
validate_nodes_and_infer_types
(
ops
);
validate_nodes_and_infer_types
(
ops
);
ASSERT_EQ
(
add1
->
get_output_shape
(
0
),
Shape
{
7
});
ASSERT_EQ
(
add1
->
get_output_shape
(
0
),
Shape
{
7
});
}
}
TEST
(
build_graph
,
multi_output_split
)
{
const
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
64
,
8
,
100
,
150
});
auto
filters
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
128
,
2
,
10
,
20
});
const
auto
split
=
make_shared
<
op
::
Split
>
(
data
,
1
,
2
);
auto
conv
=
make_shared
<
op
::
GroupConvolution
>
(
split
->
output
(
1
),
filters
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
},
2
);
EXPECT_EQ
(
conv
->
get_shape
(),
(
Shape
{
64
,
128
,
91
,
131
}));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment