Commit a8f4f4a2 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Sang Ik Lee

[SPEC] Too restrictive data type (#4191)

* Resolved problems with too restrictive data type

* Apply suggestions from code review

Code review remarks introduced
Co-Authored-By: 's avatarTomasz Socha <tomasz.socha@intel.com>

* Code review remarks. Part.2
Co-authored-by: 's avatarTomasz Socha <tomasz.socha@intel.com>
Co-authored-by: 's avatarAdam Rogowiec <adam.osewski@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent a7706e98
...@@ -115,10 +115,9 @@ void op::v1::Broadcast::validate_and_infer_types() ...@@ -115,10 +115,9 @@ void op::v1::Broadcast::validate_and_infer_types()
// shape node should have integer data type. For now we only allow i64 // shape node should have integer data type. For now we only allow i64
auto shape_et = get_input_element_type(1); auto shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
shape_et.compatible(element::Type_t::i64), shape_et.is_integral_number(),
"Broadcast shape must have element type i64, but has ", "Broadcast shape must be an integral number, but is: ",
shape_et); shape_et);
// shape node should produce a one dimensional shape. // shape node should produce a one dimensional shape.
auto broadcast_shape_rank = get_input_partial_shape(1).rank(); auto broadcast_shape_rank = get_input_partial_shape(1).rank();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -131,10 +130,9 @@ void op::v1::Broadcast::validate_and_infer_types() ...@@ -131,10 +130,9 @@ void op::v1::Broadcast::validate_and_infer_types()
// axes_mapping node should have integer data type. For now we only allow i64 // axes_mapping node should have integer data type. For now we only allow i64
auto axes_et = get_input_element_type(2); auto axes_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axes_et.compatible(element::Type_t::i64), axes_et.is_integral_number(),
"Broadcast axes must have element type i64, but has ", "Broadcast axes must be integral numbers, but are: ",
axes_et); axes_et);
// axes_mapping node should produce a one dimensional shape. // axes_mapping node should produce a one dimensional shape.
auto axes_shape_rank = get_input_partial_shape(2).rank(); auto axes_shape_rank = get_input_partial_shape(2).rank();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -258,8 +258,8 @@ vector<string> op::Constant::get_value_strings() const ...@@ -258,8 +258,8 @@ vector<string> op::Constant::get_value_strings() const
Shape op::Constant::get_shape_val() const Shape op::Constant::get_shape_val() const
{ {
NGRAPH_CHECK(m_element_type == element::i64); NGRAPH_CHECK(m_element_type.is_integral_number());
std::vector<int64_t> out_shape = get_vector<int64_t>(); std::vector<int64_t> out_shape = cast_vector<int64_t>();
Shape output_shape(shape_size(m_shape)); Shape output_shape(shape_size(m_shape));
std::transform(out_shape.begin(), out_shape.end(), output_shape.begin(), [&](const int64_t& v) { std::transform(out_shape.begin(), out_shape.end(), output_shape.begin(), [&](const int64_t& v) {
return (v > 0) ? v : 0; return (v > 0) ? v : 0;
...@@ -305,8 +305,8 @@ CoordinateDiff op::Constant::get_coordinate_diff_val() const ...@@ -305,8 +305,8 @@ CoordinateDiff op::Constant::get_coordinate_diff_val() const
AxisVector op::Constant::get_axis_vector_val() const AxisVector op::Constant::get_axis_vector_val() const
{ {
NGRAPH_CHECK(m_element_type == element::i64); NGRAPH_CHECK(m_element_type.is_integral_number());
std::vector<int64_t> out_axis_vector = get_vector<int64_t>(); std::vector<int64_t> out_axis_vector = cast_vector<int64_t>();
AxisVector output_axis_vector(shape_size(m_shape)); AxisVector output_axis_vector(shape_size(m_shape));
std::transform(out_axis_vector.begin(), std::transform(out_axis_vector.begin(),
out_axis_vector.end(), out_axis_vector.end(),
...@@ -317,10 +317,10 @@ AxisVector op::Constant::get_axis_vector_val() const ...@@ -317,10 +317,10 @@ AxisVector op::Constant::get_axis_vector_val() const
AxisSet op::Constant::get_axis_set_val() const AxisSet op::Constant::get_axis_set_val() const
{ {
NGRAPH_CHECK(m_element_type == element::i64); NGRAPH_CHECK(m_element_type.is_integral_number());
std::vector<int64_t> out_axis_set = get_vector<int64_t>(); std::vector<int64_t> out_axis_set = cast_vector<int64_t>();
AxisSet output_axis_set; AxisSet output_axis_set;
for (auto& axis : get_vector<int64_t>()) for (auto& axis : out_axis_set)
{ {
output_axis_set.insert(axis > 0 ? axis : 0); output_axis_set.insert(axis > 0 ? axis : 0);
} }
......
...@@ -236,7 +236,8 @@ const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const ...@@ -236,7 +236,8 @@ const PartialShape op::v1::ConvolutionBackpropData::get_output_shape() const
void op::v1::ConvolutionBackpropData::set_output_shape(const Shape& shape) void op::v1::ConvolutionBackpropData::set_output_shape(const Shape& shape)
{ {
this->input(2).replace_source_output( this->input(2).replace_source_output(
op::Constant::create(element::i64, Shape{shape.size()}, shape)->output(0)); op::Constant::create(this->get_input_element_type(2), Shape{shape.size()}, shape)
->output(0));
} }
void op::v1::ConvolutionBackpropData::validate_and_infer_types() void op::v1::ConvolutionBackpropData::validate_and_infer_types()
......
...@@ -35,8 +35,8 @@ op::Interpolate::Interpolate(const Output<Node>& image, ...@@ -35,8 +35,8 @@ op::Interpolate::Interpolate(const Output<Node>& image,
void op::Interpolate::validate_and_infer_types() void op::Interpolate::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
get_input_element_type(1).compatible(element::Type_t::i64), get_input_element_type(1).is_integral_number(),
"output shape must have element type i64."); "output shape must be an integral number.");
set_input_is_relevant_to_shape(1); set_input_is_relevant_to_shape(1);
PartialShape output_shape = PartialShape(get_input_partial_shape(0)); PartialShape output_shape = PartialShape(get_input_partial_shape(0));
...@@ -51,7 +51,7 @@ void op::Interpolate::validate_and_infer_types() ...@@ -51,7 +51,7 @@ void op::Interpolate::validate_and_infer_types()
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr())) if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
auto out_shape = static_cast<const int64_t*>(const_shape->get_data_ptr()); auto out_shape = const_shape->cast_vector<int64_t>();
size_t i = 0; size_t i = 0;
for (auto axis : m_attrs.axes) for (auto axis : m_attrs.axes)
{ {
......
...@@ -37,14 +37,14 @@ void op::PriorBox::validate_and_infer_types() ...@@ -37,14 +37,14 @@ void op::PriorBox::validate_and_infer_types()
// shape node should have integer data type. For now we only allow i64 // shape node should have integer data type. For now we only allow i64
auto layer_shape_et = get_input_element_type(0); auto layer_shape_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
layer_shape_et.compatible(element::Type_t::i64), layer_shape_et.is_integral_number(),
"layer shape input must have element type i64, but has ", "layer shape input must be an integral number, but is: ",
layer_shape_et); layer_shape_et);
auto image_shape_et = get_input_element_type(1); auto image_shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
image_shape_et.compatible(element::Type_t::i64), image_shape_et.is_integral_number(),
"image shape input must have element type i64, but has ", "image shape input must be an integral number, but is: ",
image_shape_et); image_shape_et);
auto layer_shape_rank = get_input_partial_shape(0).rank(); auto layer_shape_rank = get_input_partial_shape(0).rank();
......
...@@ -37,14 +37,14 @@ void op::PriorBoxClustered::validate_and_infer_types() ...@@ -37,14 +37,14 @@ void op::PriorBoxClustered::validate_and_infer_types()
// shape node should have integer data type. For now we only allow i64 // shape node should have integer data type. For now we only allow i64
auto layer_shape_et = get_input_element_type(0); auto layer_shape_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
layer_shape_et.compatible(element::Type_t::i64), layer_shape_et.is_integral_number(),
"layer shape input must have element type i64, but has ", "layer shape input must be an integral number, but is: ",
layer_shape_et); layer_shape_et);
auto image_shape_et = get_input_element_type(1); auto image_shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
image_shape_et.compatible(element::Type_t::i64), image_shape_et.is_integral_number(),
"image shape input must have element type i64, but has ", "image shape input must be an integral number, but is: ",
image_shape_et); image_shape_et);
auto layer_shape_rank = get_input_partial_shape(0).rank(); auto layer_shape_rank = get_input_partial_shape(0).rank();
......
...@@ -222,7 +222,8 @@ const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() cons ...@@ -222,7 +222,8 @@ const PartialShape op::v1::GroupConvolutionBackpropData::get_output_shape() cons
void op::v1::GroupConvolutionBackpropData::set_output_shape(const Shape& shape) void op::v1::GroupConvolutionBackpropData::set_output_shape(const Shape& shape)
{ {
this->input(2).replace_source_output( this->input(2).replace_source_output(
op::Constant::create(element::i64, Shape{shape.size()}, shape)->output(0)); op::Constant::create(this->get_input_element_type(2), Shape{shape.size()}, shape)
->output(0));
} }
void op::v1::GroupConvolutionBackpropData::validate_and_infer_types() void op::v1::GroupConvolutionBackpropData::validate_and_infer_types()
......
...@@ -44,7 +44,7 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -44,7 +44,7 @@ NodeVector op::Squeeze::decompose_op() const
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
std::vector<uint64_t> axes_to_squeeze(data_shape.size()); std::vector<uint64_t> axes_to_squeeze(data_shape.size());
......
...@@ -49,7 +49,7 @@ NodeVector op::Unsqueeze::decompose_op() const ...@@ -49,7 +49,7 @@ NodeVector op::Unsqueeze::decompose_op() const
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = as_type_ptr<op::Constant>(axes_node); auto axes_constant = as_type_ptr<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->cast_vector<size_t>();
auto data_shape = data.get_shape(); auto data_shape = data.get_shape();
......
...@@ -183,7 +183,7 @@ size_t op::v1::Gather::get_axis() const ...@@ -183,7 +183,7 @@ size_t op::v1::Gather::get_axis() const
auto axes_input_node = input_value(AXIS).get_node_shared_ptr(); auto axes_input_node = input_value(AXIS).get_node_shared_ptr();
if (auto const_op = as_type_ptr<op::Constant>(axes_input_node)) if (auto const_op = as_type_ptr<op::Constant>(axes_input_node))
{ {
axis = const_op->get_vector<int64_t>()[0]; axis = const_op->cast_vector<int64_t>()[0];
} }
if (axis < 0) if (axis < 0)
{ {
......
...@@ -105,8 +105,8 @@ void op::LRN::validate_and_infer_types() ...@@ -105,8 +105,8 @@ void op::LRN::validate_and_infer_types()
const auto& axes_type = get_input_element_type(1); const auto& axes_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
axes_type.compatible(element::Type_t::i64), axes_type.is_integral_number(),
"Axes input must have element type i64 (axes type: ", "Axes input must be integral numbers, but are: ",
axes_type, axes_type,
")."); ").");
} }
......
...@@ -153,38 +153,7 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const ...@@ -153,38 +153,7 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const
const auto max_output_boxes_input = const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr()); as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(max_output_boxes_input->get_element_type()))
{
case element::Type_t::i8:
{
max_output_boxes = max_output_boxes_input->get_vector<int8_t>().at(0);
break;
}
case element::Type_t::i16:
{
max_output_boxes = max_output_boxes_input->get_vector<int16_t>().at(0);
break;
}
case element::Type_t::i32:
{
max_output_boxes = max_output_boxes_input->get_vector<int32_t>().at(0);
break;
}
case element::Type_t::i64:
{
max_output_boxes = max_output_boxes_input->get_vector<int64_t>().at(0);
break;
}
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return max_output_boxes; return max_output_boxes;
} }
...@@ -206,7 +206,7 @@ CoordinateDiff op::v1::Pad::get_pads_begin() const ...@@ -206,7 +206,7 @@ CoordinateDiff op::v1::Pad::get_pads_begin() const
CoordinateDiff pads_begin_coord{}; CoordinateDiff pads_begin_coord{};
if (auto pads_begin_const = as_type_ptr<op::Constant>(pads_begin_node)) if (auto pads_begin_const = as_type_ptr<op::Constant>(pads_begin_node))
{ {
pads_begin_coord = pads_begin_const->get_vector<ptrdiff_t>(); pads_begin_coord = pads_begin_const->cast_vector<ptrdiff_t>();
} }
return pads_begin_coord; return pads_begin_coord;
} }
...@@ -217,7 +217,7 @@ CoordinateDiff op::v1::Pad::get_pads_end() const ...@@ -217,7 +217,7 @@ CoordinateDiff op::v1::Pad::get_pads_end() const
CoordinateDiff pads_end_coord{}; CoordinateDiff pads_end_coord{};
if (auto pads_end_const = as_type_ptr<op::Constant>(pads_end_node)) if (auto pads_end_const = as_type_ptr<op::Constant>(pads_end_node))
{ {
pads_end_coord = pads_end_const->get_vector<ptrdiff_t>(); pads_end_coord = pads_end_const->cast_vector<ptrdiff_t>();
} }
return pads_end_coord; return pads_end_coord;
} }
...@@ -252,14 +252,14 @@ void op::v1::Pad::validate_and_infer_types() ...@@ -252,14 +252,14 @@ void op::v1::Pad::validate_and_infer_types()
} }
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
pads_begin_element_type.compatible(element::Type_t::i64), pads_begin_element_type.is_integral_number(),
"pads_begin must be type i64 (axes type: ", "pads_begin must be an integral number, but is: ",
pads_begin_element_type, pads_begin_element_type,
")."); ").");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
pads_end_element_type.compatible(element::Type_t::i64), pads_end_element_type.is_integral_number(),
"pads_end must be type i64 (axes type: ", "pads_end must be an integral number, but is: ",
pads_end_element_type, pads_end_element_type,
")."); ").");
......
...@@ -161,7 +161,7 @@ void op::v1::Reshape::validate_and_infer_types() ...@@ -161,7 +161,7 @@ void op::v1::Reshape::validate_and_infer_types()
auto pattern_et = get_input_element_type(1); auto pattern_et = get_input_element_type(1);
// check data types // check data types
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, pattern_et.compatible(element::Type_t::i64), "Pattern must have element type i64."); this, pattern_et.is_integral_number(), "Pattern must be an integral number.");
// check shapes // check shapes
const PartialShape& pattern_shape = get_input_partial_shape(1); const PartialShape& pattern_shape = get_input_partial_shape(1);
...@@ -176,7 +176,7 @@ void op::v1::Reshape::validate_and_infer_types() ...@@ -176,7 +176,7 @@ void op::v1::Reshape::validate_and_infer_types()
if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr())) if (auto const_shape = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()))
{ {
std::vector<int64_t> out_shape_val = const_shape->get_vector<int64_t>(); std::vector<int64_t> out_shape_val = const_shape->cast_vector<int64_t>();
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
std::none_of(out_shape_val.begin(), std::none_of(out_shape_val.begin(),
out_shape_val.end(), out_shape_val.end(),
......
...@@ -71,12 +71,12 @@ void op::v1::StridedSlice::validate_and_infer_types() ...@@ -71,12 +71,12 @@ void op::v1::StridedSlice::validate_and_infer_types()
const auto& begin_mask_et = get_input_element_type(1); const auto& begin_mask_et = get_input_element_type(1);
const auto& end_mask_et = get_input_element_type(2); const auto& end_mask_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
begin_mask_et.compatible(element::Type_t::i64), begin_mask_et.is_integral_number(),
"Begin mask must have element type i64, but has ", "Begin mask must be an integral number, but is: ",
begin_mask_et); begin_mask_et);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
end_mask_et.compatible(element::Type_t::i64), end_mask_et.is_integral_number(),
"End mask must have element type i64, but has ", "End mask must be an integral number, but is: ",
end_mask_et); end_mask_et);
auto are_mask_elem_in_range = [](size_t e) { return e == 0 || e == 1; }; auto are_mask_elem_in_range = [](size_t e) { return e == 0 || e == 1; };
...@@ -136,9 +136,9 @@ void op::v1::StridedSlice::validate_and_infer_types() ...@@ -136,9 +136,9 @@ void op::v1::StridedSlice::validate_and_infer_types()
get_input_element_type(0), get_input_element_type(0),
infer_slice_shape(this, infer_slice_shape(this,
get_input_partial_shape(0), get_input_partial_shape(0),
begin_const->get_vector<int64_t>(), begin_const->cast_vector<int64_t>(),
end_const->get_vector<int64_t>(), end_const->cast_vector<int64_t>(),
strides->get_vector<int64_t>(), strides->cast_vector<int64_t>(),
convert_mask_to_axis_set(get_begin_mask()), convert_mask_to_axis_set(get_begin_mask()),
convert_mask_to_axis_set(get_end_mask()), convert_mask_to_axis_set(get_end_mask()),
convert_mask_to_axis_set(get_new_axis_mask()), convert_mask_to_axis_set(get_new_axis_mask()),
......
...@@ -47,7 +47,7 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types() ...@@ -47,7 +47,7 @@ void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{ {
AxisSet reduction_axes; AxisSet reduction_axes;
auto reduction_axes_val = auto reduction_axes_val =
as_type<op::Constant>(input_value(1).get_node())->get_vector<int64_t>(); as_type<op::Constant>(input_value(1).get_node())->cast_vector<int64_t>();
for (auto axis : reduction_axes_val) for (auto axis : reduction_axes_val)
{ {
try try
......
...@@ -68,7 +68,7 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types() ...@@ -68,7 +68,7 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank); int64_t axis = ngraph::normalize_axis(this, axis_val, data_rank);
auto split_lengths = auto split_lengths =
as_type_ptr<op::Constant>(split_lengths_input)->get_vector<int64_t>(); as_type_ptr<op::Constant>(split_lengths_input)->cast_vector<int64_t>();
// Adjust split lengths in case of negatives // Adjust split lengths in case of negatives
size_t sum_of_splits = 0; size_t sum_of_splits = 0;
int64_t negative_one = -1; int64_t negative_one = -1;
......
...@@ -282,6 +282,11 @@ bool element::Type::is_real() const ...@@ -282,6 +282,11 @@ bool element::Type::is_real() const
return get_type_info_map().at(m_type).m_is_real; return get_type_info_map().at(m_type).m_is_real;
} }
bool element::Type::is_integral_number() const
{
return is_integral() && (m_type != element::boolean);
}
bool element::Type::is_signed() const bool element::Type::is_signed() const
{ {
return get_type_info_map().at(m_type).m_is_signed; return get_type_info_map().at(m_type).m_is_signed;
......
...@@ -89,6 +89,7 @@ namespace ngraph ...@@ -89,6 +89,7 @@ namespace ngraph
// TODO: We may want to revisit this definition when we do a more general cleanup of // TODO: We may want to revisit this definition when we do a more general cleanup of
// element types: // element types:
bool is_integral() const { return !is_real(); } bool is_integral() const { return !is_real(); }
bool is_integral_number() const;
bool is_signed() const; bool is_signed() const;
bool is_quantized() const; bool is_quantized() const;
size_t bitwidth() const; size_t bitwidth() const;
......
...@@ -381,12 +381,12 @@ TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong) ...@@ -381,12 +381,12 @@ TEST(type_prop, broadcast_v1_broadcast_shape_et_wrong)
try try
{ {
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes); auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect shape element type not i64"; FAIL() << "Broadcast: did not detect shape element type not integral number";
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast shape must have element type i64")); std::string("Broadcast shape must be an integral number"));
} }
catch (...) catch (...)
{ {
...@@ -404,12 +404,12 @@ TEST(type_prop, broadcast_v1_axes_et_wrong) ...@@ -404,12 +404,12 @@ TEST(type_prop, broadcast_v1_axes_et_wrong)
try try
{ {
auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes); auto bc = make_shared<op::v1::Broadcast>(arg, bc_shape, bc_axes);
FAIL() << "Broadcast: did not detect axes element type not i64"; FAIL() << "Broadcast: did not detect axes element type not integral numbers";
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast axes must have element type i64")); std::string("Broadcast axes must be integral numbers, but are:"));
} }
catch (...) catch (...)
{ {
......
...@@ -571,7 +571,8 @@ TEST(type_prop, pad_v1_arg_pads_begin_incompatible_type) ...@@ -571,7 +571,8 @@ TEST(type_prop, pad_v1_arg_pads_begin_incompatible_type)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("pads_begin must be type i64 (axes type:")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("pads_begin must be an integral number, but is:"));
} }
catch (...) catch (...)
{ {
...@@ -594,7 +595,8 @@ TEST(type_prop, pad_v1_arg_pads_end_incompatible_type) ...@@ -594,7 +595,8 @@ TEST(type_prop, pad_v1_arg_pads_end_incompatible_type)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("pads_end must be type i64 (axes type:")); EXPECT_HAS_SUBSTRING(error.what(),
std::string("pads_end must be an integral number, but is:"));
} }
catch (...) catch (...)
{ {
......
...@@ -26,7 +26,7 @@ using namespace ngraph; ...@@ -26,7 +26,7 @@ using namespace ngraph;
TEST(type_prop, strided_slice_begin_incorrect_type) TEST(type_prop, strided_slice_begin_incorrect_type)
{ {
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto begin = make_shared<op::Parameter>(element::i32, Shape{4}); auto begin = make_shared<op::Parameter>(element::f16, Shape{4});
auto end = make_shared<op::Parameter>(element::i64, Shape{4}); auto end = make_shared<op::Parameter>(element::i64, Shape{4});
try try
{ {
...@@ -37,7 +37,7 @@ TEST(type_prop, strided_slice_begin_incorrect_type) ...@@ -37,7 +37,7 @@ TEST(type_prop, strided_slice_begin_incorrect_type)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Begin mask must have element type i64")); EXPECT_HAS_SUBSTRING(error.what(), std::string("Begin mask must be an integral number"));
} }
catch (...) catch (...)
{ {
...@@ -49,7 +49,7 @@ TEST(type_prop, strided_slice_end_incorrect_type) ...@@ -49,7 +49,7 @@ TEST(type_prop, strided_slice_end_incorrect_type)
{ {
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8}); auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto begin = make_shared<op::Parameter>(element::i64, Shape{4}); auto begin = make_shared<op::Parameter>(element::i64, Shape{4});
auto end = make_shared<op::Parameter>(element::i32, Shape{4}); auto end = make_shared<op::Parameter>(element::boolean, Shape{4});
try try
{ {
auto strided_slice = make_shared<op::v1::StridedSlice>( auto strided_slice = make_shared<op::v1::StridedSlice>(
...@@ -59,7 +59,7 @@ TEST(type_prop, strided_slice_end_incorrect_type) ...@@ -59,7 +59,7 @@ TEST(type_prop, strided_slice_end_incorrect_type)
} }
catch (const NodeValidationFailure& error) catch (const NodeValidationFailure& error)
{ {
EXPECT_HAS_SUBSTRING(error.what(), std::string("End mask must have element type i64")); EXPECT_HAS_SUBSTRING(error.what(), std::string("End mask must be an integral number"));
} }
catch (...) catch (...)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment