Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8f102516
Unverified
Commit
8f102516
authored
Oct 19, 2018
by
Matthew Brookhart
Committed by
GitHub
Oct 19, 2018
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into ayzhuang/in-place-concat
parents
3feb4264
982889f5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
636 additions
and
111 deletions
+636
-111
broadcast.cpp
src/ngraph/op/broadcast.cpp
+12
-17
replace_slice.cpp
src/ngraph/op/replace_slice.cpp
+58
-43
replace_slice.hpp
src/ngraph/op/replace_slice.hpp
+4
-4
select.cpp
src/ngraph/op/select.cpp
+19
-16
select.hpp
src/ngraph/op/select.hpp
+1
-0
slice.cpp
src/ngraph/op/slice.cpp
+30
-30
CMakeLists.txt
src/ngraph/runtime/cpu/CMakeLists.txt
+15
-0
halide_op.cpp
src/ngraph/runtime/cpu/builder/halide_op.cpp
+129
-0
cpu_builder.cpp
src/ngraph/runtime/cpu/cpu_builder.cpp
+4
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+5
-0
cpu_external_function.hpp
src/ngraph/runtime/cpu/cpu_external_function.hpp
+31
-0
halide_op.cpp
src/ngraph/runtime/cpu/op/halide_op.cpp
+42
-0
halide_op.hpp
src/ngraph/runtime/cpu/op/halide_op.hpp
+54
-0
halide_subgraph_extraction.cpp
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cpp
+126
-0
halide_subgraph_extraction.hpp
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp
+39
-0
CMakeLists.txt
test/CMakeLists.txt
+3
-0
halide.cpp
test/halide.cpp
+64
-0
type_prop.cpp
test/type_prop.cpp
+0
-0
No files found.
src/ngraph/op/broadcast.cpp
View file @
8f102516
...
@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
...
@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void
op
::
Broadcast
::
validate_and_infer_types
()
void
op
::
Broadcast
::
validate_and_infer_types
()
{
{
if
(
validate_punt_if_dynamic
())
infer_shape
();
for
(
auto
axis
:
m_broadcast_axes
)
{
{
return
;
NODE_VALIDATION_ASSERT
(
this
,
axis
<
m_shape
.
size
())
<<
"Broadcast axis index ("
<<
axis
<<
") exceeds specified output shape rank "
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", output shape: "
<<
m_shape
<<
")."
;
}
}
infer_shape
();
Shape
required_input_shape
=
m_shape
;
Shape
target_shape
=
m_shape
;
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
*
i
<
target_shape
.
size
())
required_input_shape
.
erase
(
required_input_shape
.
begin
()
+
*
i
);
<<
"Broadcast axis index ("
<<
*
i
<<
") exceeds target shape rank "
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", target shape: "
<<
target_shape
<<
")."
;
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
}
}
// TODO(amprocte): We can probably have a more helpful error message here.
// TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size())
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
// or there is a mismatch with one of the pre-broadcast axis lengths
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
0
).
compatible
(
required_input_shape
))
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i
<<
"Broadcast argument shape, specified output shape, and axes are incompatible "
// where target_shape[i] != arg->get_shape[i]).
<<
"(argument shape: "
<<
get_input_partial_shape
(
0
)
<<
", output shape: "
<<
m_shape
NODE_VALIDATION_ASSERT
(
this
,
target_shape
==
get_input_shape
(
0
))
<<
"Broadcast argument shape, target shape, and axes are incompatible "
<<
"(argument shape: "
<<
get_input_shape
(
0
)
<<
", target shape: "
<<
m_shape
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
...
...
src/ngraph/op/replace_slice.cpp
View file @
8f102516
...
@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...
@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
,
m_strides
(
strides
)
,
m_strides
(
strides
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
check_args
();
}
}
op
::
ReplaceSlice
::
ReplaceSlice
(
const
shared_ptr
<
Node
>&
arg0
,
op
::
ReplaceSlice
::
ReplaceSlice
(
const
shared_ptr
<
Node
>&
arg0
,
...
@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...
@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
,
m_strides
(
Strides
(
lower_bounds
.
size
(),
1
))
,
m_strides
(
Strides
(
lower_bounds
.
size
(),
1
))
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
check_args
();
}
}
void
op
::
ReplaceSlice
::
check_arg
s
()
void
op
::
ReplaceSlice
::
validate_and_infer_type
s
()
{
{
auto
&
input_0
=
get_inputs
().
at
(
0
);
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
auto
&
input_0_shape
=
input_0
.
get_shape
();
// construct the default value.
auto
&
input_0_element_type
=
input_0
.
get_element_type
();
if
(
m_strides
.
size
()
==
0
)
{
auto
&
input_1
=
get_inputs
().
at
(
1
);
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
auto
&
input_1_shape
=
input_1
.
get_shape
();
}
auto
&
input_1_element_type
=
input_1
.
get_element_type
();
NODE_VALIDATION_ASSERT
(
this
,
input_0_shape
.
size
()
==
input_1_shape
.
size
())
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
<<
"Argument ranks do not match (arg0 shape: "
<<
input_0_shape
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
<<
", arg1 shape: "
<<
input_1_shape
<<
")."
;
Dimension
merged_args_rank
;
NODE_VALIDATION_ASSERT
(
this
,
input_0_element_type
==
input_1_element_type
)
NODE_VALIDATION_ASSERT
(
this
,
<<
"Argument element types do not match (arg0 element type: "
<<
input_0_element_type
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()))
<<
", arg1 element type: "
<<
input_1_element_type
<<
")."
;
<<
"Argument ranks do not match (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
.
size
()
==
input_0_shape
.
size
())
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
<<
"Rank of lower bounds ("
<<
m_lower_bounds
.
size
()
<<
") does not match rank "
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (lower bounds: "
<<
m_lower_bounds
element
::
Type
merged_args_et
;
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
.
size
()
==
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
))
<<
"Rank of upper bounds ("
<<
m_upper_bounds
.
size
()
<<
") does not match rank "
<<
"Argument element types do not match (arg0 element type: "
<<
arg0_et
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (upper bounds: "
<<
m_upper_bounds
<<
", arg1 element type: "
<<
arg1_et
<<
")."
;
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
.
size
()
==
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
<<
"Rank of strides ("
<<
m_strides
.
size
()
<<
") does not match rank "
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (strides: "
<<
m_strides
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
Shape
slice_shape
;
size_t
output_rank
=
m_upper_bounds
.
size
()
;
for
(
size_t
i
=
0
;
i
<
input_0_shape
.
size
()
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
[
i
]
<=
input_0_shape
[
i
])
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
<<
" (strides: "
<<
m_strides
<<
")."
;
}
size_t
slice_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
NODE_VALIDATION_ASSERT
(
this
,
slice_axis_size
=
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
)
slice_axis_size
/
m_strides
[
i
]
+
((
slice_axis_size
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
<<
"Argument ranks do not match the rank of the lower bounds ("
<<
m_lower_bounds
slice_shape
.
push_back
(
slice_axis_size
);
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
arg0_shape
<<
")."
;
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
sliced_dims
[
i
]
=
sliced_dim
;
}
}
NODE_VALIDATION_ASSERT
(
this
,
input_1_shape
==
slice_shape
)
PartialShape
slice_shape
{
sliced_dims
};
<<
"Shape of replacement tensor ("
<<
input_1_shape
<<
") does not match the slice shape "
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
compatible
(
slice_shape
))
<<
"Shape of replacement tensor ("
<<
arg1_shape
<<
") does not match the slice shape "
<<
"("
<<
slice_shape
<<
")."
;
<<
"("
<<
slice_shape
<<
")."
;
set_output_type
(
0
,
input_0_element_type
,
input_0_shape
);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape
result_shape
=
(
arg0_shape
.
rank
().
is_static
())
?
arg0_shape
:
PartialShape
(
std
::
vector
<
Dimension
>
(
output_rank
,
Dimension
::
dynamic
()));
set_output_type
(
0
,
merged_args_et
,
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
ReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/replace_slice.hpp
View file @
8f102516
...
@@ -88,11 +88,11 @@ namespace ngraph
...
@@ -88,11 +88,11 @@ namespace ngraph
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
void
check_args
()
;
void
validate_and_infer_types
()
override
;
const
Coordinate
m_lower_bounds
;
Coordinate
m_lower_bounds
;
const
Coordinate
m_upper_bounds
;
Coordinate
m_upper_bounds
;
const
Strides
m_strides
;
Strides
m_strides
;
};
};
}
}
}
}
src/ngraph/op/select.cpp
View file @
8f102516
...
@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
...
@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
:
Op
(
"Select"
,
check_single_output_args
({
arg0
,
arg1
,
arg2
}))
:
Op
(
"Select"
,
check_single_output_args
({
arg0
,
arg1
,
arg2
}))
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
auto
&
input_0
=
get_inputs
().
at
(
0
);
void
op
::
Select
::
validate_and_infer_types
()
auto
&
input_1
=
get_inputs
().
at
(
1
);
{
auto
&
input_2
=
get_inputs
().
at
(
2
);
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
0
).
is_dynamic
()
||
NODE_VALIDATION_ASSERT
(
this
,
input_0
.
get_element_type
(
)
==
element
::
boolean
)
get_input_element_type
(
0
)
==
element
::
boolean
)
<<
"Argument 0 does not have boolean element type (element type: "
<<
"Argument 0 does not have boolean element type (element type: "
<<
input_0
.
get_element_type
(
)
<<
")."
;
<<
get_input_element_type
(
0
)
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
result_shape
=
get_input_partial_shape
(
0
);
input_0
.
get_shape
()
==
input_1
.
get_shape
()
&&
input_0
.
get_shape
()
==
input_2
.
get_shape
())
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
1
)))
<<
"Arguments do not all have the same shape (arg0 shape: "
<<
input_0
.
get_shape
()
<<
"Argument shapes are inconsistent."
;
<<
", arg1 shape: "
<<
input_1
.
get_shape
()
<<
", arg2 shape: "
<<
input_2
.
get_shape
()
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
2
)))
<<
")."
;
<<
"Argument shapes are inconsistent."
;
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
input_1
.
get_element_type
()
==
input_2
.
get_element_type
())
NODE_VALIDATION_ASSERT
(
<<
"Arguments 1 and 2 do not have the same element type (arg1 type: "
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
1
),
get_input_element_type
(
2
)))
<<
input_1
.
get_element_type
()
<<
", arg2 type: "
<<
input_2
.
get_element_type
()
<<
")
."
;
<<
"Argument 1 and 2 element types are inconsistent
."
;
set_output_type
(
0
,
input_1
.
get_element_type
(),
input_1
.
get_shape
()
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
Select
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Select
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/select.hpp
View file @
8f102516
...
@@ -53,6 +53,7 @@ namespace ngraph
...
@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
protected
:
void
validate_and_infer_types
()
override
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
};
};
...
...
src/ngraph/op/slice.cpp
View file @
8f102516
...
@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
...
@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void
op
::
Slice
::
validate_and_infer_types
()
void
op
::
Slice
::
validate_and_infer_types
()
{
{
if
(
validate_punt_if_dynamic
())
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
{
// construct the default value.
return
;
if
(
m_strides
.
size
()
==
0
)
}
if
(
0
==
m_strides
.
size
())
{
{
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
}
}
auto
&
input
=
get_inputs
().
at
(
0
);
auto
&
input_shape
=
input
.
get_shape
();
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
.
size
()
==
input_shape
.
size
())
<<
"Rank of lower bounds ("
<<
m_lower_bounds
.
size
()
<<
") does not match rank "
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (lower bounds: "
<<
m_lower_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
.
size
()
==
input_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
<<
"Rank of upper bounds ("
<<
m_upper_bounds
.
size
()
<<
") does not match rank "
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (upper bounds: "
<<
m_upper_bounds
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
", argument shape: "
<<
input_shape
<<
")."
;
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
.
size
()
==
input_shape
.
size
())
size_t
output_rank
=
m_upper_bounds
.
size
();
<<
"Rank of strides ("
<<
m_strides
.
size
()
<<
") does not match rank "
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (strides: "
<<
m_strides
<<
", argument shape: "
<<
input_shape
<<
")."
;
Shape
result_shape
;
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
[
i
]
<=
input_shape
[
i
])
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
<<
" (strides: "
<<
m_strides
<<
")."
;
}
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
size_t
(
input_rank
)
==
output_rank
)
<<
"Input rank does not match the rank of the lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
std
::
vector
<
Dimension
>
result_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
input_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
input_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
result_axis_size
=
result_axis_size
=
result_axis_size
/
m_strides
[
i
]
+
((
result_axis_size
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
result_axis_size
/
m_strides
[
i
]
+
((
result_axis_size
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
result_
shape
.
push_back
(
result_axis_size
)
;
result_
dims
[
i
]
=
result_axis_size
;
}
}
set_output_type
(
0
,
input
.
get_element_type
(),
result_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
PartialShape
{
result_dims
}
);
}
}
shared_ptr
<
Node
>
op
::
Slice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Slice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/runtime/cpu/CMakeLists.txt
View file @
8f102516
...
@@ -81,6 +81,7 @@ set(SRC
...
@@ -81,6 +81,7 @@ set(SRC
op/batch_norm_relu.cpp
op/batch_norm_relu.cpp
op/bounded_relu.cpp
op/bounded_relu.cpp
op/group_conv.cpp
op/group_conv.cpp
op/halide_op.cpp
op/conv_bias.cpp
op/conv_bias.cpp
op/conv_relu.cpp
op/conv_relu.cpp
op/convert_layout.cpp
op/convert_layout.cpp
...
@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY)
...
@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY)
)
)
endif
()
endif
()
if
(
NGRAPH_HALIDE
)
set
(
SRC
${
SRC
}
builder/halide_op.cpp
pass/halide_subgraph_extraction.cpp
)
endif
()
if
(
NGRAPH_TBB_ENABLE
)
if
(
NGRAPH_TBB_ENABLE
)
include
(
${
TBB_ROOT
}
/cmake/TBBBuild.cmake
)
include
(
${
TBB_ROOT
}
/cmake/TBBBuild.cmake
)
tbb_build
(
TBB_ROOT
${
TBB_ROOT
}
MAKE_ARGS tbb_build_dir=
${
CMAKE_CURRENT_BINARY_DIR
}
/tbb_build
tbb_build
(
TBB_ROOT
${
TBB_ROOT
}
MAKE_ARGS tbb_build_dir=
${
CMAKE_CURRENT_BINARY_DIR
}
/tbb_build
...
@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE)
...
@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE)
if
(
NGRAPH_DEX_ONLY
)
if
(
NGRAPH_DEX_ONLY
)
target_compile_definitions
(
cpu_backend PRIVATE
"NGRAPH_DEX_ONLY"
)
target_compile_definitions
(
cpu_backend PRIVATE
"NGRAPH_DEX_ONLY"
)
endif
()
endif
()
if
(
NGRAPH_HALIDE
)
target_compile_definitions
(
cpu_backend PRIVATE
"NGRAPH_HALIDE"
)
ExternalProject_Get_Property
(
ext_halide BINARY_DIR
)
target_include_directories
(
cpu_backend SYSTEM PRIVATE
${
BINARY_DIR
}
/include
)
target_link_libraries
(
cpu_backend PRIVATE
${
BINARY_DIR
}
/lib/libHalide.so
)
endif
()
if
(
OPENMP_FOUND
)
if
(
OPENMP_FOUND
)
target_compile_options
(
cpu_backend PRIVATE
"
${
OpenMP_CXX_FLAGS
}
"
)
target_compile_options
(
cpu_backend PRIVATE
"
${
OpenMP_CXX_FLAGS
}
"
)
...
...
src/ngraph/runtime/cpu/builder/halide_op.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <Halide.h>
#include <HalideBuffer.h>
#include <functional>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define TI(x) type_index(typeid(x))
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
halide
{
static
const
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
Halide
::
Func
(
vector
<
Halide
::
Func
>
)
>>
generators
{{
TI
(
ngraph
::
op
::
Add
),
[](
vector
<
Halide
::
Func
>
in
)
{
Halide
::
Var
x
;
Halide
::
Func
func
;
func
(
x
)
=
in
[
0
](
x
)
+
in
[
1
](
x
);
return
func
;
}},
{
TI
(
ngraph
::
op
::
Multiply
),
[](
vector
<
Halide
::
Func
>
in
)
{
Halide
::
Var
x
;
Halide
::
Func
func
;
func
(
x
)
=
in
[
0
](
x
)
*
in
[
1
](
x
);
return
func
;
}},
{
TI
(
ngraph
::
op
::
Relu
),
[](
vector
<
Halide
::
Func
>
in
)
{
Halide
::
Var
x
;
Halide
::
Func
func
;
func
(
x
)
=
Halide
::
max
(
in
[
0
](
x
),
0
);
return
func
;
}}};
}
template
<>
void
Builder
::
BUILDER_DECL
(
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
)
{
const
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
*
hs
=
static_cast
<
const
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
*>
(
node
);
auto
&
halide_functions
=
external_function
->
get_halide_functions
();
auto
&
subgraph_params
=
external_function
->
get_subgraph_params
();
auto
&
subgraph_param_sizes
=
external_function
->
get_subgraph_param_sizes
();
auto
&
subgraph_param_ptrs
=
external_function
->
get_subgraph_param_ptrs
();
for
(
const
auto
&
op
:
hs
->
get_ops
())
{
if
(
!
halide
::
generators
.
count
(
TI
(
*
op
)))
{
throw
ngraph_error
(
"Invalid op in halide subgraph"
);
}
vector
<
Halide
::
Func
>
inputs
;
for
(
const
auto
&
input
:
op
->
get_inputs
())
{
auto
tensor_name
=
input
.
get_output
().
get_tensor_ptr
()
->
get_name
();
if
(
halide_functions
.
count
(
tensor_name
))
{
inputs
.
emplace_back
(
halide_functions
[
tensor_name
]);
}
else
{
subgraph_params
[
tensor_name
]
=
Halide
::
ImageParam
(
Halide
::
Float
(
32
),
1
);
subgraph_param_sizes
[
tensor_name
]
=
shape_size
(
input
.
get_output
().
get_tensor_ptr
()
->
get_shape
());
subgraph_param_ptrs
.
emplace
(
tensor_name
,
external_function
->
get_tensor_data
(
tensor_name
));
inputs
.
emplace_back
(
subgraph_params
[
tensor_name
]);
}
}
halide_functions
[
op
->
get_output_tensor_ptr
()
->
get_name
()]
=
halide
::
generators
.
at
(
TI
(
*
op
))(
inputs
);
}
auto
out_tensor_name
=
hs
->
get_ops
().
back
()
->
get_output_tensor_ptr
()
->
get_name
();
auto
&
functors
=
external_function
->
get_functors
();
auto
&
out_tensor
=
external_function
->
get_tensor_data
(
out
[
0
].
get_name
());
auto
&
terminal_func
=
halide_functions
[
out_tensor_name
];
auto
out_size
=
out
[
0
].
get_size
();
auto
functor
=
[
&
,
out_size
](
CPURuntimeContext
*
ctx
)
{
for
(
auto
&
param
:
subgraph_params
)
{
Halide
::
Buffer
<
float
>
param_buffer
(
static_cast
<
float
*>
(
subgraph_param_ptrs
.
at
(
param
.
first
).
get
()),
subgraph_param_sizes
.
at
(
param
.
first
));
param
.
second
.
set
(
param_buffer
);
}
Halide
::
Buffer
<
float
>
out_buffer
(
static_cast
<
float
*>
(
out_tensor
),
out_size
);
terminal_func
.
realize
(
out_buffer
);
};
functors
.
emplace_back
(
functor
);
}
}
}
}
src/ngraph/runtime/cpu/cpu_builder.cpp
View file @
8f102516
...
@@ -98,6 +98,7 @@
...
@@ -98,6 +98,7 @@
#include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
...
@@ -367,7 +368,9 @@ namespace ngraph
...
@@ -367,7 +368,9 @@ namespace ngraph
static
BuildOpMap
build_dispatcher
{
static
BuildOpMap
build_dispatcher
{
{
TI
(
ngraph
::
op
::
Parameter
),
&
runtime
::
cpu
::
Builder
::
nop
},
{
TI
(
ngraph
::
op
::
Parameter
),
&
runtime
::
cpu
::
Builder
::
nop
},
{
TI
(
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
),
{
TI
(
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
),
&
runtime
::
cpu
::
Builder
::
build
<
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
>
}};
&
runtime
::
cpu
::
Builder
::
build
<
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
>
},
{
TI
(
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
),
&
runtime
::
cpu
::
Builder
::
build
<
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
>
}};
return
build_dispatcher
;
return
build_dispatcher
;
}
}
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
8f102516
...
@@ -170,6 +170,7 @@
...
@@ -170,6 +170,7 @@
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#ifdef NGRAPH_DISTRIBUTED
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/allreduce.hpp"
...
@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
...
@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUFusion
>
();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUCollapseDims
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUCollapseDims
>
();
#if defined(NGRAPH_HALIDE)
pass_manager
.
register_pass
<
ngraph
::
runtime
::
cpu
::
pass
::
HalideSubgraphExtraction
>
();
#endif
NodeVector
nv_cwi
;
// We dont need CPUWorkspaceInsertion to return list of indices
NodeVector
nv_cwi
;
// We dont need CPUWorkspaceInsertion to return list of indices
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUWorkspaceInsertion
>
(
nv_cwi
,
false
);
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUWorkspaceInsertion
>
(
nv_cwi
,
false
);
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUAssignment
>
(
this
);
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUAssignment
>
(
this
);
...
...
src/ngraph/runtime/cpu/cpu_external_function.hpp
View file @
8f102516
...
@@ -27,6 +27,10 @@
...
@@ -27,6 +27,10 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#if defined(NGRAPH_HALIDE)
#include <Halide.h>
#endif
#if !defined(NGRAPH_DEX_ONLY)
#if !defined(NGRAPH_DEX_ONLY)
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/codegen/code_writer.hpp"
...
@@ -135,6 +139,26 @@ namespace ngraph
...
@@ -135,6 +139,26 @@ namespace ngraph
const
std
::
string
&
directory
,
const
std
::
string
&
directory
,
const
std
::
string
&
filename
);
const
std
::
string
&
filename
);
#if defined(NGRAPH_HALIDE)
std
::
unordered_map
<
std
::
string
,
Halide
::
Func
>&
get_halide_functions
()
{
return
halide_functions
;
}
std
::
unordered_map
<
std
::
string
,
Halide
::
ImageParam
>&
get_subgraph_params
()
{
return
subgraph_params
;
}
std
::
unordered_map
<
std
::
string
,
int
>&
get_subgraph_param_sizes
()
{
return
subgraph_param_sizes
;
}
std
::
unordered_map
<
std
::
string
,
std
::
reference_wrapper
<
void
*>>&
get_subgraph_param_ptrs
()
{
return
subgraph_param_ptrs
;
}
#endif
protected
:
protected
:
void
build
();
void
build
();
...
@@ -240,6 +264,13 @@ namespace ngraph
...
@@ -240,6 +264,13 @@ namespace ngraph
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
CPU_ExternalFunction
>>
callees
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
CPU_ExternalFunction
>>
callees
;
bool
m_is_built
;
bool
m_is_built
;
bool
m_direct_execution
;
bool
m_direct_execution
;
#if defined(NGRAPH_HALIDE)
std
::
unordered_map
<
std
::
string
,
Halide
::
Func
>
halide_functions
;
std
::
unordered_map
<
std
::
string
,
Halide
::
ImageParam
>
subgraph_params
;
std
::
unordered_map
<
std
::
string
,
int
>
subgraph_param_sizes
;
std
::
unordered_map
<
std
::
string
,
std
::
reference_wrapper
<
void
*>>
subgraph_param_ptrs
;
#endif
};
};
}
}
}
}
...
...
src/ngraph/runtime/cpu/op/halide_op.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
shared_ptr
<
Node
>
runtime
::
cpu
::
op
::
HalideOp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
return
make_shared
<
HalideOp
>
(
new_args
,
ops
,
output_type
,
output_shape
);
}
runtime
::
cpu
::
op
::
HalideOp
::
HalideOp
(
const
NodeVector
&
args
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
ops
,
const
element
::
Type
&
out_type
,
const
Shape
&
out_shape
)
:
Op
(
"HalideOp"
,
check_single_output_args
(
args
))
,
ops
(
ops
)
,
output_type
(
out_type
)
,
output_shape
(
out_shape
)
{
constructor_validate_and_infer_types
();
}
void
runtime
::
cpu
::
op
::
HalideOp
::
validate_and_infer_types
()
{
set_output_type
(
0
,
output_type
,
output_shape
);
}
src/ngraph/runtime/cpu/op/halide_op.hpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <list>
#include <vector>
#include "ngraph/op/op.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
op
{
class
HalideOp
:
public
ngraph
::
op
::
Op
{
public
:
HalideOp
(
const
NodeVector
&
args
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
ops
,
const
element
::
Type
&
out_type
,
const
Shape
&
out_shape
);
virtual
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
get_ops
()
const
{
return
ops
;
}
private
:
std
::
list
<
std
::
shared_ptr
<
Node
>>
ops
;
element
::
Type
output_type
;
Shape
output_shape
;
};
}
}
}
}
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iostream>
#include <list>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define TI(x) type_index(typeid(x))
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
halide
{
static
const
std
::
unordered_set
<
std
::
type_index
>
whitelist
{
TI
(
ngraph
::
op
::
Add
),
TI
(
ngraph
::
op
::
Multiply
),
TI
(
ngraph
::
op
::
Relu
)};
static
const
std
::
unordered_set
<
std
::
type_index
>
skiplist
{
TI
(
ngraph
::
op
::
Parameter
),
TI
(
ngraph
::
op
::
Result
)};
}
}
}
}
// Support for multiple results, multiple outputs and getoutputelement, and multiple subgraphs in a single
// pipeline is not implemented since this should go away in favor of the "hybrid" transformer approach of
// carving out subgraphs in core ngraph
bool
runtime
::
cpu
::
pass
::
HalideSubgraphExtraction
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
{
list
<
shared_ptr
<
Node
>>
worklist
;
auto
results
=
function
->
get_results
();
// Artificial limitation
if
(
results
.
size
()
>
1
)
{
return
false
;
}
if
(
function
->
get_result
()
->
get_element_type
()
!=
element
::
f32
)
{
return
false
;
}
for
(
const
auto
&
result
:
results
)
{
worklist
.
emplace_back
(
result
);
}
unordered_set
<
shared_ptr
<
Node
>>
ops
;
list
<
shared_ptr
<
Node
>>
ordered_ops
;
while
(
!
worklist
.
empty
())
{
const
auto
&
node
=
worklist
.
front
();
if
(
!
halide
::
skiplist
.
count
(
TI
(
*
node
)))
{
if
(
halide
::
whitelist
.
count
(
TI
(
*
node
)))
{
ops
.
emplace
(
node
);
ordered_ops
.
emplace_back
(
node
);
}
else
{
break
;
}
}
const
auto
&
args
=
node
->
get_arguments
();
for
(
const
auto
&
arg
:
args
)
{
worklist
.
emplace_back
(
arg
);
}
worklist
.
pop_front
();
}
NodeVector
liveins
;
for
(
const
auto
&
op
:
ops
)
{
const
auto
&
args
=
op
->
get_arguments
();
for
(
const
auto
&
arg
:
args
)
{
if
(
!
ops
.
count
(
arg
))
{
liveins
.
emplace_back
(
arg
);
}
}
}
ordered_ops
.
reverse
();
auto
subgraph
=
make_shared
<
cpu
::
op
::
HalideOp
>
(
liveins
,
ordered_ops
,
function
->
get_result
()
->
get_element_type
(),
function
->
get_result
()
->
get_shape
());
replace_node
(
function
->
get_result
()
->
get_argument
(
0
),
subgraph
);
return
true
;
}
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
pass
{
class
HalideSubgraphExtraction
:
public
ngraph
::
pass
::
FunctionPass
{
public
:
HalideSubgraphExtraction
()
{}
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
override
;
};
}
}
}
}
test/CMakeLists.txt
View file @
8f102516
...
@@ -68,6 +68,9 @@ endif()
...
@@ -68,6 +68,9 @@ endif()
if
(
NGRAPH_CPU_ENABLE
)
if
(
NGRAPH_CPU_ENABLE
)
list
(
APPEND SRC core_fusion.cpp quantize_cpu.cpp
)
list
(
APPEND SRC core_fusion.cpp quantize_cpu.cpp
)
list
(
APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp
)
list
(
APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp
)
if
(
NGRAPH_HALIDE
)
list
(
APPEND SRC halide.cpp
)
endif
()
set
(
ACTIVE_BACKEND_LIST
${
ACTIVE_BACKEND_LIST
}
CPU
)
set
(
ACTIVE_BACKEND_LIST
${
ACTIVE_BACKEND_LIST
}
CPU
)
endif
()
endif
()
...
...
test/halide.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
#include "util/test_tools.hpp"
using
namespace
ngraph
;
using
namespace
std
;
TEST
(
halide
,
halide_subgraph
)
{
Shape
shape
{
8
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
D
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
relu
=
make_shared
<
op
::
Relu
>
((
A
+
B
)
*
C
);
auto
f
=
make_shared
<
Function
>
(
relu
+
D
,
op
::
ParameterVector
{
A
,
B
,
C
,
D
});
auto
backend
=
runtime
::
Backend
::
create
(
"CPU"
);
shared_ptr
<
runtime
::
Tensor
>
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
b
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
c
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
d
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
result
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
vector
<
float
>
data
{
-
1
,
4
,
-
2
,
5
,
1
,
5
,
7
,
9
};
copy_data
(
a
,
data
);
copy_data
(
b
,
data
);
copy_data
(
c
,
data
);
copy_data
(
d
,
data
);
vector
<
float
>
expected
{
1
,
36
,
6
,
55
,
3
,
55
,
105
,
171
};
backend
->
call_with_validate
(
f
,
{
result
},
{
a
,
b
,
c
,
d
});
EXPECT_TRUE
(
test
::
all_close
(
read_vector
<
float
>
(
result
),
expected
,
1.0e-4
f
,
1.0e-4
f
));
}
test/type_prop.cpp
View file @
8f102516
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment