Commit 5dc3a1bb authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

[SPEC] Make offset input of DeformablePSROIPooling optional (#3997)

parent eb0d866d
......@@ -45,6 +45,29 @@ op::v1::DeformablePSROIPooling::DeformablePSROIPooling(const Output<Node>& input
constructor_validate_and_infer_types();
}
op::v1::DeformablePSROIPooling::DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size,
const std::string mode,
int64_t spatial_bins_x,
int64_t spatial_bins_y,
float trans_std,
int64_t part_size)
: Op({input, coords})
, m_output_dim(output_dim)
, m_spatial_scale(spatial_scale)
, m_group_size(group_size)
, m_mode(mode)
, m_spatial_bins_x(spatial_bins_x)
, m_spatial_bins_y(spatial_bins_y)
, m_trans_std(trans_std)
, m_part_size(part_size)
{
constructor_validate_and_infer_types();
}
void op::v1::DeformablePSROIPooling::validate_and_infer_types()
{
const auto& input_et = get_input_element_type(0);
......@@ -108,6 +131,19 @@ shared_ptr<Node>
m_trans_std,
m_part_size);
}
else if (new_args.size() == 2)
{
return make_shared<v1::DeformablePSROIPooling>(new_args.at(0),
new_args.at(1),
m_output_dim,
m_spatial_scale,
m_group_size,
m_mode,
m_spatial_bins_x,
m_spatial_bins_y,
m_trans_std,
m_part_size);
}
else
{
throw ngraph_error("Not supported number of DeformablePSROIPooling args");
......
......@@ -65,6 +65,17 @@ namespace ngraph
float trans_std = 1,
int64_t part_size = 1);
DeformablePSROIPooling(const Output<Node>& input,
const Output<Node>& coords,
const int64_t output_dim,
const float spatial_scale,
const int64_t group_size = 1,
const std::string mode = "bilinear_deformable",
int64_t spatial_bins_x = 1,
int64_t spatial_bins_y = 1,
float trans_std = 1,
int64_t part_size = 1);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -1404,17 +1404,33 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
const auto trans_std = node_js.at("trans_std").get<float>();
const auto part_size = node_js.at("part_size").get<int64_t>();
node = make_shared<op::v1::DeformablePSROIPooling>(args[0],
args[1],
args[2],
output_dim,
spatial_scale,
group_size,
mode,
spatial_bins_x,
spatial_bins_y,
trans_std,
part_size);
if (args.size() == 2)
{
node = make_shared<op::v1::DeformablePSROIPooling>(args[0],
args[1],
output_dim,
spatial_scale,
group_size,
mode,
spatial_bins_x,
spatial_bins_y,
trans_std,
part_size);
}
else
{
node = make_shared<op::v1::DeformablePSROIPooling>(args[0],
args[1],
args[2],
output_dim,
spatial_scale,
group_size,
mode,
spatial_bins_x,
spatial_bins_y,
trans_std,
part_size);
}
break;
}
case OP_TYPEID::DepthToSpace_v1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment