Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
90c70dde
Unverified
Commit
90c70dde
authored
Nov 18, 2019
by
Scott Cyphers
Committed by
GitHub
Nov 18, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
CropAndResize op (#3893)
* Stub for CropAndResize * Cut and pasteo * Need a cast
parent
1ac3e5c7
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
391 additions
and
0 deletions
+391
-0
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-0
ngraph.hpp
src/ngraph/ngraph.hpp
+1
-0
node.hpp
src/ngraph/node.hpp
+9
-0
crop_and_resize.cpp
src/ngraph/op/crop_and_resize.cpp
+198
-0
crop_and_resize.hpp
src/ngraph/op/crop_and_resize.hpp
+76
-0
op_tbl.hpp
src/ngraph/op/op_tbl.hpp
+1
-0
int_executable.hpp
src/ngraph/runtime/interpreter/int_executable.hpp
+5
-0
serializer.cpp
src/ngraph/serializer.cpp
+17
-0
CMakeLists.txt
test/CMakeLists.txt
+1
-0
crop_and_resize.cpp
test/type_prop/crop_and_resize.cpp
+81
-0
No files found.
src/ngraph/CMakeLists.txt
View file @
90c70dde
...
@@ -128,6 +128,8 @@ set (SRC
...
@@ -128,6 +128,8 @@ set (SRC
op/cos.hpp
op/cos.hpp
op/cosh.cpp
op/cosh.cpp
op/cosh.hpp
op/cosh.hpp
op/crop_and_resize.cpp
op/crop_and_resize.hpp
op/dequantize.cpp
op/dequantize.cpp
op/dequantize.hpp
op/dequantize.hpp
op/divide.cpp
op/divide.cpp
...
...
src/ngraph/ngraph.hpp
View file @
90c70dde
...
@@ -81,6 +81,7 @@
...
@@ -81,6 +81,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
...
...
src/ngraph/node.hpp
View file @
90c70dde
...
@@ -410,6 +410,9 @@ namespace ngraph
...
@@ -410,6 +410,9 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input
<
Node
>
input
(
size_t
input_index
);
Input
<
Node
>
input
(
size_t
input_index
);
// Simplify migration from 0.25.1
Output
<
Node
>
input_value
(
size_t
input_index
)
const
;
/// \return A handle to the `input_index`th input of this node.
/// \return A handle to the `input_index`th input of this node.
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input
<
const
Node
>
input
(
size_t
input_index
)
const
;
Input
<
const
Node
>
input
(
size_t
input_index
)
const
;
...
@@ -650,6 +653,12 @@ namespace ngraph
...
@@ -650,6 +653,12 @@ namespace ngraph
return
Input
<
const
Node
>
(
this
,
input_index
);
return
Input
<
const
Node
>
(
this
,
input_index
);
}
}
// Simplify migration from 0.25.1
inline
Output
<
Node
>
Node
::
input_value
(
size_t
input_index
)
const
{
return
input
(
input_index
).
get_source_output
();
}
inline
Output
<
Node
>
Node
::
output
(
size_t
output_index
)
inline
Output
<
Node
>
Node
::
output
(
size_t
output_index
)
{
{
if
(
output_index
>=
m_outputs
.
size
())
if
(
output_index
>=
m_outputs
.
size
())
...
...
src/ngraph/op/crop_and_resize.cpp
0 → 100644
View file @
90c70dde
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/crop_and_resize.hpp"
using
namespace
std
;
using
namespace
ngraph
;
const
string
op
::
CropAndResize
::
type_name
{
"CropAndResize"
};
op
::
CropAndResize
::
CropAndResize
(
const
Output
<
Node
>&
image
,
const
Output
<
Node
>&
boxes
,
const
Output
<
Node
>&
box_indices
,
const
Output
<
Node
>&
crop_size
,
ResizeMethod
resize_method
,
float
extrapolation_value
)
:
Op
({
image
,
boxes
,
box_indices
,
crop_size
})
,
m_resize_method
(
resize_method
)
,
m_extrapolation_value
(
extrapolation_value
)
{
constructor_validate_and_infer_types
();
}
void
op
::
CropAndResize
::
validate_and_infer_types
()
{
NODE_VALIDATION_CHECK
(
this
,
get_input_size
()
==
4
);
NODE_VALIDATION_CHECK
(
this
,
m_resize_method
!=
ResizeMethod
::
unspecified
,
"Resize method not specified"
);
auto
image
=
input_value
(
0
);
auto
&
image_et
=
image
.
get_element_type
();
// Will override if we can determine the shape
set_output_type
(
0
,
image_et
,
{});
auto
image_shape
=
image
.
get_partial_shape
();
Dimension
image_depth
;
if
(
image_shape
.
is_static
())
{
NODE_VALIDATION_CHECK
(
this
,
static_cast
<
int64_t
>
(
image_shape
.
rank
())
==
4
,
"Image must be NHWC"
);
image_depth
=
image_shape
[
3
];
}
auto
boxes
=
input_value
(
1
);
auto
boxes_shape
=
boxes
.
get_partial_shape
();
if
(
boxes_shape
.
is_static
())
{
auto
boxes_rank
=
boxes_shape
.
rank
();
NODE_VALIDATION_CHECK
(
this
,
static_cast
<
int64_t
>
(
boxes_rank
)
==
2
,
"Boxes must be 2d"
);
auto
boxes_dim1
=
boxes_shape
[
1
];
NODE_VALIDATION_CHECK
(
this
,
static_cast
<
int64_t
>
(
boxes_dim1
)
==
4
,
"Second boxes dimension must be 4"
);
}
NODE_VALIDATION_CHECK
(
this
,
boxes
.
get_element_type
().
is_real
(),
"Boxes must be real values in [0, 1]"
);
auto
box_indices
=
input_value
(
2
);
auto
box_indices_shape
=
box_indices
.
get_partial_shape
();
Dimension
num_boxes
;
if
(
box_indices_shape
.
is_static
())
{
NODE_VALIDATION_CHECK
(
this
,
static_cast
<
int64_t
>
(
box_indices_shape
.
rank
())
==
1
,
"Box indices must have rank 1"
);
num_boxes
=
box_indices_shape
[
0
];
}
NODE_VALIDATION_CHECK
(
this
,
box_indices
.
get_element_type
().
is_integral
(),
"Box indices must be integers"
);
auto
crop_size
=
input_value
(
3
);
auto
crop_size_shape
=
crop_size
.
get_partial_shape
();
auto
crop_size_rank
=
crop_size_shape
.
rank
();
NODE_VALIDATION_CHECK
(
this
,
crop_size_shape
.
is_static
()
||
crop_size_rank
.
is_dynamic
(),
"Dynamic crop_size not supported"
);
NODE_VALIDATION_CHECK
(
this
,
static_cast
<
int64_t
>
(
crop_size_rank
)
==
1
,
"crop_size must be a vector"
);
NODE_VALIDATION_CHECK
(
this
,
static_cast
<
int64_t
>
(
crop_size_shape
[
0
])
==
2
,
"crop_size must be a vector of length 2"
);
auto
&
crop_size_et
=
crop_size
.
get_element_type
();
NODE_VALIDATION_CHECK
(
this
,
crop_size_et
.
is_integral
(),
"crops_size must be integral"
);
auto
crop_size_node
=
crop_size
.
get_node_shared_ptr
();
NODE_VALIDATION_CHECK
(
this
,
crop_size_node
->
is_constant
(),
"crop_size must be a constant"
);
auto
crop_size_const
=
static_pointer_cast
<
op
::
Constant
>
(
crop_size_node
);
if
(
crop_size_et
==
element
::
i8
)
{
auto
v
=
crop_size_const
->
get_vector
<
int8_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
u8
)
{
auto
v
=
crop_size_const
->
get_vector
<
uint8_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
i16
)
{
auto
v
=
crop_size_const
->
get_vector
<
int16_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
u16
)
{
auto
v
=
crop_size_const
->
get_vector
<
uint16_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
i32
)
{
auto
v
=
crop_size_const
->
get_vector
<
int32_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
u32
)
{
auto
v
=
crop_size_const
->
get_vector
<
uint32_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
i64
)
{
auto
v
=
crop_size_const
->
get_vector
<
int64_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
v
[
0
],
v
[
1
],
image_depth
});
}
else
if
(
crop_size_et
==
element
::
u64
)
{
auto
v
=
crop_size_const
->
get_vector
<
uint64_t
>
();
set_output_type
(
0
,
image_et
,
{
num_boxes
,
static_cast
<
int64_t
>
(
v
[
0
]),
static_cast
<
int64_t
>
(
v
[
1
]),
image_depth
});
}
else
{
NODE_VALIDATION_CHECK
(
this
,
false
,
"Unknown integral type for crop size"
);
}
}
shared_ptr
<
Node
>
op
::
CropAndResize
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
CropAndResize
>
(
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
),
m_resize_method
,
m_extrapolation_value
);
}
static
const
vector
<
pair
<
string
,
op
::
CropAndResize
::
ResizeMethod
>>&
get_resize_pairs
()
{
static
vector
<
pair
<
string
,
op
::
CropAndResize
::
ResizeMethod
>>
pairs
{
{
"unspecified"
,
op
::
CropAndResize
::
ResizeMethod
::
unspecified
},
{
"bilinear"
,
op
::
CropAndResize
::
ResizeMethod
::
bilinear
},
{
"nearest"
,
op
::
CropAndResize
::
ResizeMethod
::
nearest
}};
return
pairs
;
}
const
string
&
ngraph
::
as_string
(
op
::
CropAndResize
::
ResizeMethod
resize_method
)
{
for
(
auto
&
p
:
get_resize_pairs
())
{
if
(
p
.
second
==
resize_method
)
{
return
p
.
first
;
}
}
throw
ngraph_error
(
"Internal error: unhandled resize method"
);
}
namespace
ngraph
{
template
<>
op
::
CropAndResize
::
ResizeMethod
as_type
<
op
::
CropAndResize
::
ResizeMethod
>
(
const
std
::
string
&
s
)
{
for
(
auto
&
p
:
get_resize_pairs
())
{
if
(
p
.
first
==
s
)
{
return
p
.
second
;
}
}
throw
ngraph_error
(
"Internal error: unhandled resize method name"
);
}
}
src/ngraph/op/crop_and_resize.hpp
0 → 100644
View file @
90c70dde
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace
ngraph
{
namespace
op
{
class
CropAndResize
:
public
Op
{
public
:
enum
class
ResizeMethod
{
unspecified
,
bilinear
,
nearest
};
NGRAPH_API
static
const
std
::
string
type_name
;
const
std
::
string
&
description
()
const
override
{
return
type_name
;
}
/// \brief Constructs a crop and resize operation.
CropAndResize
()
=
default
;
/// \param image [N, H, W, C]
/// \param boxes [NUM_BOXES, 4] where boxes[box] is [y1, x1, y2, x2] each in [0, 1]
/// \param box_indices [NUM_BOXES] in [0, N)
/// \param crop_size [crop_height, crop_width]
CropAndResize
(
const
Output
<
Node
>&
image
,
const
Output
<
Node
>&
boxes
,
const
Output
<
Node
>&
box_indices
,
const
Output
<
Node
>&
crop_size
,
ResizeMethod
resize_method
,
float
extrapolation_value
);
void
validate_and_infer_types
()
override
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
ResizeMethod
get_resize_method
()
const
{
return
m_resize_method
;
}
void
set_resize_method
(
ResizeMethod
resize_method
)
{
m_resize_method
=
resize_method
;
}
float
get_extrapolation_value
()
const
{
return
m_extrapolation_value
;
}
void
set_extrapolation_value
(
float
extrapolation_value
)
{
m_extrapolation_value
=
extrapolation_value
;
}
private
:
ResizeMethod
m_resize_method
{
ResizeMethod
::
unspecified
};
float
m_extrapolation_value
{
0
};
};
}
const
std
::
string
&
as_string
(
op
::
CropAndResize
::
ResizeMethod
);
template
<
typename
T
>
T
as_type
(
const
std
::
string
&
);
template
<>
op
::
CropAndResize
::
ResizeMethod
as_type
<
op
::
CropAndResize
::
ResizeMethod
>
(
const
std
::
string
&
);
}
src/ngraph/op/op_tbl.hpp
View file @
90c70dde
...
@@ -80,6 +80,7 @@ NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
...
@@ -80,6 +80,7 @@ NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
NGRAPH_OP
(
ConvolutionBackpropFilters
,
ngraph
::
op
)
NGRAPH_OP
(
ConvolutionBackpropFilters
,
ngraph
::
op
)
NGRAPH_OP
(
Cos
,
ngraph
::
op
)
NGRAPH_OP
(
Cos
,
ngraph
::
op
)
NGRAPH_OP
(
Cosh
,
ngraph
::
op
)
NGRAPH_OP
(
Cosh
,
ngraph
::
op
)
NGRAPH_OP
(
CropAndResize
,
ngraph
::
op
)
NGRAPH_OP
(
Dequantize
,
ngraph
::
op
)
NGRAPH_OP
(
Dequantize
,
ngraph
::
op
)
NGRAPH_OP
(
Divide
,
ngraph
::
op
)
NGRAPH_OP
(
Divide
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
...
...
src/ngraph/runtime/interpreter/int_executable.hpp
View file @
90c70dde
...
@@ -707,6 +707,11 @@ private:
...
@@ -707,6 +707,11 @@ private:
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
CropAndResize
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
Dequantize
:
case
OP_TYPEID
:
:
Dequantize
:
{
{
const
op
::
Dequantize
*
dequantize
=
static_cast
<
const
op
::
Dequantize
*>
(
&
node
);
const
op
::
Dequantize
*
dequantize
=
static_cast
<
const
op
::
Dequantize
*>
(
&
node
);
...
...
src/ngraph/serializer.cpp
View file @
90c70dde
...
@@ -45,6 +45,7 @@
...
@@ -45,6 +45,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
...
@@ -1117,6 +1118,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
...
@@ -1117,6 +1118,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node
=
make_shared
<
op
::
Cosh
>
(
args
[
0
]);
node
=
make_shared
<
op
::
Cosh
>
(
args
[
0
]);
break
;
break
;
}
}
case
OP_TYPEID
:
:
CropAndResize
:
{
auto
resize_method
=
as_type
<
op
::
CropAndResize
::
ResizeMethod
>
(
node_js
.
at
(
"resize_method"
).
get
<
string
>
());
auto
extrapolation_value
=
node_js
.
at
(
"extrapolation_value"
).
get
<
float
>
();
node
=
make_shared
<
op
::
CropAndResize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
resize_method
,
extrapolation_value
);
break
;
}
case
OP_TYPEID
:
:
DepthToSpace
:
case
OP_TYPEID
:
:
DepthToSpace
:
{
{
auto
block_size
=
node_js
.
at
(
"block_size"
).
get
<
size_t
>
();
auto
block_size
=
node_js
.
at
(
"block_size"
).
get
<
size_t
>
();
...
@@ -2363,6 +2373,13 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2363,6 +2373,13 @@ json JSONSerializer::serialize_node(const Node& n)
}
}
case
OP_TYPEID
:
:
Cosh
:
{
break
;
case
OP_TYPEID
:
:
Cosh
:
{
break
;
}
}
case
OP_TYPEID
:
:
CropAndResize
:
{
auto
tmp
=
static_cast
<
const
op
::
CropAndResize
*>
(
&
n
);
node
[
"resize_method"
]
=
as_string
(
tmp
->
get_resize_method
());
node
[
"extrapolation_value"
]
=
tmp
->
get_extrapolation_value
();
break
;
}
case
OP_TYPEID
:
:
Dequantize
:
case
OP_TYPEID
:
:
Dequantize
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
...
...
test/CMakeLists.txt
View file @
90c70dde
...
@@ -89,6 +89,7 @@ set(SRC
...
@@ -89,6 +89,7 @@ set(SRC
type_prop/convert.cpp
type_prop/convert.cpp
type_prop/convolution.cpp
type_prop/convolution.cpp
type_prop/convolution_bias.cpp
type_prop/convolution_bias.cpp
type_prop/crop_and_resize.cpp
type_prop/depth_to_space.cpp
type_prop/depth_to_space.cpp
type_prop/dequantize.cpp
type_prop/dequantize.cpp
type_prop/dot.cpp
type_prop/dot.cpp
...
...
test/type_prop/crop_and_resize.cpp
0 → 100644
View file @
90c70dde
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using
namespace
std
;
using
namespace
ngraph
;
TEST
(
type_prop
,
crop_and_resize_valid
)
{
Dimension
N
=
4
;
Dimension
W_image
=
400
;
Dimension
H_image
=
300
;
Dimension
C_image
=
3
;
Dimension
num_boxes
=
20
;
int32_t
W_crop
=
30
;
int32_t
H_crop
=
40
;
PartialShape
result_shape
{
num_boxes
,
H_crop
,
W_crop
,
C_image
};
auto
image
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
N
,
H_image
,
W_image
,
C_image
});
auto
boxes
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
num_boxes
,
4
});
auto
box_indices
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
PartialShape
{
num_boxes
});
auto
crop_shape
=
op
::
Constant
::
create
(
element
::
i32
,
Shape
{
2
},
{
H_crop
,
W_crop
});
auto
crop_and_resize
=
make_shared
<
op
::
CropAndResize
>
(
image
,
boxes
,
box_indices
,
crop_shape
,
op
::
CropAndResize
::
ResizeMethod
::
bilinear
,
0
);
auto
result
=
crop_and_resize
->
output
(
0
);
ASSERT_EQ
(
result
.
get_shape
(),
result_shape
.
to_shape
());
ASSERT_EQ
(
result
.
get_element_type
(),
image
->
output
(
0
).
get_element_type
());
}
TEST
(
type_prop
,
crop_and_resize_not_constant
)
{
Dimension
N
=
4
;
Dimension
W_image
=
400
;
Dimension
H_image
=
300
;
Dimension
C_image
=
3
;
Dimension
num_boxes
=
20
;
int32_t
W_crop
=
30
;
int32_t
H_crop
=
40
;
PartialShape
result_shape
{
num_boxes
,
H_crop
,
W_crop
,
C_image
};
auto
image
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
N
,
H_image
,
W_image
,
C_image
});
auto
boxes
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
PartialShape
{
num_boxes
,
4
});
auto
box_indices
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
PartialShape
{
num_boxes
});
auto
crop_shape
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
PartialShape
{
2
});
try
{
auto
crop_and_resize
=
make_shared
<
op
::
CropAndResize
>
(
image
,
boxes
,
box_indices
,
crop_shape
,
op
::
CropAndResize
::
ResizeMethod
::
bilinear
,
0
);
FAIL
()
<<
"CropAndReshape without constant crop shape should fail"
;
}
catch
(
const
NodeValidationFailure
&
error
)
{
EXPECT_HAS_SUBSTRING
(
error
.
what
(),
std
::
string
(
"crop_size must be a constant"
));
}
catch
(...)
{
FAIL
()
<<
"Deduced type check failed for unexpected reason"
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment