Commit 2bb5bd50 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

(Dynamic) Reshape and Slice (#2611)

* added dyn_reshape and dyn_slice.

* style fix.

* some fixes.

* added dyn reshape type prop.

* fixed gpu build.

* added headers to gpu emitter.
parent 81f33056
......@@ -132,6 +132,10 @@ set (SRC
op/experimental/dyn_broadcast.hpp
op/experimental/dyn_pad.cpp
op/experimental/dyn_pad.hpp
op/experimental/dyn_reshape.cpp
op/experimental/dyn_reshape.hpp
op/experimental/dyn_slice.cpp
op/experimental/dyn_slice.hpp
op/experimental/generate_mask.cpp
op/experimental/generate_mask.hpp
op/experimental/quantized_avg_pool.cpp
......
......@@ -87,8 +87,9 @@
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <iostream>
#include "ngraph/function.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
using namespace std;
using namespace ngraph;
op::DynReshape::DynReshape(const shared_ptr<Node>& arg, const shared_ptr<Node>& pattern)
: Op("DynReshape", check_single_output_args({arg, pattern}))
{
constructor_validate_and_infer_types();
}
void op::DynReshape::validate_and_infer_types()
{
auto pattern_et = get_input_element_type(1);
// check data types
NODE_VALIDATION_CHECK(
this, pattern_et.compatible(element::Type_t::i64), "Pattern must have element type i64.");
// check shapes
const PartialShape& pattern_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
pattern_shape.rank().compatible(1),
"Pattern shape must have rank 1, got ",
pattern_shape.rank(),
".");
Rank output_rank = pattern_shape.rank().is_dynamic() ? Rank::dynamic() : pattern_shape[0];
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
}
shared_ptr<Node> op::DynReshape::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<DynReshape>(new_args.at(0), new_args.at(1));
}
void op::DynReshape::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("generate_adjoints not implemented for DynReshape");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Tensor dynamic reshape operation.
///
/// "Converts" an input tensor into a new shape with the same number of elements.
/// This op does not touch the actual data. If needed, use Transpose for that purpose.
///
class DynReshape : public Op
{
public:
/// \brief Constructs a dynamic reshape operation. This operation does not perform transpose.
///
/// \param arg The tensor to be reshaped.
/// \param pattern The node that defines output shape pattern.
/// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
DynReshape(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& pattern);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/experimental/dyn_slice.hpp"
#include <memory>
using namespace std;
using namespace ngraph;
op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds,
const shared_ptr<Node>& strides)
: Op("DynSlice", check_single_output_args({arg, lower_bounds, upper_bounds, strides}))
{
constructor_validate_and_infer_types();
}
void op::DynSlice::validate_and_infer_types()
{
auto lower_bounds_et = get_input_element_type(1);
auto upper_bounds_et = get_input_element_type(2);
auto strides_et = get_input_element_type(3);
// check data types
NODE_VALIDATION_CHECK(this,
lower_bounds_et.compatible(element::Type_t::i64),
"Lower bounds must have element type i64.");
NODE_VALIDATION_CHECK(this,
upper_bounds_et.compatible(element::Type_t::i64),
"Upper bounds must have element type i64.");
NODE_VALIDATION_CHECK(
this, strides_et.compatible(element::Type_t::i64), "Strides must have element type i64");
// check shapes
auto arg_shape = get_input_partial_shape(0);
auto lower_bounds_shape = get_input_partial_shape(1);
auto upper_bounds_shape = get_input_partial_shape(2);
auto strides_shape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
lower_bounds_shape.rank().compatible(1),
"Lower bounds shape must have rank 1, got ",
lower_bounds_shape.rank(),
".");
NODE_VALIDATION_CHECK(this,
upper_bounds_shape.rank().compatible(1),
"Upper bounds shape must have rank 1, got ",
upper_bounds_shape.rank(),
".");
NODE_VALIDATION_CHECK(this,
strides_shape.rank().compatible(1),
"Strides shape must have rank 1, got ",
strides_shape.rank(),
".");
NODE_VALIDATION_CHECK(this,
lower_bounds_shape.compatible(PartialShape{arg_shape.rank()}),
"Lower bounds must have shape [n], where n is the rank of arg.");
NODE_VALIDATION_CHECK(this,
upper_bounds_shape.compatible(PartialShape{arg_shape.rank()}),
"Upper bounds must have shape [n], where n is the rank of arg.");
NODE_VALIDATION_CHECK(this,
strides_shape.compatible(PartialShape{arg_shape.rank()}),
"Strides shape must have shape [n], where n is the rank of arg.");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
shared_ptr<Node> op::DynSlice::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<DynSlice>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
void op::DynSlice::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
throw ngraph_error("generate_adjoints not implemented for DynSlice");
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node_vector.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Takes a slice of an input tensor, i.e., the sub-tensor that resides within a bounding box, optionally with stride.
class DynSlice : public Op
{
public:
/// \brief Constructs a dynamic tensor slice operation.
///
/// \param arg The tensor to be sliced.
/// \param lower_bounds The axiswise lower bounds of the slice (inclusive).
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
DynSlice(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds,
const std::shared_ptr<Node>& strides);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
};
}
}
......@@ -82,6 +82,8 @@ NGRAPH_OP(Divide, ngraph::op)
NGRAPH_OP(DynBroadcast, ngraph::op)
NGRAPH_OP(Dot, ngraph::op)
NGRAPH_OP(DynPad, ngraph::op)
NGRAPH_OP(DynReshape, ngraph::op)
NGRAPH_OP(DynSlice, ngraph::op)
NGRAPH_OP(Equal, ngraph::op)
NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
......
......@@ -217,6 +217,8 @@ bool runtime::gpu::GPU_Backend::is_supported(const Node& op) const
{
set<string> unsupported_ops = {"Quantize",
"Dequantize",
"DynReshape",
"DynSlice",
"ShapeOf",
"All",
"Any",
......
......@@ -59,6 +59,8 @@
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
......@@ -589,6 +591,16 @@ std::string runtime::gpu::GPU_Emitter::emit_Dot(EMIT_ARGS)
return compiled_function->add_to_runtime(index, function_name, args, out);
}
std::string runtime::gpu::GPU_Emitter::emit_DynReshape(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_DynSlice(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
std::string runtime::gpu::GPU_Emitter::emit_EmbeddingLookup(EMIT_ARGS)
{
throw ngraph_error("EmbeddingLookup is not yet implemented for NVIDIA GPU");
......
......@@ -1995,6 +1995,8 @@ shared_ptr<runtime::Executable>
}
case OP_TYPEID::AllReduce:
case OP_TYPEID::BroadcastLike:
case OP_TYPEID::DynReshape:
case OP_TYPEID::DynSlice:
case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd:
......
......@@ -664,6 +664,16 @@ private:
dot->get_reduction_axes_count());
break;
}
case OP_TYPEID::DynReshape:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::DynSlice:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::EmbeddingLookup:
{
const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node);
......
......@@ -49,6 +49,8 @@
#include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_conv.hpp"
......@@ -759,6 +761,16 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::DynPad>(args[0], args[1], args[2], args[3]);
break;
}
case OP_TYPEID::DynReshape:
{
node = make_shared<op::DynReshape>(args[0], args[1]);
break;
}
case OP_TYPEID::DynSlice:
{
node = make_shared<op::DynSlice>(args[0], args[1], args[2], args[3]);
break;
}
case OP_TYPEID::EmbeddingLookup:
{
node = make_shared<op::EmbeddingLookup>(args[0], args[1]);
......@@ -1516,6 +1528,10 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::DynPad: { break;
}
case OP_TYPEID::DynReshape: { break;
}
case OP_TYPEID::DynSlice: { break;
}
case OP_TYPEID::EmbeddingLookup: { break;
}
case OP_TYPEID::Equal: { break;
......
......@@ -12652,3 +12652,405 @@ TEST(type_prop, dyn_pad_output_ranks_pad_static_ok)
EXPECT_EQ(dyn_pad->get_output_element_type(0), element::f32);
EXPECT_TRUE(dyn_pad->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(3)));
}
TEST(type_prop, dynreshape_arg_static_pattern_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto pattern = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynreshape_arg_rank_static_dynamic_pattern_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto pattern = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynreshape_arg_static_pattern_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreshape_arg_rank_static_dynamic_pattern_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreshape_arg_rank_dynamic_pattern_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreshape_arg_rank_dynamic_pattern_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::Transpose>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynreshape_arg_rank_static_dynamic_pattern_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
void DynReshape_Test_Shape_Except(const shared_ptr<Node>& param_0, const shared_ptr<Node>& param_1)
{
try
{
auto r = make_shared<op::DynReshape>(param_0, param_1);
FAIL() << "Did not detect parameter shape not rank 1";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("shape must have rank 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dynreshape_arg_static_pattern_static_not_vector)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 6, 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynReshape_Test_Shape_Except(arg, pattern);
}
TEST(type_prop, dynreshape_arg_static_pattern_rank_static_dynamic_not_vector)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 6, 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynReshape_Test_Shape_Except(arg, pattern);
}
TEST(type_prop, dynreshape_arg_rank_static_dynamic_pattern_static_not_vector)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynReshape_Test_Shape_Except(arg, pattern);
}
TEST(type_prop, dynreshape_arg_rank_static_dynamic_pattern_rank_static_dynamic_not_vector)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynReshape_Test_Shape_Except(arg, pattern);
}
TEST(type_prop, dynreshape_arg_rank_dynamic_pattern_rank_static_dynamic_not_vector)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto pattern = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynReshape_Test_Shape_Except(arg, pattern);
}
TEST(type_prop, dynreshape_pattern_et_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto pattern = make_shared<op::Parameter>(element::dynamic, Shape{4});
auto r = make_shared<op::DynReshape>(arg, pattern);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynreshape_pattern_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto pattern = make_shared<op::Parameter>(element::boolean, Shape{4});
try
{
auto r = make_shared<op::DynReshape>(arg, pattern);
FAIL() << "Did not detect pattern elment type not i64";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Pattern must have element type i64."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dynslice_arg_static_params_static_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_static_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynslice_arg_rank_dynamic_params_rank_static_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto strides = make_shared<op::Parameter>(element::i64, PartialShape{Dimension::dynamic()});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynslice_arg_rank_dynamic_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3)
{
try
{
auto r = make_shared<op::DynSlice>(param_0, param_1, param_2, param_3);
FAIL() << "Did not detect input order not vector";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("shape must have rank 1"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dynslice_arg_static_params_rank_static_dynamic_not_vector)
{
auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, 4, 6, 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
auto strides = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
{
lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
lower_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
lower_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
upper_bounds = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
upper_bounds =
make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, 2});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
{
arg = make_shared<op::Parameter>(
element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
strides = make_shared<op::Parameter>(element::i64, PartialShape{2, Dimension::dynamic()});
DynSlice_Test_Shape_Except(arg, lower_bounds, upper_bounds, strides);
}
}
TEST(type_prop, dynslice_params_et_dynamic_ok)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
void DynSlice_Test_Type_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
const shared_ptr<Node>& param_3)
{
try
{
auto r = make_shared<op::DynSlice>(param_0, param_1, param_2, param_3);
FAIL() << "Did not detect parameter element type not i64";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("must have element type i64."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, dynslice_params_et_wrong)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto lower_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto upper_bounds = make_shared<op::Parameter>(element::i64, Shape{4});
auto strides = make_shared<op::Parameter>(element::i64, Shape{4});
{
lower_bounds = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
}
{
upper_bounds = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
}
{
strides = make_shared<op::Parameter>(element::boolean, Shape{4});
DynSlice_Test_Type_Except(arg, lower_bounds, upper_bounds, strides);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment