Commit e07bc028 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Scott Cyphers

[SPEC] Add using Reshape:v1 (#4076)

* Add using  Reshape:v1

* Use Reshape:v0 in group_conv

* Use Reshape_V1 in builder:flatten

* builder:v1 introduced

* Revert old builders to use Reshape:v0

* removed unused Transpose test

* Update test/opset_pass/transpose_opset_pass.cpp

* Changed builders to opset1

* Use opset1 instead of v1
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent 7a3f6480
......@@ -26,8 +26,10 @@
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
......@@ -173,3 +175,56 @@ shared_ptr<Node> builder::expand_dims(const Output<Node>& value, size_t axis)
value, get_default_order(value.get_shape().size()), output_shape)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::reshape(const Output<Node>& value, const Shape& shape)
{
const auto out_pattern = op::Constant::create(
element::i64, Shape{shape.size()}, vector<int64_t>(shape.begin(), shape.end()));
const bool special_zero = false;
return make_shared<ngraph::opset1::Reshape>(value, out_pattern, special_zero)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::reorder_axes(const Output<Node>& value, vector<size_t> axes_order)
{
const auto axes_order_const =
op::Constant::create(element::i64,
Shape{axes_order.size()},
vector<int64_t>(axes_order.begin(), axes_order.end()));
return make_shared<ngraph::opset1::Transpose>(value, axes_order_const)
->add_provenance_group_members_above({value});
}
shared_ptr<Node> builder::opset1::transpose(const Output<Node>& value)
{
vector<size_t> axes_order(value.get_shape().size());
iota(begin(axes_order), end(axes_order), 0);
reverse(begin(axes_order), end(axes_order));
return builder::opset1::reorder_axes(value, axes_order);
}
shared_ptr<Node> builder::opset1::flatten(const Output<Node>& value, int axis)
{
auto data_shape = value.get_shape();
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input
// tensor. The last dimension is the product of the rest of input tensor dimensions:
// [d_{axis}, ..., d_n]
size_t first_dim_size =
accumulate(begin(data_shape), next(begin(data_shape), axis), 1UL, multiplies<size_t>());
size_t last_dim_size =
accumulate(next(begin(data_shape), axis), end(data_shape), 1UL, multiplies<size_t>());
return builder::opset1::reshape(value, Shape{first_dim_size, last_dim_size});
}
shared_ptr<Node> builder::opset1::expand_dims(const Output<Node>& value, size_t axis)
{
Shape output_shape(value.get_shape());
// Add empty axis at specified position.
auto empty_axis_it = begin(output_shape);
advance(empty_axis_it, axis);
output_shape.insert(empty_axis_it, 1);
return builder::opset1::reshape(value, output_shape);
}
......@@ -103,5 +103,51 @@ namespace ngraph
/// \return The node with added empty axis.
///
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
namespace opset1
{
/// \brief Change shape of a value
///
/// \param[in] value The value to be reshaped.
/// \param[in] shape The new shape.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> reshape(const Output<Node>& value, const Shape& shape);
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param The vlaue whose axes we want to permute.
/// \param axes_order The permutation of axes.
///
/// \return Transpose:v1 op.
std::shared_ptr<Node> reorder_axes(const Output<Node>& value,
std::vector<size_t> axes_order = {});
/// \brief Return transposed vlaue (with axes in reversed order).
///
/// \param Value to transpose.
///
/// \return Transpose:v1 op.
std::shared_ptr<Node> transpose(const Output<Node>& value);
/// \brief Flatten a value into a 2D matrix, with a static dividing axis.
///
/// \param The tensor to be flattened.
/// \param The axis dividing shape.
///
/// \return The new value will be a 2D matrix representing the flattened input
/// node.
std::shared_ptr<Node> flatten(const Output<Node>& value, int axis);
/// \brief Expands node tensor shape with empty axis at
/// specified position.
///
/// \param[in] value The value to be expanded.
/// \param[in] axis The position in the expanded axes where the
/// new axis is placed.
///
/// \return Reshape:v1 op.
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
}
} // namespace builder
} // namespace ngraph
......@@ -56,7 +56,7 @@ namespace ngraph
filters_shape.insert(filters_shape.begin(), groups);
auto reshaped_filters =
ngraph::builder::reshape(filters, filters_shape);
ngraph::builder::opset1::reshape(filters, filters_shape);
return std::make_shared<default_opset::GroupConvolution>(
data,
......
......@@ -39,7 +39,7 @@ namespace ngraph
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, data_rank, -data_rank, data_rank);
return {ngraph::builder::flatten(data, normalized_axis)};
return {ngraph::builder::opset1::flatten(data, normalized_axis)};
}
} // namespace set_1
......
......@@ -62,16 +62,16 @@ namespace ngraph
if (trans_a)
{
input_a = ngraph::builder::transpose(input_a);
input_a = ngraph::builder::opset1::transpose(input_a);
}
if (trans_b)
{
input_b = ngraph::builder::transpose(input_b);
input_b = ngraph::builder::opset1::transpose(input_b);
}
input_a = ngraph::builder::flatten(input_a, 1);
input_b = ngraph::builder::flatten(input_b, 1);
input_a = ngraph::builder::opset1::flatten(input_a, 1);
input_b = ngraph::builder::opset1::flatten(input_b, 1);
auto matmul_node = std::make_shared<ngraph::op::MatMul>(input_a, input_b);
......
......@@ -39,7 +39,8 @@ namespace ngraph
ngraph::normalize_axis(node.get_description(), axis, input_shape.size());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::builder::flatten(input, normalized_axis);
const auto coerced_tensor =
ngraph::builder::opset1::flatten(input, normalized_axis);
const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d =
......@@ -53,7 +54,7 @@ namespace ngraph
auto results =
std::make_shared<ngraph::opset0::EmbeddingLookup>(argmax_2d, eye_matrix);
return {ngraph::builder::reshape(results, input_shape)};
return {ngraph::builder::opset1::reshape(results, input_shape)};
}
} // namespace set_1
......
......@@ -37,8 +37,8 @@ namespace ngraph
node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty())
? ngraph::builder::transpose(data)
: ngraph::builder::reorder_axes(data, permute_axes)};
? ngraph::builder::opset1::transpose(data)
: ngraph::builder::opset1::reorder_axes(data, permute_axes)};
}
} // namespace set_1
......
......@@ -76,10 +76,7 @@ namespace ngraph
{
output_shape.at(idx) = 1;
}
return std::make_shared<ngraph::op::Reshape>(
op_node,
ngraph::get_default_order(op_node->get_shape().size()),
Shape{output_shape});
return builder::opset1::reshape(op_node, output_shape);
}
std::shared_ptr<ngraph::Node>
......
......@@ -22,8 +22,8 @@
#include "core/node.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp"
......@@ -100,10 +100,7 @@ namespace ngraph
auto output_shape = input_node->get_shape();
output_shape.at(normalized_axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node,
ngraph::get_default_order(op_node->get_shape().size()),
Shape{output_shape});
auto reshape_node = builder::opset1::reshape(op_node, output_shape);
// WORKAROUND FOR PROBLEMS WITH RESHAPE ON i64 @TODO: remove
auto reconvert_node =
......
......@@ -109,7 +109,7 @@ namespace ngraph
node->get_element_type(), ngraph::Shape{}, value);
}
return ngraph::builder::reshape(node, Shape{});
return ngraph::builder::opset1::reshape(node, Shape{});
}
} // namespace reshape
......
......@@ -22,15 +22,15 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Transpose::type_info;
constexpr NodeTypeInfo op::v1::Transpose::type_info;
op::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_order)
op::v1::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_order)
: Op({arg, input_order})
{
constructor_validate_and_infer_types();
}
void op::Transpose::validate_and_infer_types()
void op::v1::Transpose::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
get_input_element_type(1).compatible(element::i64),
......@@ -65,16 +65,16 @@ void op::Transpose::validate_and_infer_types()
}
}
shared_ptr<Node> op::Transpose::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Transpose::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Transpose>(new_args.at(0), new_args.at(1));
return make_shared<v1::Transpose>(new_args.at(0), new_args.at(1));
}
// TODO(amprocte): This will require some way of inverting the permutation in-graph. (TensorFlow,
// for example, has an InvertPermutation op, but that doesn't feel very nGraph-y somehow.)
void op::Transpose::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const OutputVector& /* deltas */)
void op::v1::Transpose::generate_adjoints(autodiff::Adjoints& /* adjoints */,
const OutputVector& /* deltas */)
{
throw ngraph_error("generate_adjoints not implemented for Transpose");
}
......@@ -24,7 +24,7 @@ namespace ngraph
{
namespace op
{
namespace v0
namespace v1
{
/// \brief Tensor transpose operation.
class NGRAPH_API Transpose : public Op
......@@ -52,6 +52,6 @@ namespace ngraph
const OutputVector& deltas) override;
};
}
using v0::Transpose;
using v1::Transpose;
}
}
......@@ -59,7 +59,7 @@ NodeVector op::DepthToSpace::decompose_op() const
{
// Insert batch axis
data_shape.insert(data_shape.begin(), 1);
data = builder::reshape(data, data_shape);
data = builder::opset1::reshape(data, data_shape);
}
const size_t n_dim = data_shape.at(0);
const size_t c_dim = data_shape.at(1);
......@@ -102,7 +102,7 @@ NodeVector op::DepthToSpace::decompose_op() const
case DepthToSpaceMode::DEPTH_FIRST:
{
dispersed_shape.insert(dispersed_shape.begin() + 1, c_flat);
flat_node = builder::reshape(data, dispersed_shape);
flat_node = builder::opset1::reshape(data, dispersed_shape);
axes_order.push_back(1);
for (int i = spatial_dim_index; i < data_shape.size(); ++i)
......@@ -111,7 +111,7 @@ NodeVector op::DepthToSpace::decompose_op() const
axes_order.push_back(i);
}
flat_node = builder::reorder_axes(flat_node, axes_order);
flat_node = builder::opset1::reorder_axes(flat_node, axes_order);
break;
}
// x' = reshape(data, [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2,
......@@ -123,7 +123,7 @@ NodeVector op::DepthToSpace::decompose_op() const
default:
{
dispersed_shape.insert(dispersed_shape.begin() + spatial_dims + 1, c_flat);
flat_node = builder::reshape(data, dispersed_shape);
flat_node = builder::opset1::reshape(data, dispersed_shape);
axes_order.push_back(spatial_dims + 1);
for (int i = 2; i < data_shape.size(); ++i)
......@@ -131,7 +131,7 @@ NodeVector op::DepthToSpace::decompose_op() const
axes_order.push_back(spatial_dims + i);
axes_order.push_back(i - 1);
}
flat_node = builder::reorder_axes(flat_node, axes_order);
flat_node = builder::opset1::reorder_axes(flat_node, axes_order);
}
}
Shape squeezed_shape{n_dim, c_flat};
......@@ -139,7 +139,7 @@ NodeVector op::DepthToSpace::decompose_op() const
{
squeezed_shape.push_back(data_shape.at(i) * bs);
}
flat_node = builder::reshape(flat_node, squeezed_shape);
flat_node = builder::opset1::reshape(flat_node, squeezed_shape);
return NodeVector{flat_node};
}
......
......@@ -18,10 +18,10 @@
#include "group_conv.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/validation_util.hpp"
......@@ -516,8 +516,9 @@ NodeVector op::v0::GroupConvolution::decompose_op() const
if (m_groups_in_filters)
{
// Remove group dimmension after slicing
sliced_filter = builder::reshape(
sliced_filter = make_shared<op::Reshape>(
sliced_filters[group],
get_default_order(sliced_filters[group]->get_shape().size()),
Shape(std::next(std::begin(filters_shape), 1), std::end(filters_shape)));
}
convolution_nodes.push_back(
......
......@@ -58,7 +58,7 @@ NodeVector op::SpaceToDepth::decompose_op() const
{
// Insert batch axis
data_shape.insert(data_shape.begin(), 1);
data = builder::reshape(data, data_shape);
data = builder::opset1::reshape(data, data_shape);
}
const size_t n_dim = data_shape.at(0);
......@@ -87,7 +87,7 @@ NodeVector op::SpaceToDepth::decompose_op() const
dispersed_shape.push_back(data_shape.at(i + spatial_dim_index) / m_blocksize);
dispersed_shape.push_back(m_blocksize);
}
auto flat_node = builder::reshape(data, dispersed_shape);
auto flat_node = builder::opset1::reshape(data, dispersed_shape);
// calculate axes to transpose
// [0, 3, 5, ..., spatial_dims + (spatial_dims + 1), 2, 4, ..., K + K])
vector<size_t> axes_order{0};
......@@ -121,14 +121,14 @@ NodeVector op::SpaceToDepth::decompose_op() const
default: { axes_order.insert(axes_order.begin() + spatial_dims + 1, 1);
}
}
flat_node = builder::reorder_axes(flat_node, axes_order);
flat_node = builder::opset1::reorder_axes(flat_node, axes_order);
Shape squeezed_shape{n_dim};
for (int i = 0; i < spatial_dims; ++i)
{
squeezed_shape.push_back(data_shape.at(spatial_dim_index + i) / m_blocksize);
}
squeezed_shape.insert(squeezed_shape.begin() + 1, c_dim * std::pow(m_blocksize, spatial_dims));
flat_node = builder::reshape(flat_node, squeezed_shape);
flat_node = builder::opset1::reshape(flat_node, squeezed_shape);
return NodeVector{flat_node};
}
......
......@@ -244,7 +244,7 @@ NGRAPH_OP(TensorIterator, ngraph::op::v0, 0)
NGRAPH_OP(Tile, ngraph::op::v0, 0)
NGRAPH_OP(TopK, ngraph::op::v0, 0)
NGRAPH_OP(TopK, ngraph::op::v1, 1)
NGRAPH_OP(Transpose, ngraph::op::v0, 0)
NGRAPH_OP(Transpose, ngraph::op::v1, 1)
NGRAPH_OP(Unsqueeze, ngraph::op::v0, 0)
NGRAPH_OP(VariadicSplit, ngraph::op::v1, 1)
NGRAPH_OP(Xor, ngraph::op::v0, 0)
......@@ -206,6 +206,5 @@ NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Tile, ngraph::op)
NGRAPH_OP(TopK, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op)
NGRAPH_OP(Xor, ngraph::op)
......@@ -156,7 +156,7 @@ NGRAPH_OP(Tanh, ngraph::op::v0)
NGRAPH_OP(TensorIterator, ngraph::op::v0)
NGRAPH_OP(Tile, ngraph::op::v0)
NGRAPH_OP(TopK, ngraph::op::v1)
NGRAPH_OP(Transpose, ngraph::op::v0)
NGRAPH_OP(Transpose, ngraph::op::v1)
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
NGRAPH_OP(VariadicSplit, ngraph::op::v1)
NGRAPH_OP(Xor, ngraph::op::v0)
......@@ -265,9 +265,15 @@ namespace
shared_ptr<Node> replacement_node;
const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static())
const auto input_rank = node->get_input_partial_shape(0).rank();
if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static() &&
input_rank.is_static())
{
replacement_node = builder::reshape(node->input_value(0), node->get_output_shape(0));
const auto output_shape = node->get_output_shape(0);
replacement_node =
make_shared<op::Reshape>(node->input_value(0),
get_default_order(static_cast<size_t>(input_rank)),
output_shape);
}
else
{
......@@ -396,7 +402,8 @@ namespace
filters_shape[1] *= groups;
filters_shape.erase(filters_shape.begin());
auto reshaped_filters = builder::reshape(node->input_value(1), filters_shape);
auto reshaped_filters = make_shared<op::v0::Reshape>(
filters_arg, get_default_order(filters_arg.get_shape().size()), filters_shape);
auto pads_begin = node->get_pads_begin();
auto pads_end = node->get_pads_end();
......@@ -774,6 +781,44 @@ namespace
return true;
}
bool op_cast(shared_ptr<op::v1::Transpose> node)
{
const auto data = node->input_value(0);
const auto data_pshape = data.get_partial_shape();
NGRAPH_CHECK(data_pshape.is_static(),
"Unable to convert Transpose:v1 to Reshape:v0 "
"if data shape is dynamic. Node: ",
*node);
const auto data_shape = data_pshape.to_shape();
const auto order_node = node->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(order_node->is_constant(),
"Unable to convert Transpose:v1 to Reshape:v0 "
"if order node is not constant. Node: ",
*node);
const auto order_const = as_type_ptr<op::Constant>(order_node);
auto order = order_const->get_axis_vector_val();
Shape out_shape = data_shape;
if (order.empty())
{
order.resize(out_shape.size());
iota(begin(order), end(order), 0);
}
else
{
for (size_t i = 0; i < order.size(); ++i)
{
out_shape[i] = data_shape.at(order.at(i));
}
}
auto replacement_node = make_shared<op::v0::Reshape>(data, order, out_shape);
replace_node(node, replacement_node);
return true;
}
bool op_cast(shared_ptr<op::v1::VariadicSplit> node)
{
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
......
......@@ -1832,7 +1832,6 @@ protected:
case OP_TYPEID::MatMul:
case OP_TYPEID::Split:
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
case OP_TYPEID::Tile:
case OP_TYPEID::DynReplaceSlice:
......
......@@ -2926,9 +2926,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = move(topk);
break;
}
case OP_TYPEID::Transpose:
case OP_TYPEID::Transpose_v1:
{
node = make_shared<op::Transpose>(args[0], args[1]);
node = make_shared<op::v1::Transpose>(args[0], args[1]);
break;
}
case OP_TYPEID::StopGradient:
......@@ -4579,7 +4579,7 @@ json JSONSerializer::serialize_node(const Node& n)
node["index_element_type"] = write_element_type(tmp->get_index_element_type());
break;
}
case OP_TYPEID::Transpose: { break;
case OP_TYPEID::Transpose_v1: { break;
}
case OP_TYPEID::Unsqueeze: { break;
}
......
......@@ -92,6 +92,7 @@ set(SRC
opset_pass/slice_opset_pass.cpp
opset_pass/softmax_opset_pass.cpp
opset_pass/topk_opset_pass.cpp
opset_pass/transpose_opset_pass.cpp
partial_shape.cpp
pass.cpp
pass_liveness.cpp
......
......@@ -1428,15 +1428,6 @@ namespace
EXPECT_FALSE(node.is_binary_elementwise_logical());
}
void op_is_Transpose()
{
op::Transpose node;
EXPECT_FALSE(node.is_unary_elementwise_arithmetic());
EXPECT_FALSE(node.is_binary_elementwise_arithmetic());
EXPECT_FALSE(node.is_binary_elementwise_comparison());
EXPECT_FALSE(node.is_binary_elementwise_logical());
}
void op_is_Unsqueeze()
{
op::Unsqueeze node;
......
......@@ -145,7 +145,7 @@ TEST(opset, check_opset1)
CHECK_OPSET(op::v0::TensorIterator, opset1::TensorIterator)
CHECK_OPSET(op::v0::Tile, opset1::Tile)
CHECK_OPSET(op::v1::TopK, opset1::TopK)
CHECK_OPSET(op::v0::Transpose, opset1::Transpose)
CHECK_OPSET(op::v1::Transpose, opset1::Transpose)
CHECK_OPSET(op::v0::Unsqueeze, opset1::Unsqueeze)
CHECK_OPSET(op::v1::VariadicSplit, opset1::VariadicSplit)
CHECK_OPSET(op::v0::Xor, opset1::Xor)
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_transpose_downgrade_pass)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{4, 5, 6, 7});
AxisVector order{2, 1, 3, 0};
const auto order_node = op::Constant::create(element::i64, Shape{order.size()}, order);
auto transpose = make_shared<op::v1::Transpose>(data, order_node);
auto result = make_shared<op::Result>(transpose);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto reshape_result = f->get_results().at(0);
auto reshape_node = as_type_ptr<op::v0::Reshape>(
reshape_result->input(0).get_source_output().get_node_shared_ptr());
ASSERT_TRUE(reshape_node);
EXPECT_EQ(reshape_node->get_input_order(), order);
EXPECT_EQ(reshape_node->get_output_shape(), Shape({6, 5, 7, 4}));
}
TEST(opset_transform, opset1_transpose_downgrade_pass_data_shape_not_staic)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
AxisVector order{2, 1, 3, 0};
const auto order_node = op::Constant::create(element::i64, Shape{order.size()}, order);
auto transpose = make_shared<op::v1::Transpose>(data, order_node);
auto result = make_shared<op::Result>(transpose);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Transpose Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Unable to convert Transpose:v1 to Reshape:v0 "
"if data shape is dynamic. Node:"));
}
catch (...)
{
FAIL() << "Transpose pass failed for unexpected reason";
}
}
TEST(opset_transform, opset1_transpose_downgrade_pass_order_not_constant)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{4, 5, 6, 7});
const auto order_node = make_shared<op::Parameter>(element::i64, Shape{4});
auto transpose = make_shared<op::v1::Transpose>(data, order_node);
auto result = make_shared<op::Result>(transpose);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{data, order_node});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
try
{
pass_manager.run_passes(f);
FAIL() << "Exception after Transpose Opset0Downgrade pass was not thrown.";
}
catch (const ngraph_error& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Unable to convert Transpose:v1 to Reshape:v0 "
"if order node is not constant. Node:"));
}
catch (...)
{
FAIL() << "Transpose pass failed for unexpected reason";
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment