Commit 04ea0671 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

v1::Reshape downgrade pass + onnx_importer adjustments (#4046)

parent 228570eb
......@@ -37,32 +37,28 @@ namespace ngraph
NodeVector reshape(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
const auto data = ng_inputs.at(0);
auto output_shape =
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
std::shared_ptr<ngraph::Node> pattern;
// If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2)
// Since opset 5 the target shape is provided as input
if (ng_inputs.size() == 2)
{
// Currently only support Constant node.
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
<< "doesn't support shape input of other type than Constant.";
NGRAPH_CHECK(ng_inputs.at(1)->is_constant(),
"The target shape input has to be a Constant.");
output_shape = ngraph::as_type_ptr<ngraph::op::Constant>(ng_inputs.at(1))
->get_vector<std::size_t>();
pattern = ng_inputs.at(1);
}
// Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty())
else
{
return {data};
const auto output_shape =
node.get_attribute_value<std::vector<int64_t>>("shape", {});
pattern = ngraph::op::Constant::create(
element::i64, Shape{output_shape.size()}, output_shape);
}
output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data, ngraph::get_default_order(data_shape.size()), Shape{output_shape})};
return {std::make_shared<ngraph::op::v1::Reshape>(data, pattern, true)};
}
} // namespace set_1
......
......@@ -221,8 +221,19 @@ namespace
bool op_cast(shared_ptr<op::v1::Reshape> node)
{
auto replacement_node = make_shared<op::v0::DynReshape>(
node->input_value(0), node->input_value(1), node->get_special_zero());
shared_ptr<Node> replacement_node;
const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static())
{
replacement_node = builder::reshape(node->input_value(0), node->get_output_shape(0));
}
else
{
replacement_node = make_shared<op::v0::DynReshape>(
node->input_value(0), node->input_value(1), node->get_special_zero());
}
replace_node(node, replacement_node);
return true;
}
......
ir_version: 3
ir_version: 6
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
output: "B"
input: "data"
input: "shape"
output: "reshaped"
op_type: "Reshape"
attribute {
name: "shape"
ints: 6
ints: 2
ints: 2
type: INTS
}
}
name: "compute_graph"
name: "test_reshape_negative_dim"
initializer {
dims: 3
data_type: 7
int64_data: 2
int64_data: -1
int64_data: 2
name: "shape"
}
input {
name: "A"
name: "data"
type {
tensor_type {
elem_type: 1
......@@ -33,17 +35,30 @@ graph {
}
}
}
input {
name: "shape"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "B"
name: "reshaped"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 6
dim_value: 2
}
dim {
dim_value: 2
dim_value: 6
}
dim {
dim_value: 2
......
......@@ -124,21 +124,34 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reshape_single_dim)
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reshape_negative_dim)
{
auto function = onnx_import::import_onnx_model(
// the model contains the target shape in the initializers: [2, -1, 2]
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.prototxt"));
// input data shape (2, 3, 4)
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}})
// 2x3x4
Inputs inputs{test::NDArray<float, 3>({{{0.5488135, 0.71518934, 0.60276335, 0.5448832},
{0.4236548, 0.6458941, 0.4375872, 0.891773},
{0.96366274, 0.3834415, 0.79172504, 0.5288949}},
{{0.56804454, 0.92559665, 0.07103606, 0.0871293},
{0.0202184, 0.83261985, 0.77815676, 0.87001216},
{0.9786183, 0.7991586, 0.46147937, 0.7805292}}})
.get_vector()};
// output data shape (6, 2, 2)
Outputs expected_outputs{test::NDArray<float, 3>({{{0, 1}, {2, 3}},
{{4, 5}, {6, 7}},
{{8, 9}, {10, 11}},
{{12, 13}, {14, 15}},
{{16, 17}, {18, 19}},
{{20, 21}, {22, 23}}})
// 2x6x2
Outputs expected_outputs{test::NDArray<float, 3>({{{0.5488135, 0.71518934},
{0.60276335, 0.5448832},
{0.4236548, 0.6458941},
{0.4375872, 0.891773},
{0.96366274, 0.3834415},
{0.79172504, 0.5288949}},
{{0.56804454, 0.92559665},
{0.07103606, 0.0871293},
{0.0202184, 0.83261985},
{0.77815676, 0.87001216},
{0.9786183, 0.7991586},
{0.46147937, 0.7805292}}})
.get_vector()};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment