Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8f102516
Unverified
Commit
8f102516
authored
Oct 19, 2018
by
Matthew Brookhart
Committed by
GitHub
Oct 19, 2018
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into ayzhuang/in-place-concat
parents
3feb4264
982889f5
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1477 additions
and
137 deletions
+1477
-137
broadcast.cpp
src/ngraph/op/broadcast.cpp
+12
-17
replace_slice.cpp
src/ngraph/op/replace_slice.cpp
+58
-43
replace_slice.hpp
src/ngraph/op/replace_slice.hpp
+4
-4
select.cpp
src/ngraph/op/select.cpp
+19
-16
select.hpp
src/ngraph/op/select.hpp
+1
-0
slice.cpp
src/ngraph/op/slice.cpp
+30
-30
CMakeLists.txt
src/ngraph/runtime/cpu/CMakeLists.txt
+15
-0
halide_op.cpp
src/ngraph/runtime/cpu/builder/halide_op.cpp
+129
-0
cpu_builder.cpp
src/ngraph/runtime/cpu/cpu_builder.cpp
+4
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+5
-0
cpu_external_function.hpp
src/ngraph/runtime/cpu/cpu_external_function.hpp
+31
-0
halide_op.cpp
src/ngraph/runtime/cpu/op/halide_op.cpp
+42
-0
halide_op.hpp
src/ngraph/runtime/cpu/op/halide_op.hpp
+54
-0
halide_subgraph_extraction.cpp
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cpp
+126
-0
halide_subgraph_extraction.hpp
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp
+39
-0
CMakeLists.txt
test/CMakeLists.txt
+3
-0
halide.cpp
test/halide.cpp
+64
-0
type_prop.cpp
test/type_prop.cpp
+841
-26
No files found.
src/ngraph/op/broadcast.cpp
View file @
8f102516
...
@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
...
@@ -40,33 +40,28 @@ op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
void
op
::
Broadcast
::
validate_and_infer_types
()
void
op
::
Broadcast
::
validate_and_infer_types
()
{
{
if
(
validate_punt_if_dynamic
())
infer_shape
();
for
(
auto
axis
:
m_broadcast_axes
)
{
{
return
;
NODE_VALIDATION_ASSERT
(
this
,
axis
<
m_shape
.
size
())
<<
"Broadcast axis index ("
<<
axis
<<
") exceeds specified output shape rank "
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", output shape: "
<<
m_shape
<<
")."
;
}
}
infer_shape
();
Shape
required_input_shape
=
m_shape
;
Shape
target_shape
=
m_shape
;
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
for
(
auto
i
=
m_broadcast_axes
.
rbegin
();
i
!=
m_broadcast_axes
.
rend
();
++
i
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
*
i
<
target_shape
.
size
())
required_input_shape
.
erase
(
required_input_shape
.
begin
()
+
*
i
);
<<
"Broadcast axis index ("
<<
*
i
<<
") exceeds target shape rank "
<<
"(broadcast axes: "
<<
m_broadcast_axes
<<
", target shape: "
<<
target_shape
<<
")."
;
target_shape
.
erase
(
target_shape
.
begin
()
+
*
i
);
}
}
// TODO(amprocte): We can probably have a more helpful error message here.
// TODO(amprocte): We can probably have a more helpful error message here.
// There are two things that can go wrong, which are being picked up in
// There are two things that can go wrong, which are being picked up in
// one fell swoop by this check: either the number of broadcast axes is not
// one fell swoop by this check: either the number of broadcast axes is not
// enough (arg->get_shape().size() + broadcast_axes.size() != shape.size())
// enough, or there is a mismatch with one of the pre-broadcast axis lengths.
// or there is a mismatch with one of the pre-broadcast axis lengths
NODE_VALIDATION_ASSERT
(
this
,
get_input_partial_shape
(
0
).
compatible
(
required_input_shape
))
// (i.e. target_shape.size() == arg->get_shape.size() but there is some i
<<
"Broadcast argument shape, specified output shape, and axes are incompatible "
// where target_shape[i] != arg->get_shape[i]).
<<
"(argument shape: "
<<
get_input_partial_shape
(
0
)
<<
", output shape: "
<<
m_shape
NODE_VALIDATION_ASSERT
(
this
,
target_shape
==
get_input_shape
(
0
))
<<
"Broadcast argument shape, target shape, and axes are incompatible "
<<
"(argument shape: "
<<
get_input_shape
(
0
)
<<
", target shape: "
<<
m_shape
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
<<
", broadcast axes: "
<<
m_broadcast_axes
<<
")."
;
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
m_shape
);
...
...
src/ngraph/op/replace_slice.cpp
View file @
8f102516
...
@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...
@@ -32,8 +32,6 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
,
m_strides
(
strides
)
,
m_strides
(
strides
)
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
check_args
();
}
}
op
::
ReplaceSlice
::
ReplaceSlice
(
const
shared_ptr
<
Node
>&
arg0
,
op
::
ReplaceSlice
::
ReplaceSlice
(
const
shared_ptr
<
Node
>&
arg0
,
...
@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
...
@@ -46,69 +44,86 @@ op::ReplaceSlice::ReplaceSlice(const shared_ptr<Node>& arg0,
,
m_strides
(
Strides
(
lower_bounds
.
size
(),
1
))
,
m_strides
(
Strides
(
lower_bounds
.
size
(),
1
))
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
check_args
();
}
}
void
op
::
ReplaceSlice
::
check_arg
s
()
void
op
::
ReplaceSlice
::
validate_and_infer_type
s
()
{
{
auto
&
input_0
=
get_inputs
().
at
(
0
);
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
auto
&
input_0_shape
=
input_0
.
get_shape
();
// construct the default value.
auto
&
input_0_element_type
=
input_0
.
get_element_type
();
if
(
m_strides
.
size
()
==
0
)
{
auto
&
input_1
=
get_inputs
().
at
(
1
);
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
auto
&
input_1_shape
=
input_1
.
get_shape
();
}
auto
&
input_1_element_type
=
input_1
.
get_element_type
();
NODE_VALIDATION_ASSERT
(
this
,
input_0_shape
.
size
()
==
input_1_shape
.
size
())
const
PartialShape
&
arg0_shape
=
get_input_partial_shape
(
0
);
<<
"Argument ranks do not match (arg0 shape: "
<<
input_0_shape
const
PartialShape
&
arg1_shape
=
get_input_partial_shape
(
1
);
<<
", arg1 shape: "
<<
input_1_shape
<<
")."
;
Dimension
merged_args_rank
;
NODE_VALIDATION_ASSERT
(
this
,
input_0_element_type
==
input_1_element_type
)
NODE_VALIDATION_ASSERT
(
this
,
<<
"Argument element types do not match (arg0 element type: "
<<
input_0_element_type
Dimension
::
merge
(
merged_args_rank
,
arg0_shape
.
rank
(),
arg1_shape
.
rank
()))
<<
", arg1 element type: "
<<
input_1_element_type
<<
")."
;
<<
"Argument ranks do not match (arg0 shape: "
<<
arg0_shape
<<
", arg1 shape: "
<<
arg1_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
.
size
()
==
input_0_shape
.
size
())
element
::
Type
arg0_et
=
get_input_element_type
(
0
);
<<
"Rank of lower bounds ("
<<
m_lower_bounds
.
size
()
<<
") does not match rank "
element
::
Type
arg1_et
=
get_input_element_type
(
1
);
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (lower bounds: "
<<
m_lower_bounds
element
::
Type
merged_args_et
;
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
.
size
()
==
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
element
::
Type
::
merge
(
merged_args_et
,
arg0_et
,
arg1_et
))
<<
"Rank of upper bounds ("
<<
m_upper_bounds
.
size
()
<<
") does not match rank "
<<
"Argument element types do not match (arg0 element type: "
<<
arg0_et
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (upper bounds: "
<<
m_upper_bounds
<<
", arg1 element type: "
<<
arg1_et
<<
")."
;
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
.
size
()
==
input_0_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
<<
"Rank of strides ("
<<
m_strides
.
size
()
<<
") does not match rank "
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
<<
"of argument ("
<<
input_0_shape
.
size
()
<<
") (strides: "
<<
m_strides
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
Shape
slice_shape
;
size_t
output_rank
=
m_upper_bounds
.
size
()
;
for
(
size_t
i
=
0
;
i
<
input_0_shape
.
size
()
;
i
++
)
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
[
i
]
<=
input_0_shape
[
i
])
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_0_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
<<
" (strides: "
<<
m_strides
<<
")."
;
}
size_t
slice_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
NODE_VALIDATION_ASSERT
(
this
,
slice_axis_size
=
merged_args_rank
.
is_dynamic
()
||
size_t
(
merged_args_rank
)
==
output_rank
)
slice_axis_size
/
m_strides
[
i
]
+
((
slice_axis_size
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
<<
"Argument ranks do not match the rank of the lower bounds ("
<<
m_lower_bounds
slice_shape
.
push_back
(
slice_axis_size
);
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
std
::
vector
<
Dimension
>
sliced_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
arg0_shape
.
rank
().
is_dynamic
()
||
arg0_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
arg0_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
arg0_shape
<<
")."
;
size_t
sliced_dim
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
sliced_dim
=
sliced_dim
/
m_strides
[
i
]
+
((
sliced_dim
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
sliced_dims
[
i
]
=
sliced_dim
;
}
}
NODE_VALIDATION_ASSERT
(
this
,
input_1_shape
==
slice_shape
)
PartialShape
slice_shape
{
sliced_dims
};
<<
"Shape of replacement tensor ("
<<
input_1_shape
<<
") does not match the slice shape "
NODE_VALIDATION_ASSERT
(
this
,
arg1_shape
.
compatible
(
slice_shape
))
<<
"Shape of replacement tensor ("
<<
arg1_shape
<<
") does not match the slice shape "
<<
"("
<<
slice_shape
<<
")."
;
<<
"("
<<
slice_shape
<<
")."
;
set_output_type
(
0
,
input_0_element_type
,
input_0_shape
);
// Slight corner case here: if arg0 was rank-unknown, we can go ahead and set the output rank
// because the attribs will have given us enough info.
PartialShape
result_shape
=
(
arg0_shape
.
rank
().
is_static
())
?
arg0_shape
:
PartialShape
(
std
::
vector
<
Dimension
>
(
output_rank
,
Dimension
::
dynamic
()));
set_output_type
(
0
,
merged_args_et
,
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
ReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
ReplaceSlice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/replace_slice.hpp
View file @
8f102516
...
@@ -88,11 +88,11 @@ namespace ngraph
...
@@ -88,11 +88,11 @@ namespace ngraph
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
void
check_args
()
;
void
validate_and_infer_types
()
override
;
const
Coordinate
m_lower_bounds
;
Coordinate
m_lower_bounds
;
const
Coordinate
m_upper_bounds
;
Coordinate
m_upper_bounds
;
const
Strides
m_strides
;
Strides
m_strides
;
};
};
}
}
}
}
src/ngraph/op/select.cpp
View file @
8f102516
...
@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
...
@@ -32,27 +32,30 @@ op::Select::Select(const shared_ptr<Node>& arg0,
:
Op
(
"Select"
,
check_single_output_args
({
arg0
,
arg1
,
arg2
}))
:
Op
(
"Select"
,
check_single_output_args
({
arg0
,
arg1
,
arg2
}))
{
{
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
auto
&
input_0
=
get_inputs
().
at
(
0
);
void
op
::
Select
::
validate_and_infer_types
()
auto
&
input_1
=
get_inputs
().
at
(
1
);
{
auto
&
input_2
=
get_inputs
().
at
(
2
);
NODE_VALIDATION_ASSERT
(
this
,
get_input_element_type
(
0
).
is_dynamic
()
||
NODE_VALIDATION_ASSERT
(
this
,
input_0
.
get_element_type
(
)
==
element
::
boolean
)
get_input_element_type
(
0
)
==
element
::
boolean
)
<<
"Argument 0 does not have boolean element type (element type: "
<<
"Argument 0 does not have boolean element type (element type: "
<<
input_0
.
get_element_type
(
)
<<
")."
;
<<
get_input_element_type
(
0
)
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
result_shape
=
get_input_partial_shape
(
0
);
input_0
.
get_shape
()
==
input_1
.
get_shape
()
&&
input_0
.
get_shape
()
==
input_2
.
get_shape
())
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
1
)))
<<
"Arguments do not all have the same shape (arg0 shape: "
<<
input_0
.
get_shape
()
<<
"Argument shapes are inconsistent."
;
<<
", arg1 shape: "
<<
input_1
.
get_shape
()
<<
", arg2 shape: "
<<
input_2
.
get_shape
()
NODE_VALIDATION_ASSERT
(
this
,
PartialShape
::
merge_into
(
result_shape
,
get_input_partial_shape
(
2
)))
<<
")."
;
<<
"Argument shapes are inconsistent."
;
element
::
Type
result_et
;
NODE_VALIDATION_ASSERT
(
this
,
input_1
.
get_element_type
()
==
input_2
.
get_element_type
())
NODE_VALIDATION_ASSERT
(
<<
"Arguments 1 and 2 do not have the same element type (arg1 type: "
this
,
element
::
Type
::
merge
(
result_et
,
get_input_element_type
(
1
),
get_input_element_type
(
2
)))
<<
input_1
.
get_element_type
()
<<
", arg2 type: "
<<
input_2
.
get_element_type
()
<<
")
."
;
<<
"Argument 1 and 2 element types are inconsistent
."
;
set_output_type
(
0
,
input_1
.
get_element_type
(),
input_1
.
get_shape
()
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
shared_ptr
<
Node
>
op
::
Select
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Select
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/op/select.hpp
View file @
8f102516
...
@@ -53,6 +53,7 @@ namespace ngraph
...
@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
protected
:
protected
:
void
validate_and_infer_types
()
override
;
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
;
const
NodeVector
&
deltas
)
override
;
};
};
...
...
src/ngraph/op/slice.cpp
View file @
8f102516
...
@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
...
@@ -44,55 +44,55 @@ op::Slice::Slice(const shared_ptr<Node>& arg,
void
op
::
Slice
::
validate_and_infer_types
()
void
op
::
Slice
::
validate_and_infer_types
()
{
{
if
(
validate_punt_if_dynamic
())
// An empty stride vector with lower_bounds/upper_bounds filled in means that we need to
{
// construct the default value.
return
;
if
(
m_strides
.
size
()
==
0
)
}
if
(
0
==
m_strides
.
size
())
{
{
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
m_strides
=
Strides
(
m_lower_bounds
.
size
(),
1
);
}
}
auto
&
input
=
get_inputs
().
at
(
0
);
auto
&
input_shape
=
input
.
get_shape
();
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
.
size
()
==
input_shape
.
size
())
<<
"Rank of lower bounds ("
<<
m_lower_bounds
.
size
()
<<
") does not match rank "
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (lower bounds: "
<<
m_lower_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
.
size
()
==
input_shape
.
size
())
NODE_VALIDATION_ASSERT
(
this
,
<<
"Rank of upper bounds ("
<<
m_upper_bounds
.
size
()
<<
") does not match rank "
m_lower_bounds
.
size
()
==
m_upper_bounds
.
size
()
&&
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (upper bounds: "
<<
m_upper_bounds
m_lower_bounds
.
size
()
==
m_strides
.
size
())
<<
", argument shape: "
<<
input_shape
<<
")."
;
<<
"Ranks of lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
") and strides ("
<<
m_strides
<<
") do not match."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
.
size
()
==
input_shape
.
size
())
size_t
output_rank
=
m_upper_bounds
.
size
();
<<
"Rank of strides ("
<<
m_strides
.
size
()
<<
") does not match rank "
<<
"of argument ("
<<
input_shape
.
size
()
<<
") (strides: "
<<
m_strides
<<
", argument shape: "
<<
input_shape
<<
")."
;
Shape
result_shape
;
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
{
NODE_VALIDATION_ASSERT
(
this
,
m_upper_bounds
[
i
]
<=
input_shape
[
i
])
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
NODE_VALIDATION_ASSERT
(
this
,
m_lower_bounds
[
i
]
<=
m_upper_bounds
[
i
])
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
"Lower bound for slice is greater than upper bound at axis "
<<
i
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
<<
" (lower bounds: "
<<
m_lower_bounds
<<
", upper bounds: "
<<
m_upper_bounds
<<
")."
;
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
NODE_VALIDATION_ASSERT
(
this
,
m_strides
[
i
]
!=
0
)
<<
"Stride for slice is zero at axis "
<<
i
<<
" (strides: "
<<
m_strides
<<
")."
;
<<
" (strides: "
<<
m_strides
<<
")."
;
}
const
PartialShape
&
input_shape
=
get_input_partial_shape
(
0
);
Dimension
input_rank
=
input_shape
.
rank
();
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
size_t
(
input_rank
)
==
output_rank
)
<<
"Input rank does not match the rank of the lower bounds ("
<<
m_lower_bounds
<<
"), upper bounds ("
<<
m_upper_bounds
<<
"), and strides ("
<<
m_strides
<<
")."
;
std
::
vector
<
Dimension
>
result_dims
(
output_rank
);
for
(
size_t
i
=
0
;
i
<
output_rank
;
i
++
)
{
NODE_VALIDATION_ASSERT
(
this
,
input_rank
.
is_dynamic
()
||
input_shape
[
i
].
is_dynamic
()
||
m_upper_bounds
[
i
]
<=
size_t
(
input_shape
[
i
]))
<<
"Upper bound for slice at axis "
<<
i
<<
" is out of range "
<<
"(upper bounds: "
<<
m_upper_bounds
<<
", argument shape: "
<<
input_shape
<<
")."
;
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
size_t
result_axis_size
=
m_upper_bounds
[
i
]
-
m_lower_bounds
[
i
];
result_axis_size
=
result_axis_size
=
result_axis_size
/
m_strides
[
i
]
+
((
result_axis_size
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
result_axis_size
/
m_strides
[
i
]
+
((
result_axis_size
%
m_strides
[
i
]
==
0
)
?
0
:
1
);
result_
shape
.
push_back
(
result_axis_size
)
;
result_
dims
[
i
]
=
result_axis_size
;
}
}
set_output_type
(
0
,
input
.
get_element_type
(),
result_shape
);
set_output_type
(
0
,
get_input_element_type
(
0
),
PartialShape
{
result_dims
}
);
}
}
shared_ptr
<
Node
>
op
::
Slice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Slice
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
...
...
src/ngraph/runtime/cpu/CMakeLists.txt
View file @
8f102516
...
@@ -81,6 +81,7 @@ set(SRC
...
@@ -81,6 +81,7 @@ set(SRC
op/batch_norm_relu.cpp
op/batch_norm_relu.cpp
op/bounded_relu.cpp
op/bounded_relu.cpp
op/group_conv.cpp
op/group_conv.cpp
op/halide_op.cpp
op/conv_bias.cpp
op/conv_bias.cpp
op/conv_relu.cpp
op/conv_relu.cpp
op/convert_layout.cpp
op/convert_layout.cpp
...
@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY)
...
@@ -115,6 +116,14 @@ if (NOT NGRAPH_DEX_ONLY)
)
)
endif
()
endif
()
if
(
NGRAPH_HALIDE
)
set
(
SRC
${
SRC
}
builder/halide_op.cpp
pass/halide_subgraph_extraction.cpp
)
endif
()
if
(
NGRAPH_TBB_ENABLE
)
if
(
NGRAPH_TBB_ENABLE
)
include
(
${
TBB_ROOT
}
/cmake/TBBBuild.cmake
)
include
(
${
TBB_ROOT
}
/cmake/TBBBuild.cmake
)
tbb_build
(
TBB_ROOT
${
TBB_ROOT
}
MAKE_ARGS tbb_build_dir=
${
CMAKE_CURRENT_BINARY_DIR
}
/tbb_build
tbb_build
(
TBB_ROOT
${
TBB_ROOT
}
MAKE_ARGS tbb_build_dir=
${
CMAKE_CURRENT_BINARY_DIR
}
/tbb_build
...
@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE)
...
@@ -152,6 +161,12 @@ if (NGRAPH_CPU_ENABLE)
if
(
NGRAPH_DEX_ONLY
)
if
(
NGRAPH_DEX_ONLY
)
target_compile_definitions
(
cpu_backend PRIVATE
"NGRAPH_DEX_ONLY"
)
target_compile_definitions
(
cpu_backend PRIVATE
"NGRAPH_DEX_ONLY"
)
endif
()
endif
()
if
(
NGRAPH_HALIDE
)
target_compile_definitions
(
cpu_backend PRIVATE
"NGRAPH_HALIDE"
)
ExternalProject_Get_Property
(
ext_halide BINARY_DIR
)
target_include_directories
(
cpu_backend SYSTEM PRIVATE
${
BINARY_DIR
}
/include
)
target_link_libraries
(
cpu_backend PRIVATE
${
BINARY_DIR
}
/lib/libHalide.so
)
endif
()
if
(
OPENMP_FOUND
)
if
(
OPENMP_FOUND
)
target_compile_options
(
cpu_backend PRIVATE
"
${
OpenMP_CXX_FLAGS
}
"
)
target_compile_options
(
cpu_backend PRIVATE
"
${
OpenMP_CXX_FLAGS
}
"
)
...
...
src/ngraph/runtime/cpu/builder/halide_op.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <Halide.h>
#include <HalideBuffer.h>
#include <functional>
#include <string>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define TI(x) type_index(typeid(x))
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
halide
{
static
const
std
::
unordered_map
<
std
::
type_index
,
std
::
function
<
Halide
::
Func
(
vector
<
Halide
::
Func
>
)
>>
generators
{{
TI
(
ngraph
::
op
::
Add
),
[](
vector
<
Halide
::
Func
>
in
)
{
Halide
::
Var
x
;
Halide
::
Func
func
;
func
(
x
)
=
in
[
0
](
x
)
+
in
[
1
](
x
);
return
func
;
}},
{
TI
(
ngraph
::
op
::
Multiply
),
[](
vector
<
Halide
::
Func
>
in
)
{
Halide
::
Var
x
;
Halide
::
Func
func
;
func
(
x
)
=
in
[
0
](
x
)
*
in
[
1
](
x
);
return
func
;
}},
{
TI
(
ngraph
::
op
::
Relu
),
[](
vector
<
Halide
::
Func
>
in
)
{
Halide
::
Var
x
;
Halide
::
Func
func
;
func
(
x
)
=
Halide
::
max
(
in
[
0
](
x
),
0
);
return
func
;
}}};
}
template
<>
void
Builder
::
BUILDER_DECL
(
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
)
{
const
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
*
hs
=
static_cast
<
const
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
*>
(
node
);
auto
&
halide_functions
=
external_function
->
get_halide_functions
();
auto
&
subgraph_params
=
external_function
->
get_subgraph_params
();
auto
&
subgraph_param_sizes
=
external_function
->
get_subgraph_param_sizes
();
auto
&
subgraph_param_ptrs
=
external_function
->
get_subgraph_param_ptrs
();
for
(
const
auto
&
op
:
hs
->
get_ops
())
{
if
(
!
halide
::
generators
.
count
(
TI
(
*
op
)))
{
throw
ngraph_error
(
"Invalid op in halide subgraph"
);
}
vector
<
Halide
::
Func
>
inputs
;
for
(
const
auto
&
input
:
op
->
get_inputs
())
{
auto
tensor_name
=
input
.
get_output
().
get_tensor_ptr
()
->
get_name
();
if
(
halide_functions
.
count
(
tensor_name
))
{
inputs
.
emplace_back
(
halide_functions
[
tensor_name
]);
}
else
{
subgraph_params
[
tensor_name
]
=
Halide
::
ImageParam
(
Halide
::
Float
(
32
),
1
);
subgraph_param_sizes
[
tensor_name
]
=
shape_size
(
input
.
get_output
().
get_tensor_ptr
()
->
get_shape
());
subgraph_param_ptrs
.
emplace
(
tensor_name
,
external_function
->
get_tensor_data
(
tensor_name
));
inputs
.
emplace_back
(
subgraph_params
[
tensor_name
]);
}
}
halide_functions
[
op
->
get_output_tensor_ptr
()
->
get_name
()]
=
halide
::
generators
.
at
(
TI
(
*
op
))(
inputs
);
}
auto
out_tensor_name
=
hs
->
get_ops
().
back
()
->
get_output_tensor_ptr
()
->
get_name
();
auto
&
functors
=
external_function
->
get_functors
();
auto
&
out_tensor
=
external_function
->
get_tensor_data
(
out
[
0
].
get_name
());
auto
&
terminal_func
=
halide_functions
[
out_tensor_name
];
auto
out_size
=
out
[
0
].
get_size
();
auto
functor
=
[
&
,
out_size
](
CPURuntimeContext
*
ctx
)
{
for
(
auto
&
param
:
subgraph_params
)
{
Halide
::
Buffer
<
float
>
param_buffer
(
static_cast
<
float
*>
(
subgraph_param_ptrs
.
at
(
param
.
first
).
get
()),
subgraph_param_sizes
.
at
(
param
.
first
));
param
.
second
.
set
(
param_buffer
);
}
Halide
::
Buffer
<
float
>
out_buffer
(
static_cast
<
float
*>
(
out_tensor
),
out_size
);
terminal_func
.
realize
(
out_buffer
);
};
functors
.
emplace_back
(
functor
);
}
}
}
}
src/ngraph/runtime/cpu/cpu_builder.cpp
View file @
8f102516
...
@@ -98,6 +98,7 @@
...
@@ -98,6 +98,7 @@
#include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tan.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/kernel/tanh.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
...
@@ -367,7 +368,9 @@ namespace ngraph
...
@@ -367,7 +368,9 @@ namespace ngraph
static
BuildOpMap
build_dispatcher
{
static
BuildOpMap
build_dispatcher
{
{
TI
(
ngraph
::
op
::
Parameter
),
&
runtime
::
cpu
::
Builder
::
nop
},
{
TI
(
ngraph
::
op
::
Parameter
),
&
runtime
::
cpu
::
Builder
::
nop
},
{
TI
(
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
),
{
TI
(
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
),
&
runtime
::
cpu
::
Builder
::
build
<
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
>
}};
&
runtime
::
cpu
::
Builder
::
build
<
ngraph
::
runtime
::
cpu
::
op
::
ConvertLayout
>
},
{
TI
(
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
),
&
runtime
::
cpu
::
Builder
::
build
<
ngraph
::
runtime
::
cpu
::
op
::
HalideOp
>
}};
return
build_dispatcher
;
return
build_dispatcher
;
}
}
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
8f102516
...
@@ -170,6 +170,7 @@
...
@@ -170,6 +170,7 @@
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
#ifdef NGRAPH_DISTRIBUTED
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/allreduce.hpp"
...
@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
...
@@ -1023,6 +1024,10 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUFusion
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUFusion
>
();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
// pass_manager.register_pass<runtime::cpu::pass::CPUHorizontalFusion>();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUCollapseDims
>
();
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUCollapseDims
>
();
#if defined(NGRAPH_HALIDE)
pass_manager
.
register_pass
<
ngraph
::
runtime
::
cpu
::
pass
::
HalideSubgraphExtraction
>
();
#endif
NodeVector
nv_cwi
;
// We dont need CPUWorkspaceInsertion to return list of indices
NodeVector
nv_cwi
;
// We dont need CPUWorkspaceInsertion to return list of indices
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUWorkspaceInsertion
>
(
nv_cwi
,
false
);
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUWorkspaceInsertion
>
(
nv_cwi
,
false
);
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUAssignment
>
(
this
);
pass_manager
.
register_pass
<
runtime
::
cpu
::
pass
::
CPUAssignment
>
(
this
);
...
...
src/ngraph/runtime/cpu/cpu_external_function.hpp
View file @
8f102516
...
@@ -27,6 +27,10 @@
...
@@ -27,6 +27,10 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#if defined(NGRAPH_HALIDE)
#include <Halide.h>
#endif
#if !defined(NGRAPH_DEX_ONLY)
#if !defined(NGRAPH_DEX_ONLY)
#include "ngraph/codegen/code_writer.hpp"
#include "ngraph/codegen/code_writer.hpp"
...
@@ -135,6 +139,26 @@ namespace ngraph
...
@@ -135,6 +139,26 @@ namespace ngraph
const
std
::
string
&
directory
,
const
std
::
string
&
directory
,
const
std
::
string
&
filename
);
const
std
::
string
&
filename
);
#if defined(NGRAPH_HALIDE)
std
::
unordered_map
<
std
::
string
,
Halide
::
Func
>&
get_halide_functions
()
{
return
halide_functions
;
}
std
::
unordered_map
<
std
::
string
,
Halide
::
ImageParam
>&
get_subgraph_params
()
{
return
subgraph_params
;
}
std
::
unordered_map
<
std
::
string
,
int
>&
get_subgraph_param_sizes
()
{
return
subgraph_param_sizes
;
}
std
::
unordered_map
<
std
::
string
,
std
::
reference_wrapper
<
void
*>>&
get_subgraph_param_ptrs
()
{
return
subgraph_param_ptrs
;
}
#endif
protected
:
protected
:
void
build
();
void
build
();
...
@@ -240,6 +264,13 @@ namespace ngraph
...
@@ -240,6 +264,13 @@ namespace ngraph
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
CPU_ExternalFunction
>>
callees
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
CPU_ExternalFunction
>>
callees
;
bool
m_is_built
;
bool
m_is_built
;
bool
m_direct_execution
;
bool
m_direct_execution
;
#if defined(NGRAPH_HALIDE)
std
::
unordered_map
<
std
::
string
,
Halide
::
Func
>
halide_functions
;
std
::
unordered_map
<
std
::
string
,
Halide
::
ImageParam
>
subgraph_params
;
std
::
unordered_map
<
std
::
string
,
int
>
subgraph_param_sizes
;
std
::
unordered_map
<
std
::
string
,
std
::
reference_wrapper
<
void
*>>
subgraph_param_ptrs
;
#endif
};
};
}
}
}
}
...
...
src/ngraph/runtime/cpu/op/halide_op.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/cpu/op/halide_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
shared_ptr
<
Node
>
runtime
::
cpu
::
op
::
HalideOp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
return
make_shared
<
HalideOp
>
(
new_args
,
ops
,
output_type
,
output_shape
);
}
runtime
::
cpu
::
op
::
HalideOp
::
HalideOp
(
const
NodeVector
&
args
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
ops
,
const
element
::
Type
&
out_type
,
const
Shape
&
out_shape
)
:
Op
(
"HalideOp"
,
check_single_output_args
(
args
))
,
ops
(
ops
)
,
output_type
(
out_type
)
,
output_shape
(
out_shape
)
{
constructor_validate_and_infer_types
();
}
void
runtime
::
cpu
::
op
::
HalideOp
::
validate_and_infer_types
()
{
set_output_type
(
0
,
output_type
,
output_shape
);
}
src/ngraph/runtime/cpu/op/halide_op.hpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <list>
#include <vector>
#include "ngraph/op/op.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
op
{
class
HalideOp
:
public
ngraph
::
op
::
Op
{
public
:
HalideOp
(
const
NodeVector
&
args
,
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
ops
,
const
element
::
Type
&
out_type
,
const
Shape
&
out_shape
);
virtual
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
const
std
::
list
<
std
::
shared_ptr
<
Node
>>&
get_ops
()
const
{
return
ops
;
}
private
:
std
::
list
<
std
::
shared_ptr
<
Node
>>
ops
;
element
::
Type
output_type
;
Shape
output_shape
;
};
}
}
}
}
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <iostream>
#include <list>
#include <typeindex>
#include <typeinfo>
#include <unordered_set>
#include "ngraph/op/add.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/runtime/cpu/op/halide_op.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
using
namespace
std
;
using
namespace
ngraph
;
#define TI(x) type_index(typeid(x))
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
halide
{
static
const
std
::
unordered_set
<
std
::
type_index
>
whitelist
{
TI
(
ngraph
::
op
::
Add
),
TI
(
ngraph
::
op
::
Multiply
),
TI
(
ngraph
::
op
::
Relu
)};
static
const
std
::
unordered_set
<
std
::
type_index
>
skiplist
{
TI
(
ngraph
::
op
::
Parameter
),
TI
(
ngraph
::
op
::
Result
)};
}
}
}
}
// Support for multiple results, multiple outputs and getoutputelement, and multiple subgraphs in a single
// pipeline is not implemented since this should go away in favor of the "hybrid" transformer approach of
// carving out subgraphs in core ngraph
bool
runtime
::
cpu
::
pass
::
HalideSubgraphExtraction
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
{
list
<
shared_ptr
<
Node
>>
worklist
;
auto
results
=
function
->
get_results
();
// Artificial limitation
if
(
results
.
size
()
>
1
)
{
return
false
;
}
if
(
function
->
get_result
()
->
get_element_type
()
!=
element
::
f32
)
{
return
false
;
}
for
(
const
auto
&
result
:
results
)
{
worklist
.
emplace_back
(
result
);
}
unordered_set
<
shared_ptr
<
Node
>>
ops
;
list
<
shared_ptr
<
Node
>>
ordered_ops
;
while
(
!
worklist
.
empty
())
{
const
auto
&
node
=
worklist
.
front
();
if
(
!
halide
::
skiplist
.
count
(
TI
(
*
node
)))
{
if
(
halide
::
whitelist
.
count
(
TI
(
*
node
)))
{
ops
.
emplace
(
node
);
ordered_ops
.
emplace_back
(
node
);
}
else
{
break
;
}
}
const
auto
&
args
=
node
->
get_arguments
();
for
(
const
auto
&
arg
:
args
)
{
worklist
.
emplace_back
(
arg
);
}
worklist
.
pop_front
();
}
NodeVector
liveins
;
for
(
const
auto
&
op
:
ops
)
{
const
auto
&
args
=
op
->
get_arguments
();
for
(
const
auto
&
arg
:
args
)
{
if
(
!
ops
.
count
(
arg
))
{
liveins
.
emplace_back
(
arg
);
}
}
}
ordered_ops
.
reverse
();
auto
subgraph
=
make_shared
<
cpu
::
op
::
HalideOp
>
(
liveins
,
ordered_ops
,
function
->
get_result
()
->
get_element_type
(),
function
->
get_result
()
->
get_shape
());
replace_node
(
function
->
get_result
()
->
get_argument
(
0
),
subgraph
);
return
true
;
}
src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
cpu
{
namespace
pass
{
class
HalideSubgraphExtraction
:
public
ngraph
::
pass
::
FunctionPass
{
public
:
HalideSubgraphExtraction
()
{}
bool
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
function
)
override
;
};
}
}
}
}
test/CMakeLists.txt
View file @
8f102516
...
@@ -68,6 +68,9 @@ endif()
...
@@ -68,6 +68,9 @@ endif()
if
(
NGRAPH_CPU_ENABLE
)
if
(
NGRAPH_CPU_ENABLE
)
list
(
APPEND SRC core_fusion.cpp quantize_cpu.cpp
)
list
(
APPEND SRC core_fusion.cpp quantize_cpu.cpp
)
list
(
APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp
)
list
(
APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp
)
if
(
NGRAPH_HALIDE
)
list
(
APPEND SRC halide.cpp
)
endif
()
set
(
ACTIVE_BACKEND_LIST
${
ACTIVE_BACKEND_LIST
}
CPU
)
set
(
ACTIVE_BACKEND_LIST
${
ACTIVE_BACKEND_LIST
}
CPU
)
endif
()
endif
()
...
...
test/halide.cpp
0 → 100644
View file @
8f102516
//*****************************************************************************
// Copyright 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "util/all_close.hpp"
#include "util/test_tools.hpp"
using
namespace
ngraph
;
using
namespace
std
;
TEST
(
halide
,
halide_subgraph
)
{
Shape
shape
{
8
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
C
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
D
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
relu
=
make_shared
<
op
::
Relu
>
((
A
+
B
)
*
C
);
auto
f
=
make_shared
<
Function
>
(
relu
+
D
,
op
::
ParameterVector
{
A
,
B
,
C
,
D
});
auto
backend
=
runtime
::
Backend
::
create
(
"CPU"
);
shared_ptr
<
runtime
::
Tensor
>
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
b
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
c
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
d
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
shared_ptr
<
runtime
::
Tensor
>
result
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
vector
<
float
>
data
{
-
1
,
4
,
-
2
,
5
,
1
,
5
,
7
,
9
};
copy_data
(
a
,
data
);
copy_data
(
b
,
data
);
copy_data
(
c
,
data
);
copy_data
(
d
,
data
);
vector
<
float
>
expected
{
1
,
36
,
6
,
55
,
3
,
55
,
105
,
171
};
backend
->
call_with_validate
(
f
,
{
result
},
{
a
,
b
,
c
,
d
});
EXPECT_TRUE
(
test
::
all_close
(
read_vector
<
float
>
(
result
),
expected
,
1.0e-4
f
,
1.0e-4
f
));
}
test/type_prop.cpp
View file @
8f102516
...
@@ -25,12 +25,8 @@ using namespace ngraph;
...
@@ -25,12 +25,8 @@ using namespace ngraph;
#define EXPECT_HAS_SUBSTRING(haystack, needle) \
#define EXPECT_HAS_SUBSTRING(haystack, needle) \
EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
//
// Tests for broadcast.
//
TEST
(
type_prop
,
broadcast_deduce
)
TEST
(
type_prop
,
broadcast_deduce
)
{
{
// Deduce type
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
Shape
bc_shape
{
2
,
3
,
4
};
Shape
bc_shape
{
2
,
3
,
4
};
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
...
@@ -38,6 +34,175 @@ TEST(type_prop, broadcast_deduce)
...
@@ -38,6 +34,175 @@ TEST(type_prop, broadcast_deduce)
ASSERT_EQ
(
bc
->
get_shape
(),
bc_shape
);
ASSERT_EQ
(
bc
->
get_shape
(),
bc_shape
);
}
}
TEST
(
type_prop
,
broadcast_axes_oob
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
auto
bc_shape
=
Shape
{
2
,
3
,
4
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
,
3
});
FAIL
()
<<
"Broadcast axis out of bounds not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast axis index (3) exceeds specified output shape rank"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
broadcast_shape_mismatch_wrong_rank
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
auto
bc_shape
=
Shape
{
2
,
3
,
4
,
5
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
FAIL
()
<<
"Output shape mismatch (wrong rank) not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast argument shape, specified output shape, and axes are incompatible"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
broadcast_shape_mismatch_wrong_size
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
auto
bc_shape
=
Shape
{
2
,
3
,
5
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
FAIL
()
<<
"Output shape mismatch (wrong size) not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast argument shape, specified output shape, and axes are incompatible"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
broadcast_partial_rank_dynamic_ok
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
Shape
bc_shape
{
2
,
3
,
4
};
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
ASSERT_EQ
(
bc
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
bc
->
get_shape
(),
bc_shape
);
}
TEST
(
type_prop
,
broadcast_partial_rank_dynamic_axes_oob
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
bc_shape
=
Shape
{
2
,
3
,
4
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
,
3
});
FAIL
()
<<
"Broadcast axis out of bounds not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast axis index (3) exceeds specified output shape rank"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
broadcast_partial_rank_static_dynamic_ok
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
4
});
Shape
bc_shape
{
2
,
3
,
4
};
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
ASSERT_EQ
(
bc
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
bc
->
get_shape
(),
bc_shape
);
}
TEST
(
type_prop
,
broadcast_partial_rank_static_dynamic_axes_oob
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
4
});
auto
bc_shape
=
Shape
{
2
,
3
,
4
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
,
3
});
FAIL
()
<<
"Broadcast axis out of bounds not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast axis index (3) exceeds specified output shape rank"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_rank
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
4
});
auto
bc_shape
=
Shape
{
2
,
3
,
4
,
5
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
FAIL
()
<<
"Output shape mismatch (wrong rank) not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast argument shape, specified output shape, and axes are incompatible"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
broadcast_partial_rank_static_dynamic_shape_mismatch_wrong_size
)
{
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
4
});
auto
bc_shape
=
Shape
{
2
,
3
,
5
};
try
{
auto
bc
=
make_shared
<
op
::
Broadcast
>
(
param
,
bc_shape
,
AxisSet
{
1
});
FAIL
()
<<
"Output shape mismatch (wrong size) not detected"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
"Broadcast argument shape, specified output shape, and axes are incompatible"
);
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
batchnorm_rank_less_than_2
)
TEST
(
type_prop
,
batchnorm_rank_less_than_2
)
{
{
auto
dummy
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
});
auto
dummy
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
});
...
@@ -949,7 +1114,7 @@ TEST(type_prop, select_shape_mismatch_a)
...
@@ -949,7 +1114,7 @@ TEST(type_prop, select_shape_mismatch_a)
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument
s do not all have the same shape
"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument
shapes are inconsistent
"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -970,7 +1135,7 @@ TEST(type_prop, select_shape_mismatch_b)
...
@@ -970,7 +1135,7 @@ TEST(type_prop, select_shape_mismatch_b)
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument
s do not all have the same shape
"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument
shapes are inconsistent
"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -991,7 +1156,7 @@ TEST(type_prop, select_shape_mismatch_c)
...
@@ -991,7 +1156,7 @@ TEST(type_prop, select_shape_mismatch_c)
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument
s do not all have the same shape
"
));
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument
shapes are inconsistent
"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1035,7 +1200,160 @@ TEST(type_prop, select_elem_mismatch_bc)
...
@@ -1035,7 +1200,160 @@ TEST(type_prop, select_elem_mismatch_bc)
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Arguments 1 and 2 do not have the same element type"
));
std
::
string
(
"Argument 1 and 2 element types are inconsistent"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
select_partial_all_rank_dynamic
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
rank
().
is_dynamic
());
}
TEST
(
type_prop
,
select_partial_all_rank_dynamic_arg0_et_dynamic_arg1_arg2_et_mismatch
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
PartialShape
::
dynamic
());
try
{
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
FAIL
()
<<
"Did not detect mismatched element types for args 1 and 2 (element type-dynamic "
"arg0)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument 1 and 2 element types are inconsistent"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
select_partial_all_rank_dynamic_arg0_arg1_et_dynamic
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
rank
().
is_dynamic
());
}
TEST
(
type_prop
,
select_partial_all_rank_dynamic_arg0_arg2_et_dynamic
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
rank
().
is_dynamic
());
}
TEST
(
type_prop
,
select_partial_all_rank_dynamic_arg0_arg1_arg2_et_dynamic
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
dynamic
,
PartialShape
::
dynamic
());
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
dynamic
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
rank
().
is_dynamic
());
}
TEST
(
type_prop
,
select_partial_arg0_rank_dynamic_static_arg1_arg2_rank_dynamic_ok
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
PartialShape
{
2
,
Dimension
::
dynamic
(),
3
});
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
2
,
Dimension
::
dynamic
(),
3
}));
}
TEST
(
type_prop
,
select_partial_arg1_rank_dynamic_static_arg0_arg2_rank_dynamic_ok
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
2
,
Dimension
::
dynamic
(),
3
});
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
2
,
Dimension
::
dynamic
(),
3
}));
}
TEST
(
type_prop
,
select_partial_arg2_rank_dynamic_static_arg0_arg1_rank_dynamic_ok
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
PartialShape
::
dynamic
());
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
::
dynamic
());
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
2
,
Dimension
::
dynamic
(),
3
});
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
2
,
Dimension
::
dynamic
(),
3
}));
}
TEST
(
type_prop
,
select_partial_all_rank_static_dynamic_ok
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
PartialShape
{
2
,
Dimension
::
dynamic
(),
Dimension
::
dynamic
()});
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
8
,
Dimension
::
dynamic
()});
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
3
});
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
ASSERT_EQ
(
sel
->
get_output_element_type
(
0
),
element
::
f32
);
ASSERT_TRUE
(
sel
->
get_output_partial_shape
(
0
).
is_static
());
ASSERT_EQ
(
sel
->
get_output_shape
(
0
),
(
Shape
{
2
,
8
,
3
}));
}
TEST
(
type_prop
,
select_partial_all_rank_static_intransitive_incompatibility
)
{
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
PartialShape
{
2
,
Dimension
::
dynamic
(),
Dimension
::
dynamic
()});
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
Dimension
::
dynamic
(),
8
,
Dimension
::
dynamic
()});
auto
param2
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
3
,
Dimension
::
dynamic
(),
3
});
try
{
auto
sel
=
make_shared
<
op
::
Select
>
(
param0
,
param1
,
param2
);
FAIL
()
<<
"Did not detect intransitive partial-shape incompatibility"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument shapes are inconsistent"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1654,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides)
...
@@ -1654,7 +1972,9 @@ TEST(type_prop, slice_deduce_vector_invalid_strides)
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Rank of strides (2) does not match rank of argument (1)"
));
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1757,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing)
...
@@ -1757,7 +2077,8 @@ TEST(type_prop, slice_deduce_matrix_lower_missing)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Rank of lower bounds (1) does not match rank of argument (2)"
));
std
::
string
(
"Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1778,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing)
...
@@ -1778,7 +2099,8 @@ TEST(type_prop, slice_deduce_matrix_upper_missing)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Rank of upper bounds (1) does not match rank of argument (2)"
));
std
::
string
(
"Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1797,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra)
...
@@ -1797,9 +2119,10 @@ TEST(type_prop, slice_deduce_matrix_lower_extra)
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{0, 0, "
std
::
string
(
"Rank of lower bounds (3) does not match rank of argument (2)"
));
"0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1817,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
...
@@ -1817,10 +2140,169 @@ TEST(type_prop, slice_deduce_matrix_upper_extra)
FAIL
()
<<
"Extra upper bound coordinate not detected"
;
FAIL
()
<<
"Extra upper bound coordinate not detected"
;
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
slice_partial_arg_input_rank_dynamic_attribs_ok
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
sl
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
sl
->
get_shape
(),
(
Shape
{
0
,
1
,
2
,
2
}));
}
TEST
(
type_prop
,
slice_partial_arg_rank_dynamic_attribs_rank_mismatch
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
try
{
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
slice_partial_arg_rank_dynamic_attribs_bounds_crossing
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
8
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
try
{
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Crossing lower/upper bounds not detected (argument rank-dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Rank of upper bounds (3) does not match rank of argument (2)"
));
std
::
string
(
"Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
slice_partial_arg_rank_static_dynamic_ok
)
{
PartialShape
input_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
sl
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
sl
->
get_shape
(),
(
Shape
{
0
,
1
,
2
,
2
}));
}
TEST
(
type_prop
,
slice_partial_arg_rank_static_dynamic_some_dims_known_ok
)
{
PartialShape
input_shape
{
2
,
4
,
10
,
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
sl
->
get_element_type
(),
element
::
f32
);
ASSERT_EQ
(
sl
->
get_shape
(),
(
Shape
{
0
,
1
,
2
,
2
}));
}
TEST
(
type_prop
,
slice_partial_arg_rank_static_dynamic_attribs_rank_mismatches_arg
)
{
PartialShape
input_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
try
{
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Input rank does not match the "
"rank of the lower bounds (Coordinate{1, 2, "
"3, 4}), upper bounds (Coordinate{1, 3, 5, "
"7}), and strides (Strides{1, 1, 1, 2})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
slice_partial_arg_rank_static_dynamic_some_dims_known_upper_bounds_oob
)
{
PartialShape
input_shape
{
2
,
2
,
10
,
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
try
{
auto
sl
=
make_shared
<
op
::
Slice
>
(
param
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Upper bounds out of bounds not detected (argument rank-static dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -1964,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides)
...
@@ -1964,7 +2446,9 @@ TEST(type_prop, replace_slice_deduce_vector_invalid_strides)
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Rank of strides (2) does not match rank of argument (1)"
));
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{7}) and strides (Strides{1, 2}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -2027,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch)
...
@@ -2027,9 +2511,10 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch)
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
EXPECT_HAS_SUBSTRING
(
std
::
string
(
"Shape of replacement tensor (Shape{3, 6}) does not match "
error
.
what
(),
"the slice shape (Shape{4, 6})"
));
std
::
string
(
"Shape of replacement tensor ({3,6}) does not match the slice shape ({4,6})"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -2053,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided)
...
@@ -2053,7 +2538,7 @@ TEST(type_prop, replace_slice_deduce_matrix_slice_shape_mismatch_strided)
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
std
::
string
(
"Shape of replacement tensor (
Shape{4, 6}) does not match the slice shape
"
));
"Shape of replacement tensor (
{4,6}) does not match the slice shape ({4,3})
"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -2163,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing)
...
@@ -2163,7 +2648,8 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_missing)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Rank of lower bounds (1) does not match rank of argument (2)"
));
std
::
string
(
"Ranks of lower bounds (Coordinate{0}), upper bounds "
"(Coordinate{5, 5}) and strides (Strides{1}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -2185,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing)
...
@@ -2185,7 +2671,8 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_missing)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Rank of upper bounds (1) does not match rank of argument (2)"
));
std
::
string
(
"Ranks of lower bounds (Coordinate{0, 0}), upper bounds "
"(Coordinate{5}) and strides (Strides{1, 1}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -2206,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra)
...
@@ -2206,9 +2693,10 @@ TEST(type_prop, replace_slice_deduce_matrix_lower_extra)
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{0, 0, "
std
::
string
(
"Rank of lower bounds (3) does not match rank of argument (2)"
));
"0}), upper bounds (Coordinate{5, 5}) and "
"strides (Strides{1, 1, 1}) do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
@@ -2228,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra)
...
@@ -2228,10 +2716,337 @@ TEST(type_prop, replace_slice_deduce_matrix_upper_extra)
FAIL
()
<<
"Extra upper bound coordinate not detected"
;
FAIL
()
<<
"Extra upper bound coordinate not detected"
;
}
}
catch
(
const
NodeValidationError
&
error
)
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{0, 0}), "
"upper bounds (Coordinate{5, 5, 5}) and "
"strides (Strides{1, 1}) do not match"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_ok
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
rsl
->
get_element_type
(),
element
::
f32
);
ASSERT_TRUE
(
rsl
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()}));
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_rank_mismatch
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatch of lower-bounds/upper-bounds/strides ranks not detected (argument "
"rank-dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Ranks of lower bounds (Coordinate{1, 2, 3, 4}), upper bounds "
"(Coordinate{1, 3, 5}) and strides (Strides{1, 1, 1, 2}) do not match"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_dynamic_attribs_bounds_crossing
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
8
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Crossing lower/upper bounds not detected (argument rank-dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
{
EXPECT_HAS_SUBSTRING
(
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
error
.
what
(),
std
::
string
(
"Rank of upper bounds (3) does not match rank of argument (2)"
));
std
::
string
(
"Lower bound for slice is greater than upper bound at axis 3 (lower "
"bounds: Coordinate{1, 2, 3, 8}, upper bounds: Coordinate{1, 3, 5, 7})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_ok
)
{
PartialShape
input_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
rsl
->
get_element_type
(),
element
::
f32
);
ASSERT_TRUE
(
rsl
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()}));
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_ok
)
{
PartialShape
input_shape
{
2
,
4
,
10
,
Dimension
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
rsl
->
get_element_type
(),
element
::
f32
);
ASSERT_TRUE
(
rsl
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
2
,
4
,
10
,
Dimension
::
dynamic
()}));
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_dynamic_attribs_rank_mismatches_input
)
{
PartialShape
input_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatch of attrib ranks with arg ranks not detected (argument rank-static "
"dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_static_dynamic_some_dims_known_replacement_rank_dynamic_upper_bounds_oob
)
{
PartialShape
input_shape
{
2
,
2
,
10
,
Dimension
::
dynamic
()};
PartialShape
replacement_shape
{
PartialShape
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Upper bounds out of bounds not detected (argument rank-static dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Upper bound for slice at axis 1 is out of "
"range (upper bounds: Coordinate{1, 3, 5, "
"7}, argument shape: {2,2,10,?})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_ok
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
rsl
->
get_element_type
(),
element
::
f32
);
ASSERT_TRUE
(
rsl
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()}));
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_ok
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
0
,
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
2
};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
ASSERT_EQ
(
rsl
->
get_element_type
(),
element
::
f32
);
ASSERT_TRUE
(
rsl
->
get_output_partial_shape
(
0
).
same_scheme
(
PartialShape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()}));
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_some_dims_known_attribs_mismatch_replacement_shape
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
1
,
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
2
};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatch of shape inferred from attributes with provided replacement shape not "
"detected (rank-dynamic/rank-static dynamic inputs)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Shape of replacement tensor ({1,?,?,2}) does not match "
"the slice shape ({0,1,2,2})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_dynamic_replacement_rank_static_dynamic_attribs_rank_mismatches_replacement
)
{
PartialShape
input_shape
{
PartialShape
::
dynamic
()};
PartialShape
replacement_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatch of attrib ranks with arg ranks not detected (arguments "
"rank-dynamic/rank-static "
"dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument ranks do not match the rank of the lower bounds "
"(Coordinate{1, 2, 3, 4}), upper bounds (Coordinate{1, 3, "
"5, 7}), and strides (Strides{1, 1, 1, 2})"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
TEST
(
type_prop
,
replace_slice_partial_input_rank_static_dynamic_replacement_rank_static_dynamic_argument_ranks_mismatch
)
{
PartialShape
input_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
PartialShape
replacement_shape
{
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
(),
Dimension
::
dynamic
()};
Coordinate
lower_bounds
{
1
,
2
,
3
,
4
};
Coordinate
upper_bounds
{
1
,
3
,
5
,
7
};
Strides
strides
{
1
,
1
,
1
,
2
};
auto
param0
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
input_shape
);
auto
param1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
replacement_shape
);
try
{
auto
rsl
=
make_shared
<
op
::
ReplaceSlice
>
(
param0
,
param1
,
lower_bounds
,
upper_bounds
,
strides
);
// Should have thrown, so fail if it didn't
FAIL
()
<<
"Mismatching input/replacement ranks not detected (arguments both rank-static "
"dynamic)"
;
}
catch
(
const
NodeValidationError
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"Argument ranks do not match"
));
}
}
catch
(...)
catch
(...)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment