Commit 5dc3a1bb authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

[SPEC] Make offset input of DeformablePSROIPooling optional (#3997)

parent eb0d866d
...@@ -45,6 +45,29 @@ op::v1::DeformablePSROIPooling::DeformablePSROIPooling(const Output<Node>& input ...@@ -45,6 +45,29 @@ op::v1::DeformablePSROIPooling::DeformablePSROIPooling(const Output<Node>& input
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::v1::DeformablePSROIPooling::DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size,
const std::string mode,
int64_t spatial_bins_x,
int64_t spatial_bins_y,
float trans_std,
int64_t part_size)
: Op({input, coords})
, m_output_dim(output_dim)
, m_spatial_scale(spatial_scale)
, m_group_size(group_size)
, m_mode(mode)
, m_spatial_bins_x(spatial_bins_x)
, m_spatial_bins_y(spatial_bins_y)
, m_trans_std(trans_std)
, m_part_size(part_size)
{
constructor_validate_and_infer_types();
}
void op::v1::DeformablePSROIPooling::validate_and_infer_types() void op::v1::DeformablePSROIPooling::validate_and_infer_types()
{ {
const auto& input_et = get_input_element_type(0); const auto& input_et = get_input_element_type(0);
...@@ -108,6 +131,19 @@ shared_ptr<Node> ...@@ -108,6 +131,19 @@ shared_ptr<Node>
m_trans_std, m_trans_std,
m_part_size); m_part_size);
} }
else if (new_args.size() == 2)
{
return make_shared<v1::DeformablePSROIPooling>(new_args.at(0),
new_args.at(1),
m_output_dim,
m_spatial_scale,
m_group_size,
m_mode,
m_spatial_bins_x,
m_spatial_bins_y,
m_trans_std,
m_part_size);
}
else else
{ {
throw ngraph_error("Not supported number of DeformablePSROIPooling args"); throw ngraph_error("Not supported number of DeformablePSROIPooling args");
......
...@@ -65,6 +65,17 @@ namespace ngraph ...@@ -65,6 +65,17 @@ namespace ngraph
float trans_std = 1, float trans_std = 1,
int64_t part_size = 1); int64_t part_size = 1);
DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size = 1,
const std::string mode = "bilinear_deformable",
int64_t spatial_bins_x = 1,
int64_t spatial_bins_y = 1,
float trans_std = 1,
int64_t part_size = 1);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -1404,17 +1404,33 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -1404,17 +1404,33 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
const auto trans_std = node_js.at("trans_std").get<float>(); const auto trans_std = node_js.at("trans_std").get<float>();
const auto part_size = node_js.at("part_size").get<int64_t>(); const auto part_size = node_js.at("part_size").get<int64_t>();
node = make_shared<op::v1::DeformablePSROIPooling>(args[0], if (args.size() == 2)
args[1], {
args[2], node = make_shared<op::v1::DeformablePSROIPooling>(args[0],
output_dim, args[1],
spatial_scale, output_dim,
group_size, spatial_scale,
mode, group_size,
spatial_bins_x, mode,
spatial_bins_y, spatial_bins_x,
trans_std, spatial_bins_y,
part_size); trans_std,
part_size);
}
else
{
node = make_shared<op::v1::DeformablePSROIPooling>(args[0],
args[1],
args[2],
output_dim,
spatial_scale,
group_size,
mode,
spatial_bins_x,
spatial_bins_y,
trans_std,
part_size);
}
break; break;
} }
case OP_TYPEID::DepthToSpace_v1: case OP_TYPEID::DepthToSpace_v1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment