Commit 5485d6b6 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Slice operator (#188)

* Formatting; fix erroneous "type_prop" classification for reduce and slice tests

* Clarify comment

* Add step parameter to slice, with type checking but not (yet) implementation in VM
parent 55ac615b
......@@ -38,6 +38,7 @@ set (SRC
ops/reduce.cpp
ops/reshape.cpp
ops/select.cpp
ops/slice.cpp
ops/tuple.cpp
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise_builtin.cpp
......
......@@ -43,6 +43,9 @@ namespace ngraph
/// @brief A set of axes, for example, reduction axes
using AxisSet = std::set<size_t>;
/// @brief Coordinate in a tensor
using Coordinate = std::vector<size_t>;
/// @brief Shape for a tensor
using Shape = std::vector<size_t>;
......
......@@ -79,6 +79,7 @@
#include "ngraph/ops/remainder.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/runtime/backend.hpp"
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "ngraph/ops/slice.hpp"
using namespace std;
using namespace ngraph::op;
void Slice::propagate_types()
{
if (m_arguments.size() != 1)
{
throw ngraph_error("Wrong number of arguments.");
}
auto arg_type = m_arguments.at(0)->get_value_type();
if (nullptr == arg_type)
{
throw ngraph_error("Argument to slice is missing type.");
}
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
if (nullptr == arg_tensor_view_type)
{
throw ngraph_error("Argument to slice is not a tensor view");
}
auto& arg_shape = arg_tensor_view_type->get_shape();
if (m_lower_bounds.size() != arg_shape.size())
{
throw ngraph_error(
"Number of lower bounds provided for slice does not match number of input axes");
}
if (m_upper_bounds.size() != arg_shape.size())
{
throw ngraph_error(
"Number of upper bounds provided for slice does not match number of input axes");
}
if (m_step.size() != arg_shape.size())
{
throw ngraph_error(
"Number of step axes provided for slice does not match number of input axes");
}
Shape result_shape;
for (size_t i = 0; i < arg_shape.size(); i++)
{
if (m_upper_bounds[i] > arg_shape[i])
{
throw ngraph_error("Upper bound for slice is out of range");
}
if (m_lower_bounds[i] > m_upper_bounds[i])
{
throw ngraph_error("Lower bound for slice is greater than upper bound");
}
if (0 == m_step[i])
{
throw ngraph_error("Step distance for slice is zero");
}
size_t result_axis_size = m_upper_bounds[i] - m_lower_bounds[i];
result_axis_size =
result_axis_size / m_step[i] + ((result_axis_size % m_step[i] == 0) ? 0 : 1);
result_shape.push_back(result_axis_size);
}
set_value_type_checked(
make_shared<TensorViewType>(arg_tensor_view_type->get_element_type(), result_shape));
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
class Slice : public Builtin
{
public:
///
/// @param arg The tensor view to be sliced.
/// @param lower_bounds The axiswise lower bounds of the slice.
/// @param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// @param step The slicing step; for example, step of {n,m} means to take
/// every nth row and everyth mth column of the input matrix.
///
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds,
const Shape& step)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(step)
{
}
Slice(const std::shared_ptr<Node>& arg,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
: Builtin({arg})
, m_lower_bounds(lower_bounds)
, m_upper_bounds(upper_bounds)
, m_step(Shape(lower_bounds.size(), 1))
{
}
virtual std::string description() const override { return "Slice"; }
virtual void propagate_types() override;
const Coordinate& get_lower_bounds() const { return m_lower_bounds; }
const Coordinate& get_upper_bounds() const { return m_upper_bounds; }
const Shape& get_step() const { return m_step; }
protected:
const Coordinate m_lower_bounds;
const Coordinate m_upper_bounds;
const Shape m_step;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class MatrixSliceInstruction : public Instruction
{
public:
MatrixSliceInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out,
size_t lower_row,
size_t lower_col,
size_t upper_row,
size_t upper_col)
: m_arg(arg)
, m_out(out)
, m_lower_row(lower_row)
, m_lower_col(lower_col)
, m_upper_row(upper_row)
, m_upper_col(upper_col)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenMatrix<ET>(call_frame, m_out) = EigenMatrix<ET>(call_frame, m_arg)
.block(m_lower_row,
m_lower_col,
m_upper_row - m_lower_row,
m_upper_col - m_lower_col);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
size_t m_lower_row;
size_t m_lower_col;
size_t m_upper_row;
size_t m_upper_col;
};
}
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/runtime/ngvm/call_frame.hpp"
#include "ngraph/runtime/ngvm/eigen/utils.hpp"
#include "ngraph/runtime/ngvm/instruction.hpp"
#include "ngraph/runtime/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
namespace ngvm
{
namespace eigen
{
template <typename ET>
class VectorSliceInstruction : public Instruction
{
public:
VectorSliceInstruction(const TensorViewInfo& arg,
const TensorViewInfo& out,
size_t lower,
size_t upper)
: m_arg(arg)
, m_out(out)
, m_lower(lower)
, m_upper(upper)
{
}
virtual void execute(CallFrame& call_frame) const override
{
EigenVector<ET>(call_frame, m_out) =
EigenVector<ET>(call_frame, m_arg).segment(m_lower, m_upper - m_lower);
}
protected:
TensorViewInfo m_arg;
TensorViewInfo m_out;
size_t m_lower;
size_t m_upper;
};
}
}
}
}
......@@ -47,6 +47,7 @@
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_tensors.hpp"
......@@ -74,6 +75,7 @@
#include "ngraph/runtime/ngvm/eigen/less_than.hpp"
#include "ngraph/runtime/ngvm/eigen/log.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_mult.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_slice.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_transpose.hpp"
#include "ngraph/runtime/ngvm/eigen/matrix_vector_product.hpp"
#include "ngraph/runtime/ngvm/eigen/maximum.hpp"
......@@ -87,6 +89,7 @@
#include "ngraph/runtime/ngvm/eigen/scalar_tensor_product.hpp"
#include "ngraph/runtime/ngvm/eigen/select.hpp"
#include "ngraph/runtime/ngvm/eigen/subtract.hpp"
#include "ngraph/runtime/ngvm/eigen/vector_slice.hpp"
#include "ngraph/runtime/ngvm/external_function.hpp"
#include "ngraph/runtime/utils.hpp"
......@@ -835,6 +838,67 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
}
};
REGISTER_TO_OP_MAP(op::Slice)
{
auto slice = static_cast<const op::Slice*>(n);
for (auto d : slice->get_step())
{
if (1 != d)
{
throw ngraph_error("Slice does not support non-unit step yet in the VM");
}
}
auto arg_type = slice->get_arguments().at(0)->get_value_type();
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
assert(nullptr != arg_tensor_view_type);
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
auto& lower_bounds = slice->get_lower_bounds();
auto& upper_bounds = slice->get_upper_bounds();
// Scalar slice is necessarily just a copy.
if (arg_rank == 0)
{
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type,
"Slice has unhandled element type",
runtime::ngvm::eigen::CopyInstruction,
in.at(0).get_index(),
out.at(0).get_index());
}
else if (arg_rank == 1)
{
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type,
"Slice has unhandled element type",
runtime::ngvm::eigen::VectorSliceInstruction,
in[0],
out[0],
lower_bounds[0],
upper_bounds[0]);
}
else if (arg_rank == 2)
{
PUSH_POLYMORPHIC_INSTRUCTION(arg_element_type,
"Slice has unhandled element type",
runtime::ngvm::eigen::MatrixSliceInstruction,
in[0],
out[0],
lower_bounds[0],
lower_bounds[1],
upper_bounds[0],
upper_bounds[1]);
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
{
throw ngraph_error("Slice is not implemented yet for tensors with rank>2 in VM");
}
};
initialized = true;
}
return op_map;
......
......@@ -1621,7 +1621,7 @@ TEST(execute, reduce_matrix_to_scalar_zero_by_zero)
ASSERT_EQ((vector<float>{99}), b->get_vector());
}
TEST(type_prop, reshape_t2v_012)
TEST(execute, reshape_t2v_012)
{
auto shape_a = Shape{2, 2, 3};
auto A = make_shared<op::Parameter>(
......@@ -1645,7 +1645,7 @@ TEST(type_prop, reshape_t2v_012)
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), result->get_vector());
}
TEST(type_prop, reshape_t2s_012)
TEST(execute, reshape_t2s_012)
{
auto shape_a = Shape{1, 1, 1};
auto A = make_shared<op::Parameter>(
......@@ -1669,7 +1669,7 @@ TEST(type_prop, reshape_t2s_012)
ASSERT_EQ((vector<float>{6}), result->get_vector());
}
TEST(type_prop, reshape_t2s_120)
TEST(execute, reshape_t2s_120)
{
auto shape_a = Shape{1, 1, 1};
auto A = make_shared<op::Parameter>(
......@@ -1693,7 +1693,7 @@ TEST(type_prop, reshape_t2s_120)
ASSERT_EQ((vector<float>{6}), result->get_vector());
}
TEST(type_prop, reshape_s2t)
TEST(execute, reshape_s2t)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(
......@@ -1717,7 +1717,7 @@ TEST(type_prop, reshape_s2t)
ASSERT_EQ((vector<float>{42}), result->get_vector());
}
TEST(type_prop, reshape_v2m_col)
TEST(execute, reshape_v2m_col)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(
......@@ -1741,7 +1741,7 @@ TEST(type_prop, reshape_v2m_col)
ASSERT_EQ((vector<float>{1, 2, 3}), result->get_vector());
}
TEST(type_prop, reshape_v2m_row)
TEST(execute, reshape_v2m_row)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(
......@@ -1765,7 +1765,7 @@ TEST(type_prop, reshape_v2m_row)
ASSERT_EQ((vector<float>{1, 2, 3}), result->get_vector());
}
TEST(type_prop, reshape_v2t_middle)
TEST(execute, reshape_v2t_middle)
{
auto shape_a = Shape{3};
auto A = make_shared<op::Parameter>(
......@@ -1789,7 +1789,7 @@ TEST(type_prop, reshape_v2t_middle)
ASSERT_EQ((vector<float>{1, 2, 3}), result->get_vector());
}
TEST(type_prop, reshape_m2m_same)
TEST(execute, reshape_m2m_same)
{
auto shape_a = Shape{3, 3};
auto A = make_shared<op::Parameter>(
......@@ -1813,7 +1813,7 @@ TEST(type_prop, reshape_m2m_same)
ASSERT_EQ((vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9}), result->get_vector());
}
TEST(type_prop, reshape_m2m_transpose)
TEST(execute, reshape_m2m_transpose)
{
auto shape_a = Shape{3, 3};
auto A = make_shared<op::Parameter>(
......@@ -1837,7 +1837,7 @@ TEST(type_prop, reshape_m2m_transpose)
ASSERT_EQ((vector<float>{1, 4, 7, 2, 5, 8, 3, 6, 9}), result->get_vector());
}
TEST(type_prop, reshape_m2m_dim_change_transpose)
TEST(execute, reshape_m2m_dim_change_transpose)
{
auto shape_a = Shape{3, 2};
auto A = make_shared<op::Parameter>(
......@@ -1883,3 +1883,75 @@ TEST(execute, exp)
(vector<float>{expf(-4), expf(-3), expf(-2), expf(-1), expf(0), expf(1), expf(2), expf(3)}),
result->get_vector());
}
TEST(execute, slice_scalar)
{
auto shape_a = Shape{};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Slice>(A, Coordinate{}, Coordinate{});
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{312};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{312}), result->get_vector());
}
TEST(execute, slice_matrix)
{
auto shape_a = Shape{4, 4};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{2, 3};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Slice>(A, Coordinate{0, 1}, Coordinate{3, 3});
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{2, 3, 6, 7, 10, 11}), result->get_vector());
}
TEST(execute, slice_vector)
{
auto shape_a = Shape{16};
auto A = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), shape_a));
auto shape_r = Shape{12};
auto rt = make_shared<TensorViewType>(element::Float32::element_type(), shape_r);
auto r = make_shared<op::Slice>(A, Coordinate{2}, Coordinate{14});
auto f = make_shared<Function>(r, rt, op::Parameters{A});
auto manager = runtime::Manager::get("NGVM");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
auto a = ngraph::runtime::make_tensor<element::Float32>(shape_a);
*a = vector<float>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
auto result = ngraph::runtime::make_tensor<element::Float32>(shape_r);
(*cf)({a}, {result});
ASSERT_EQ((vector<float>{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}), result->get_vector());
}
......@@ -1235,3 +1235,285 @@ TEST(type_prop, reshape_deduce_wrong_output_shape)
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_vector)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{2}, Coordinate{5});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{3}));
}
TEST(type_prop, slice_deduce_matrix)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{2, 1}, Coordinate{5, 7});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{3, 6}));
}
TEST(type_prop, slice_deduce_matrix_strided)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{2, 1}, Coordinate{5, 7}, Shape{3, 2});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 3}));
}
TEST(type_prop, slice_deduce_matrix_strided_uneven)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{2, 1}, Coordinate{5, 7}, Shape{3, 4});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{1, 2}));
}
TEST(type_prop, slice_deduce_vector_edge)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{6});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()), TensorViewType(element::Float32::element_type(), Shape{6}));
}
TEST(type_prop, slice_deduce_matrix_edge)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 8});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{6, 8}));
}
TEST(type_prop, slice_deduce_matrix_zero_cols)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 0});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{6, 0}));
}
TEST(type_prop, slice_deduce_matrix_zero_zero)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{0, 0});
sl->propagate_types();
ASSERT_EQ(*(sl->get_value_type()),
TensorViewType(element::Float32::element_type(), Shape{0, 0}));
}
TEST(type_prop, slice_deduce_vector_invalid_step)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{7}, Shape{1, 2});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Invalid slice step not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string(
"Number of step axes provided for slice does not match number of input axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_vector_edge_upper_oob)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{7});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Upper bound out of range not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Upper bound for slice is out of range"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_matrix_edge_upper_oob)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{6, 9});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Upper bound out of range not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Upper bound for slice is out of range"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_vector_lower_above_upper)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6}));
auto sl = make_shared<op::Slice>(param, Coordinate{3}, Coordinate{2});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Lower bound above upper not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Lower bound for slice is greater than upper bound"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_matrix_lower_above_upper)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 5}, Coordinate{6, 4});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Lower bound above upper not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(error.what(), std::string("Lower bound for slice is greater than upper bound"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_matrix_lower_missing)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0}, Coordinate{5, 5});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Missing lower bound coordinate not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string(
"Number of lower bounds provided for slice does not match number of input axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_matrix_upper_missing)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{5});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Missing upper bound coordinate not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string(
"Number of upper bounds provided for slice does not match number of input axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_matrix_lower_extra)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0, 0}, Coordinate{5, 5});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Extra lower bound coordinate not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string(
"Number of lower bounds provided for slice does not match number of input axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, slice_deduce_matrix_upper_extra)
{
auto param = make_shared<op::Parameter>(
make_shared<TensorViewType>(element::Float32::element_type(), Shape{6, 8}));
auto sl = make_shared<op::Slice>(param, Coordinate{0, 0}, Coordinate{5, 5, 5});
try
{
sl->propagate_types();
// Should have thrown, so fail if it didn't
FAIL() << "Extra upper bound coordinate not detected";
}
catch (const ngraph_error& error)
{
EXPECT_EQ(
error.what(),
std::string(
"Number of upper bounds provided for slice does not match number of input axes"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment