Unverified Commit 90c70dde authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

CropAndResize op (#3893)

* Stub for CropAndResize

* Cut and pasteo

* Need a cast
parent 1ac3e5c7
......@@ -128,6 +128,8 @@ set (SRC
op/cos.hpp
op/cosh.cpp
op/cosh.hpp
op/crop_and_resize.cpp
op/crop_and_resize.hpp
op/dequantize.cpp
op/dequantize.hpp
op/divide.cpp
......
......@@ -81,6 +81,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
......
......@@ -410,6 +410,9 @@ namespace ngraph
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input<Node> input(size_t input_index);
// Simplify migration from 0.25.1
Output<Node> input_value(size_t input_index) const;
/// \return A handle to the `input_index`th input of this node.
/// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
Input<const Node> input(size_t input_index) const;
......@@ -650,6 +653,12 @@ namespace ngraph
return Input<const Node>(this, input_index);
}
// Simplify migration from 0.25.1
inline Output<Node> Node::input_value(size_t input_index) const
{
return input(input_index).get_source_output();
}
inline Output<Node> Node::output(size_t output_index)
{
if (output_index >= m_outputs.size())
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <vector>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/crop_and_resize.hpp"
using namespace std;
using namespace ngraph;
const string op::CropAndResize::type_name{"CropAndResize"};
op::CropAndResize::CropAndResize(const Output<Node>& image,
const Output<Node>& boxes,
const Output<Node>& box_indices,
const Output<Node>& crop_size,
ResizeMethod resize_method,
float extrapolation_value)
: Op({image, boxes, box_indices, crop_size})
, m_resize_method(resize_method)
, m_extrapolation_value(extrapolation_value)
{
constructor_validate_and_infer_types();
}
void op::CropAndResize::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this, get_input_size() == 4);
NODE_VALIDATION_CHECK(
this, m_resize_method != ResizeMethod::unspecified, "Resize method not specified");
auto image = input_value(0);
auto& image_et = image.get_element_type();
// Will override if we can determine the shape
set_output_type(0, image_et, {});
auto image_shape = image.get_partial_shape();
Dimension image_depth;
if (image_shape.is_static())
{
NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(image_shape.rank()) == 4, "Image must be NHWC");
image_depth = image_shape[3];
}
auto boxes = input_value(1);
auto boxes_shape = boxes.get_partial_shape();
if (boxes_shape.is_static())
{
auto boxes_rank = boxes_shape.rank();
NODE_VALIDATION_CHECK(this, static_cast<int64_t>(boxes_rank) == 2, "Boxes must be 2d");
auto boxes_dim1 = boxes_shape[1];
NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(boxes_dim1) == 4, "Second boxes dimension must be 4");
}
NODE_VALIDATION_CHECK(
this, boxes.get_element_type().is_real(), "Boxes must be real values in [0, 1]");
auto box_indices = input_value(2);
auto box_indices_shape = box_indices.get_partial_shape();
Dimension num_boxes;
if (box_indices_shape.is_static())
{
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(box_indices_shape.rank()) == 1,
"Box indices must have rank 1");
num_boxes = box_indices_shape[0];
}
NODE_VALIDATION_CHECK(
this, box_indices.get_element_type().is_integral(), "Box indices must be integers");
auto crop_size = input_value(3);
auto crop_size_shape = crop_size.get_partial_shape();
auto crop_size_rank = crop_size_shape.rank();
NODE_VALIDATION_CHECK(this,
crop_size_shape.is_static() || crop_size_rank.is_dynamic(),
"Dynamic crop_size not supported");
NODE_VALIDATION_CHECK(
this, static_cast<int64_t>(crop_size_rank) == 1, "crop_size must be a vector");
NODE_VALIDATION_CHECK(this,
static_cast<int64_t>(crop_size_shape[0]) == 2,
"crop_size must be a vector of length 2");
auto& crop_size_et = crop_size.get_element_type();
NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral");
auto crop_size_node = crop_size.get_node_shared_ptr();
NODE_VALIDATION_CHECK(this, crop_size_node->is_constant(), "crop_size must be a constant");
auto crop_size_const = static_pointer_cast<op::Constant>(crop_size_node);
if (crop_size_et == element::i8)
{
auto v = crop_size_const->get_vector<int8_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u8)
{
auto v = crop_size_const->get_vector<uint8_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::i16)
{
auto v = crop_size_const->get_vector<int16_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u16)
{
auto v = crop_size_const->get_vector<uint16_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::i32)
{
auto v = crop_size_const->get_vector<int32_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u32)
{
auto v = crop_size_const->get_vector<uint32_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::i64)
{
auto v = crop_size_const->get_vector<int64_t>();
set_output_type(0, image_et, {num_boxes, v[0], v[1], image_depth});
}
else if (crop_size_et == element::u64)
{
auto v = crop_size_const->get_vector<uint64_t>();
set_output_type(
0,
image_et,
{num_boxes, static_cast<int64_t>(v[0]), static_cast<int64_t>(v[1]), image_depth});
}
else
{
NODE_VALIDATION_CHECK(this, false, "Unknown integral type for crop size");
}
}
shared_ptr<Node> op::CropAndResize::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CropAndResize>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
m_resize_method,
m_extrapolation_value);
}
static const vector<pair<string, op::CropAndResize::ResizeMethod>>& get_resize_pairs()
{
static vector<pair<string, op::CropAndResize::ResizeMethod>> pairs{
{"unspecified", op::CropAndResize::ResizeMethod::unspecified},
{"bilinear", op::CropAndResize::ResizeMethod::bilinear},
{"nearest", op::CropAndResize::ResizeMethod::nearest}};
return pairs;
}
const string& ngraph::as_string(op::CropAndResize::ResizeMethod resize_method)
{
for (auto& p : get_resize_pairs())
{
if (p.second == resize_method)
{
return p.first;
}
}
throw ngraph_error("Internal error: unhandled resize method");
}
namespace ngraph
{
template <>
op::CropAndResize::ResizeMethod as_type<op::CropAndResize::ResizeMethod>(const std::string& s)
{
for (auto& p : get_resize_pairs())
{
if (p.first == s)
{
return p.second;
}
}
throw ngraph_error("Internal error: unhandled resize method name");
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
class CropAndResize : public Op
{
public:
enum class ResizeMethod
{
unspecified,
bilinear,
nearest
};
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a crop and resize operation.
CropAndResize() = default;
/// \param image [N, H, W, C]
/// \param boxes [NUM_BOXES, 4] where boxes[box] is [y1, x1, y2, x2] each in [0, 1]
/// \param box_indices [NUM_BOXES] in [0, N)
/// \param crop_size [crop_height, crop_width]
CropAndResize(const Output<Node>& image,
const Output<Node>& boxes,
const Output<Node>& box_indices,
const Output<Node>& crop_size,
ResizeMethod resize_method,
float extrapolation_value);
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
ResizeMethod get_resize_method() const { return m_resize_method; }
void set_resize_method(ResizeMethod resize_method) { m_resize_method = resize_method; }
float get_extrapolation_value() const { return m_extrapolation_value; }
void set_extrapolation_value(float extrapolation_value)
{
m_extrapolation_value = extrapolation_value;
}
private:
ResizeMethod m_resize_method{ResizeMethod::unspecified};
float m_extrapolation_value{0};
};
}
const std::string& as_string(op::CropAndResize::ResizeMethod);
template <typename T>
T as_type(const std::string&);
template <>
op::CropAndResize::ResizeMethod as_type<op::CropAndResize::ResizeMethod>(const std::string&);
}
......@@ -80,6 +80,7 @@ NGRAPH_OP(ConvolutionBackpropData, ngraph::op)
NGRAPH_OP(ConvolutionBackpropFilters, ngraph::op)
NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP(Cosh, ngraph::op)
NGRAPH_OP(CropAndResize, ngraph::op)
NGRAPH_OP(Dequantize, ngraph::op)
NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(Dot, ngraph::op)
......
......@@ -707,6 +707,11 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::CropAndResize:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::Dequantize:
{
const op::Dequantize* dequantize = static_cast<const op::Dequantize*>(&node);
......
......@@ -45,6 +45,7 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
......@@ -1117,6 +1118,15 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Cosh>(args[0]);
break;
}
case OP_TYPEID::CropAndResize:
{
auto resize_method =
as_type<op::CropAndResize::ResizeMethod>(node_js.at("resize_method").get<string>());
auto extrapolation_value = node_js.at("extrapolation_value").get<float>();
node = make_shared<op::CropAndResize>(
args[0], args[1], args[2], args[3], resize_method, extrapolation_value);
break;
}
case OP_TYPEID::DepthToSpace:
{
auto block_size = node_js.at("block_size").get<size_t>();
......@@ -2363,6 +2373,13 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Cosh: { break;
}
case OP_TYPEID::CropAndResize:
{
auto tmp = static_cast<const op::CropAndResize*>(&n);
node["resize_method"] = as_string(tmp->get_resize_method());
node["extrapolation_value"] = tmp->get_extrapolation_value();
break;
}
case OP_TYPEID::Dequantize:
{
auto tmp = dynamic_cast<const op::Dequantize*>(&n);
......
......@@ -89,6 +89,7 @@ set(SRC
type_prop/convert.cpp
type_prop/convolution.cpp
type_prop/convolution_bias.cpp
type_prop/crop_and_resize.cpp
type_prop/depth_to_space.cpp
type_prop/dequantize.cpp
type_prop/dot.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, crop_and_resize_valid)
{
Dimension N = 4;
Dimension W_image = 400;
Dimension H_image = 300;
Dimension C_image = 3;
Dimension num_boxes = 20;
int32_t W_crop = 30;
int32_t H_crop = 40;
PartialShape result_shape{num_boxes, H_crop, W_crop, C_image};
auto image =
make_shared<op::Parameter>(element::f32, PartialShape{N, H_image, W_image, C_image});
auto boxes = make_shared<op::Parameter>(element::f32, PartialShape{num_boxes, 4});
auto box_indices = make_shared<op::Parameter>(element::i32, PartialShape{num_boxes});
auto crop_shape = op::Constant::create(element::i32, Shape{2}, {H_crop, W_crop});
auto crop_and_resize = make_shared<op::CropAndResize>(
image, boxes, box_indices, crop_shape, op::CropAndResize::ResizeMethod::bilinear, 0);
auto result = crop_and_resize->output(0);
ASSERT_EQ(result.get_shape(), result_shape.to_shape());
ASSERT_EQ(result.get_element_type(), image->output(0).get_element_type());
}
TEST(type_prop, crop_and_resize_not_constant)
{
Dimension N = 4;
Dimension W_image = 400;
Dimension H_image = 300;
Dimension C_image = 3;
Dimension num_boxes = 20;
int32_t W_crop = 30;
int32_t H_crop = 40;
PartialShape result_shape{num_boxes, H_crop, W_crop, C_image};
auto image =
make_shared<op::Parameter>(element::f32, PartialShape{N, H_image, W_image, C_image});
auto boxes = make_shared<op::Parameter>(element::f32, PartialShape{num_boxes, 4});
auto box_indices = make_shared<op::Parameter>(element::i32, PartialShape{num_boxes});
auto crop_shape = make_shared<op::Parameter>(element::i32, PartialShape{2});
try
{
auto crop_and_resize = make_shared<op::CropAndResize>(
image, boxes, box_indices, crop_shape, op::CropAndResize::ResizeMethod::bilinear, 0);
FAIL() << "CropAndReshape without constant crop shape should fail";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("crop_size must be a constant"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment