//***************************************************************************** // Copyright 2017-2019 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "util/type_prop.hpp" using namespace std; using namespace ngraph; TEST(type_prop, range_nonconst_ok) { auto start = make_shared<op::Parameter>(element::i32, Shape{}); auto stop = make_shared<op::Parameter>(element::i32, Shape{}); auto step = make_shared<op::Parameter>(element::i32, Shape{}); auto range = make_shared<op::Range>(start, stop, step); EXPECT_EQ(range->get_element_type(), element::i32); EXPECT_TRUE(range->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1))); } TEST(type_prop, range_nonconst_some_dyn_et_ok) { auto start = make_shared<op::Parameter>(element::i32, Shape{}); auto stop = make_shared<op::Parameter>(element::dynamic, Shape{}); auto step = make_shared<op::Parameter>(element::i32, Shape{}); auto range = make_shared<op::Range>(start, stop, step); EXPECT_EQ(range->get_element_type(), element::i32); EXPECT_TRUE(range->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1))); } TEST(type_prop, range_nonconst_all_dyn_et_ok) { auto start = make_shared<op::Parameter>(element::dynamic, Shape{}); auto stop = make_shared<op::Parameter>(element::dynamic, Shape{}); auto step = make_shared<op::Parameter>(element::dynamic, Shape{}); auto range = make_shared<op::Range>(start, stop, step); EXPECT_EQ(range->get_element_type(), element::dynamic); EXPECT_TRUE(range->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1))); } TEST(type_prop, range_nonconst_f32_ok) { auto start = make_shared<op::Parameter>(element::dynamic, Shape{}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Parameter>(element::dynamic, Shape{}); auto range = make_shared<op::Range>(start, stop, step); EXPECT_EQ(range->get_element_type(), element::f32); EXPECT_TRUE(range->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1))); } TEST(type_prop, range_nonconst_boolean_fails) { auto start = make_shared<op::Parameter>(element::dynamic, Shape{}); auto stop = make_shared<op::Parameter>(element::boolean, Shape{}); auto step = make_shared<op::Parameter>(element::dynamic, Shape{}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "Boolean element type not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "Element type for start, stop, and step, must not be boolean."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_ok) { auto start = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{3}); auto stop = make_shared<op::Parameter>(element::i32, Shape{}); auto step = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{2}); auto range = make_shared<op::Range>(start, stop, step); EXPECT_EQ(range->get_element_type(), element::i32); EXPECT_TRUE(range->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(1))); } TEST(type_prop, range_some_const_zero_stride_fails) { auto start = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{3}); auto stop = make_shared<op::Parameter>(element::i32, Shape{}); auto step = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{0}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "Zero stride not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'step' cannot be zero."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_plus_inf_start_fails) { auto start = make_shared<op::Constant>( element::f32, Shape{}, std::vector<float>{std::numeric_limits<float>::infinity()}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{1}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "+Infinity start not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'start' cannot be nan or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_minus_inf_start_fails) { auto start = make_shared<op::Constant>( element::f32, Shape{}, std::vector<float>{-std::numeric_limits<float>::infinity()}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{1}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "-Infinity start not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'start' cannot be nan or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_nan_start_fails) { auto start = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{std::nanf("")}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{1}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "NaN start not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'start' cannot be nan or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_plus_inf_stop_fails) { auto start = make_shared<op::Parameter>(element::f32, Shape{}); auto stop = make_shared<op::Constant>( element::f32, Shape{}, std::vector<float>{std::numeric_limits<float>::infinity()}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{1}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "+Infinity stop not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'stop' cannot be nan or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_minus_inf_stop_fails) { auto start = make_shared<op::Parameter>(element::f32, Shape{}); auto stop = make_shared<op::Constant>( element::f32, Shape{}, std::vector<float>{-std::numeric_limits<float>::infinity()}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{1}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "-Infinity stop not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'stop' cannot be nan or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_nan_stio_fails) { auto start = make_shared<op::Parameter>(element::f32, Shape{}); auto stop = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{std::nanf("")}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{1}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "NaN stop not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'stop' cannot be nan or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_plus_inf_stride_fails) { auto start = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{3}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Constant>( element::f32, Shape{}, std::vector<float>{std::numeric_limits<float>::infinity()}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "+Infinity stride not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'step' cannot be zero, nan, or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_minus_inf_stride_fails) { auto start = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{3}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Constant>( element::f32, Shape{}, std::vector<float>{-std::numeric_limits<float>::infinity()}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "-Infinity stride not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'step' cannot be zero, nan, or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_some_const_nan_stride_fails) { auto start = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{3}); auto stop = make_shared<op::Parameter>(element::f32, Shape{}); auto step = make_shared<op::Constant>(element::f32, Shape{}, std::vector<float>{std::nanf("")}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "NaN stride not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'step' cannot be zero, nan, or infinite."); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } TEST(type_prop, range_all_const_zero_stride_fails) { auto start = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{3}); auto stop = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{5}); auto step = make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{0}); try { auto range = make_shared<op::Range>(start, stop, step); FAIL() << "Zero stride not detected"; } catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), "'step' cannot be zero"); } catch (...) { FAIL() << "Test failed for unexpected reason"; } } struct RangeParams { double start; double stop; double step; PartialShape expected_shape; }; template <typename T> void run_range_test(const element::Type& et, const RangeParams& params) { auto start = make_shared<op::Constant>(et, Shape{}, std::vector<T>{static_cast<T>(params.start)}); auto stop = make_shared<op::Constant>(et, Shape{}, std::vector<T>{static_cast<T>(params.stop)}); auto step = make_shared<op::Constant>(et, Shape{}, std::vector<T>{static_cast<T>(params.step)}); auto range = make_shared<op::Range>(start, stop, step); EXPECT_EQ(range->get_element_type(), et); EXPECT_TRUE(range->get_output_partial_shape(0).same_scheme(params.expected_shape)) << "Expected shape " << params.expected_shape << " but got " << range->get_output_partial_shape(0); } struct RangeTest : ::testing::TestWithParam<RangeParams> { }; TEST_P(RangeTest, deduce_shape_i8) { run_range_test<int8_t>(element::i8, GetParam()); } TEST_P(RangeTest, deduce_shape_i16) { run_range_test<int16_t>(element::i16, GetParam()); } TEST_P(RangeTest, deduce_shape_i32) { run_range_test<int32_t>(element::i32, GetParam()); } TEST_P(RangeTest, deduce_shape_i64) { run_range_test<int64_t>(element::i64, GetParam()); } TEST_P(RangeTest, deduce_shape_u8) { run_range_test<uint8_t>(element::u8, GetParam()); } TEST_P(RangeTest, deduce_shape_u16) { run_range_test<uint16_t>(element::u16, GetParam()); } TEST_P(RangeTest, deduce_shape_u32) { run_range_test<uint32_t>(element::u32, GetParam()); } TEST_P(RangeTest, deduce_shape_u64) { run_range_test<uint64_t>(element::u64, GetParam()); } TEST_P(RangeTest, deduce_shape_bf16) { run_range_test<bfloat16>(element::bf16, GetParam()); } TEST_P(RangeTest, deduce_shape_f16) { run_range_test<float16>(element::f16, GetParam()); } TEST_P(RangeTest, deduce_shape_f32) { run_range_test<float>(element::f32, GetParam()); } TEST_P(RangeTest, deduce_shape_f64) { run_range_test<double>(element::f64, GetParam()); } INSTANTIATE_TEST_CASE_P(type_prop, RangeTest, ::testing::Values(RangeParams{0, 5, 1, PartialShape{5}}, RangeParams{0, 22, 2, PartialShape{11}}, RangeParams{1, 23, 2, PartialShape{11}}, RangeParams{1, 22, 2, PartialShape{11}}, RangeParams{0, 0, 1, PartialShape{0}}, RangeParams{1, 0, 2, PartialShape{0}}), PrintToDummyParamName()); struct RangeTestWithNegatives : ::testing::TestWithParam<RangeParams> { }; TEST_P(RangeTestWithNegatives, deduce_shape_i8) { run_range_test<int8_t>(element::i8, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_i16) { run_range_test<int16_t>(element::i16, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_i32) { run_range_test<int32_t>(element::i32, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_i64) { run_range_test<int64_t>(element::i64, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_bf16) { run_range_test<bfloat16>(element::bf16, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_f16) { run_range_test<float16>(element::f16, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_f32) { run_range_test<float>(element::f32, GetParam()); } TEST_P(RangeTestWithNegatives, deduce_shape_f64) { run_range_test<double>(element::f64, GetParam()); } INSTANTIATE_TEST_CASE_P(type_prop, RangeTestWithNegatives, ::testing::Values(RangeParams{2, 0, -2, PartialShape{1}}, RangeParams{2, 0, -1, PartialShape{2}}, RangeParams{-19, 19, 1, PartialShape{38}}, RangeParams{-19, 19, 3, PartialShape{13}}, RangeParams{20, -19, 1, PartialShape{0}}), PrintToDummyParamName()); struct RangeTestFloating : ::testing::TestWithParam<RangeParams> { }; TEST_P(RangeTestFloating, deduce_shape_bf16) { run_range_test<bfloat16>(element::bf16, GetParam()); } TEST_P(RangeTestFloating, deduce_shape_f16) { run_range_test<float16>(element::f16, GetParam()); } TEST_P(RangeTestFloating, deduce_shape_f32) { run_range_test<float>(element::f32, GetParam()); } TEST_P(RangeTestFloating, deduce_shape_f64) { run_range_test<double>(element::f64, GetParam()); } INSTANTIATE_TEST_CASE_P(type_prop, RangeTestFloating, ::testing::Values(RangeParams{0, 1, 0.25, PartialShape{4}}, RangeParams{-1, 1, 0.25, PartialShape{8}}, RangeParams{-1, 0.875, 0.25, PartialShape{8}}), PrintToDummyParamName());