Commit a8f4f4a2 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Sang Ik Lee

[SPEC] Too restrictive data type (#4191)

* Resolved problems with too restrictive data type

* Apply suggestions from code review

Code review remarks introduced
Co-Authored-By: 's avatarTomasz Socha <tomasz.socha@intel.com>

* Code review remarks. Part.2
Co-authored-by: 's avatarTomasz Socha <tomasz.socha@intel.com>
Co-authored-by: 's avatarAdam Rogowiec <adam.osewski@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent a7706e98
......@@ -115,10 +115,9 @@ void op::v1::Broadcast::validate_and_infer_types()
// shape node should have integer data type. For now we only allow i64
auto shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
shape_et.compatible(element::Type_t::i64),
"Broadcast shape must have element type i64, but has ",
shape_et.is_integral_number(),
"Broadcast shape must be an integral number, but is: ",
shape_et);
// shape node should produce a one dimensional shape.
auto broadcast_shape_rank = get_input_partial_shape(1).rank();
NODE_VALIDATION_CHECK(this,
......@@ -131,10 +130,9 @@ void op::v1::Broadcast::validate_and_infer_types()
// axes_mapping node should have integer data type. For now we only allow i64
auto axes_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
axes_et.compatible(element::Type_t::i64),
"Broadcast axes must have element type i64, but has ",
axes_et.is_integral_number(),
"Broadcast axes must be integral numbers, but are: ",
axes_et);
// axes_mapping node should produce a one dimensional shape.
auto axes_shape_rank = get_input_partial_shape(2).rank();
NODE_VALIDATION_CHECK(this,
......
......@@ -258,8 +258,8 @@ vector<string> op::Constant::get_value_strings() const
Shape op::Constant::get_shape_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_shape = get_vector<int64_t>();
NGRAPH_CHECK(m_element_type.is_integral_number());
std::vector<int64_t> out_shape = cast_vector<int64_t>();
Shape output_shape(shape_size(m_shape));
std::transform(out_shape.begin(), out_shape.end(), output_shape.begin(), [&](const int64_t& v) {
return (v > 0) ? v : 0;
......@@ -305,8 +305,8 @@ CoordinateDiff op::Constant::get_coordinate_diff_val() const
AxisVector op::Constant::get_axis_vector_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_axis_vector = get_vector<int64_t>();
NGRAPH_CHECK(m_element_type.is_integral_number());
std::vector<int64_t> out_axis_vector = cast_vector<int64_t>();
AxisVector output_axis_vector(shape_size(m_shape));
std::transform(out_axis_vector.begin(),
out_axis_vector.end(),
......@@ -317,10 +317,10 @@ AxisVector op::Constant::get_axis_vector_val() const
AxisSet op::Constant::get_axis_set_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_axis_set = get_vector<int64_t>();
NGRAPH_CHECK(m_element_type.is_integral_number());
std::vector<int64_t> out_axis_set = cast_vector<int64_t>();
AxisSet output_axis_set;
for (auto& axis : get_vector<int64_t>())
for (auto& axis : out_axis_set)
{
output_axis_set.insert(axis > 0 ? axis : 0);
}
......
......@@ -236,7 +236,8 @@ const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const
void op::v1::ConvolutionBackpropData::set_output_shape(const Shape& shape)
{
this->input(2).replace_source_output(
op::Constant::create(element::i64, Shape{shape.size()}, shape)->output(0));
op::Constant::create(this->get_input_element_type(2), Shape{shape.size()}, shape)
->output(0));
}
void op::v1::ConvolutionBackpropData::validate_and_infer_types()
......
......@@ -35,8 +35,8 @@ op::Interpolate::Interpolate(const Output<Node>& image,
void op::Interpolate::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).compatible(element::Type_t::i64),
"output shape must have element type i64.");
get_input_element_type(1).is_integral_number(),
"output shape must be an integral number.");
set_input_is_relevant_to_shape(1);
PartialShape output_shape = PartialShape(get_input_partial_shape(0));
......@@ -51,7 +51,7 @@ void op::Interpolate::validate_and_infer_types()
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
auto out_shape = static_cast<const int64_t*>(const_shape->get_data_ptr());
auto out_shape = const_shape->cast_vector<int64_t>();
size_t i = 0;
for (auto axis : m_attrs.axes)
{
......
......@@ -37,14 +37,14 @@ void op::PriorBox::validate_and_infer_types()
// shape node should have integer data type. For now we only allow i64
auto layer_shape_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
layer_shape_et.compatible(element::Type_t::i64),
"layer shape input must have element type i64, but has ",
layer_shape_et.is_integral_number(),
"layer shape input must be an integral number, but is: ",
layer_shape_et);
auto image_shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
image_shape_et.compatible(element::Type_t::i64),
"image shape input must have element type i64, but has ",
image_shape_et.is_integral_number(),
"image shape input must be an integral number, but is: ",
image_shape_et);
auto layer_shape_rank = get_input_partial_shape(0).rank();
......
......@@ -37,14 +37,14 @@ void op::PriorBoxClustered::validate_and_infer_types()
// shape node should have integer data type. For now we only allow i64
auto layer_shape_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
layer_shape_et.compatible(element::Type_t::i64),
"layer shape input must have element type i64, but has ",
layer_shape_et.is_integral_number(),
"layer shape input must be an integral number, but is: ",
layer_shape_et);
auto image_shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
image_shape_et.compatible(element::Type_t::i64),
"image shape input must have element type i64, but has ",
image_shape_et.is_integral_number(),
"image shape input must be an integral number, but is: ",
image_shape_et);
auto layer_shape_rank = get_input_partial_shape(0).rank();
......
......@@ -222,7 +222,8 @@ const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() cons
void op::v1::GroupConvolutionBackpropData::set_output_shape(const Shape& shape)
{
this->input(2).replace_source_output(
op::Constant::create(element::i64, Shape{shape.size()}, shape)->output(0));
op::Constant::create(this->get_input_element_type(2), Shape{shape.size()}, shape)
->output(0));
}
void op::v1::GroupConvolutionBackpropData::validate_and_infer_types()
......
......@@ -44,7 +44,7 @@ NodeVector op::Squeeze::decompose_op() const
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>();
auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size());
......
......@@ -49,7 +49,7 @@ NodeVector op::Unsqueeze::decompose_op() const
// Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>();
auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape();
......
......@@ -183,7 +183,7 @@ size_t op::v1::Gather::get_axis() const
auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{
axis = const_op->get_vector<int64_t>()[0];
axis = const_op->cast_vector<int64_t>()[0];
}
if (axis < 0)
{
......
......@@ -105,8 +105,8 @@ void op::LRN::validate_and_infer_types()
const auto& axes_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
axes_type.compatible(element::Type_t::i64),
"Axes input must have element type i64 (axes type: ",
axes_type.is_integral_number(),
"Axes input must be integral numbers, but are: ",
axes_type,
").");
}
......
......@@ -153,38 +153,7 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const
const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(max_output_boxes_input->get_element_type()))
{
case element::Type_t::i8:
{
max_output_boxes = max_output_boxes_input->get_vector<int8_t>().at(0);
break;
}
case element::Type_t::i16:
{
max_output_boxes = max_output_boxes_input->get_vector<int16_t>().at(0);
break;
}
case element::Type_t::i32:
{
max_output_boxes = max_output_boxes_input->get_vector<int32_t>().at(0);
break;
}
case element::Type_t::i64:
{
max_output_boxes = max_output_boxes_input->get_vector<int64_t>().at(0);
break;
}
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);
return max_output_boxes;
}
......@@ -206,7 +206,7 @@ CoordinateDiff op::v1::Pad::get_pads_begin() const
CoordinateDiff pads_begin_coord{};
if (auto pads_begin_const = as_type_ptr<op::Constant>(pads_begin_node))
{
pads_begin_coord = pads_begin_const->get_vector<ptrdiff_t>();
pads_begin_coord = pads_begin_const->cast_vector<ptrdiff_t>();
}
return pads_begin_coord;
}
......@@ -217,7 +217,7 @@ CoordinateDiff op::v1::Pad::get_pads_end() const
CoordinateDiff pads_end_coord{};
if (auto pads_end_const = as_type_ptr<op::Constant>(pads_end_node))
{
pads_end_coord = pads_end_const->get_vector<ptrdiff_t>();
pads_end_coord = pads_end_const->cast_vector<ptrdiff_t>();
}
return pads_end_coord;
}
......@@ -252,14 +252,14 @@ void op::v1::Pad::validate_and_infer_types()
}
NODE_VALIDATION_CHECK(this,
pads_begin_element_type.compatible(element::Type_t::i64),
"pads_begin must be type i64 (axes type: ",
pads_begin_element_type.is_integral_number(),
"pads_begin must be an integral number, but is: ",
pads_begin_element_type,
").");
NODE_VALIDATION_CHECK(this,
pads_end_element_type.compatible(element::Type_t::i64),
"pads_end must be type i64 (axes type: ",
pads_end_element_type.is_integral_number(),
"pads_end must be an integral number, but is: ",
pads_end_element_type,
").");
......
......@@ -161,7 +161,7 @@ void op::v1::Reshape::validate_and_infer_types()
auto pattern_et = get_input_element_type(1);
// check data types
NODE_VALIDATION_CHECK(
this, pattern_et.compatible(element::Type_t::i64), "Pattern must have element type i64.");
this, pattern_et.is_integral_number(), "Pattern must be an integral number.");
// check shapes
const PartialShape& pattern_shape = get_input_partial_shape(1);
......@@ -176,7 +176,7 @@ void op::v1::Reshape::validate_and_infer_types()
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>();
std::vector<int64_t> out_shape_val = const_shape->cast_vector<int64_t>();
NODE_VALIDATION_CHECK(this,
std::none_of(out_shape_val.begin(),
out_shape_val.end(),
......
......@@ -71,12 +71,12 @@ void op::v1::StridedSlice::validate_and_infer_types()
const auto& begin_mask_et = get_input_element_type(1);
const auto& end_mask_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
begin_mask_et.compatible(element::Type_t::i64),
"Begin mask must have element type i64, but has ",
begin_mask_et.is_integral_number(),
"Begin mask must be an integral number, but is: ",
begin_mask_et);
NODE_VALIDATION_CHECK(this,
end_mask_et.compatible(element::Type_t::i64),
"End mask must have element type i64, but has ",
end_mask_et.is_integral_number(),
"End mask must be an integral number, but is: ",
end_mask_et);
auto are_mask_elem_in_range = [](size_t e) { return e == 0 || e == 1; };
......@@ -136,9 +136,9 @@ void op::v1::StridedSlice::validate_and_infer_types()
get_input_element_type(0),
infer_slice_shape(this,
get_input_partial_shape(0),
begin_const->get_vector<int64_t>(),
end_const->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
begin_const->cast_vector<int64_t>(),
end_const->cast_vector<int64_t>(),
strides->cast_vector<int64_t>(),
convert_mask_to_axis_set(get_begin_mask()),
convert_mask_to_axis_set(get_end_mask()),
convert_mask_to_axis_set(get_new_axis_mask()),
......
......@@ -47,7 +47,7 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{
AxisSet reduction_axes;
auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>();
as_type<op::Constant>(input_value(1).get_node())->cast_vector<int64_t>();
for (auto axis : reduction_axes_val)
{
try
......
......@@ -68,7 +68,7 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank);
auto split_lengths =
as_type_ptr<op::Constant>(split_lengths_input)->get_vector<int64_t>();
as_type_ptr<op::Constant>(split_lengths_input)->cast_vector<int64_t>();
// Adjust split lengths in case of negatives
size_t sum_of_splits = 0;
int64_t negative_one = -1;
......
......@@ -282,6 +282,11 @@ bool element::Type::is_real() const
return get_type_info_map().at(m_type).m_is_real;
}
bool element::Type::is_integral_number() const
{
return is_integral() && (m_type != element::boolean);
}
bool element::Type::is_signed() const
{
return get_type_info_map().at(m_type).m_is_signed;
......
......@@ -89,6 +89,7 @@ namespace ngraph
// TODO: We may want to revisit this definition when we do a more general cleanup of
// element types:
bool is_integral() const { return !is_real(); }
bool is_integral_number() const;
bool is_signed() const;
bool is_quantized() const;
size_t bitwidth() const;
......
......@@ -381,12 +381,12 @@ TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong)
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect shape element type not i64";
FAIL() << "Broadcast: did not detect shape element type not integral number";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast shape must have element type i64"));
std::string("Broadcast shape must be an integral number"));
}
catch (...)
{
......@@ -404,12 +404,12 @@ TEST(type_prop, broadcast_v1_axes_et_wrong)
try
{
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect axes element type not i64";
FAIL() << "Broadcast: did not detect axes element type not integral numbers";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast axes must have element type i64"));
std::string("Broadcast axes must be integral numbers, but are:"));
}
catch (...)
{
......
......@@ -571,7 +571,8 @@ TEST(type_prop, pad_v1_arg_pads_begin_incompatible_type)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("pads_begin must be type i64 (axes type:"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("pads_begin must be an integral number, but is:"));
}
catch (...)
{
......@@ -594,7 +595,8 @@ TEST(type_prop, pad_v1_arg_pads_end_incompatible_type)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("pads_end must be type i64 (axes type:"));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("pads_end must be an integral number, but is:"));
}
catch (...)
{
......
......@@ -26,7 +26,7 @@ using namespace ngraph;
TEST(type_prop, strided_slice_begin_incorrect_type)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto begin = make_shared<op::Parameter>(element::i32, Shape{4});
auto begin = make_shared<op::Parameter>(element::f16, Shape{4});
auto end = make_shared<op::Parameter>(element::i64, Shape{4});
try
{
......@@ -37,7 +37,7 @@ TEST(type_prop, strided_slice_begin_incorrect_type)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Begin mask must have element type i64"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Begin mask must be an integral number"));
}
catch (...)
{
......@@ -49,7 +49,7 @@ TEST(type_prop, strided_slice_end_incorrect_type)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto begin = make_shared<op::Parameter>(element::i64, Shape{4});
auto end = make_shared<op::Parameter>(element::i32, Shape{4});
auto end = make_shared<op::Parameter>(element::boolean, Shape{4});
try
{
auto strided_slice = make_shared<op::v1::StridedSlice>(
......@@ -59,7 +59,7 @@ TEST(type_prop, strided_slice_end_incorrect_type)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("End mask must have element type i64"));
EXPECT_HAS_SUBSTRING(error.what(), std::string("End mask must be an integral number"));
}
catch (...)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment