Commit 04ea0671 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Michał Karzyński

v1::Reshape downgrade pass + onnx_importer adjustments (#4046)

parent 228570eb
...@@ -37,32 +37,28 @@ namespace ngraph ...@@ -37,32 +37,28 @@ namespace ngraph
NodeVector reshape(const Node& node) NodeVector reshape(const Node& node)
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0); const auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = std::shared_ptr<ngraph::Node> pattern;
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input. // Since opset 5 the target shape is provided as input
if (output_shape.empty() && ng_inputs.size() == 2) if (ng_inputs.size() == 2)
{ {
// Currently only support Constant node. NGRAPH_CHECK(ng_inputs.at(1)->is_constant(),
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant") "The target shape input has to be a Constant.");
<< "doesn't support shape input of other type than Constant.";
output_shape = ngraph::as_type_ptr<ngraph::op::Constant>(ng_inputs.at(1)) pattern = ng_inputs.at(1);
->get_vector<std::size_t>();
} }
// Do nothing if there is no shape argument nor second node input. else
else if (output_shape.empty())
{ {
return {data}; const auto output_shape =
node.get_attribute_value<std::vector<int64_t>>("shape", {});
pattern = ngraph::op::Constant::create(
element::i64, Shape{output_shape.size()}, output_shape);
} }
output_shape = return {std::make_shared<ngraph::op::v1::Reshape>(data, pattern, true)};
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data, ngraph::get_default_order(data_shape.size()), Shape{output_shape})};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -221,8 +221,19 @@ namespace ...@@ -221,8 +221,19 @@ namespace
bool op_cast(shared_ptr<op::v1::Reshape> node) bool op_cast(shared_ptr<op::v1::Reshape> node)
{ {
auto replacement_node = make_shared<op::v0::DynReshape>( shared_ptr<Node> replacement_node;
node->input_value(0), node->input_value(1), node->get_special_zero());
const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static())
{
replacement_node = builder::reshape(node->input_value(0), node->get_output_shape(0));
}
else
{
replacement_node = make_shared<op::v0::DynReshape>(
node->input_value(0), node->input_value(1), node->get_special_zero());
}
replace_node(node, replacement_node); replace_node(node, replacement_node);
return true; return true;
} }
......
ir_version: 3 ir_version: 6
producer_name: "nGraph ONNX Importer" producer_name: "nGraph ONNX Importer"
graph { graph {
node { node {
input: "A" input: "data"
output: "B" input: "shape"
output: "reshaped"
op_type: "Reshape" op_type: "Reshape"
attribute {
name: "shape"
ints: 6
ints: 2
ints: 2
type: INTS
}
} }
name: "compute_graph" name: "test_reshape_negative_dim"
initializer {
dims: 3
data_type: 7
int64_data: 2
int64_data: -1
int64_data: 2
name: "shape"
}
input { input {
name: "A" name: "data"
type { type {
tensor_type { tensor_type {
elem_type: 1 elem_type: 1
...@@ -33,17 +35,30 @@ graph { ...@@ -33,17 +35,30 @@ graph {
} }
} }
} }
input {
name: "shape"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
output { output {
name: "B" name: "reshaped"
type { type {
tensor_type { tensor_type {
elem_type: 1 elem_type: 1
shape { shape {
dim { dim {
dim_value: 6 dim_value: 2
} }
dim { dim {
dim_value: 2 dim_value: 6
} }
dim { dim {
dim_value: 2 dim_value: 2
......
...@@ -124,21 +124,34 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reshape_single_dim) ...@@ -124,21 +124,34 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reshape_single_dim)
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reshape_negative_dim) NGRAPH_TEST(onnx_${BACKEND_NAME}, model_reshape_negative_dim)
{ {
auto function = onnx_import::import_onnx_model( // the model contains the target shape in the initializers: [2, -1, 2]
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.prototxt")); file_util::path_join(SERIALIZED_ZOO, "onnx/reshape_negative_dim.prototxt"));
// input data shape (2, 3, 4) // 2x3x4
Inputs inputs{test::NDArray<float, 3>({{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}}, Inputs inputs{test::NDArray<float, 3>({{{0.5488135, 0.71518934, 0.60276335, 0.5448832},
{{12, 13, 14, 15}, {16, 17, 18, 19}, {20, 21, 22, 23}}}) {0.4236548, 0.6458941, 0.4375872, 0.891773},
{0.96366274, 0.3834415, 0.79172504, 0.5288949}},
{{0.56804454, 0.92559665, 0.07103606, 0.0871293},
{0.0202184, 0.83261985, 0.77815676, 0.87001216},
{0.9786183, 0.7991586, 0.46147937, 0.7805292}}})
.get_vector()}; .get_vector()};
// output data shape (6, 2, 2) // 2x6x2
Outputs expected_outputs{test::NDArray<float, 3>({{{0, 1}, {2, 3}}, Outputs expected_outputs{test::NDArray<float, 3>({{{0.5488135, 0.71518934},
{{4, 5}, {6, 7}}, {0.60276335, 0.5448832},
{{8, 9}, {10, 11}}, {0.4236548, 0.6458941},
{{12, 13}, {14, 15}}, {0.4375872, 0.891773},
{{16, 17}, {18, 19}}, {0.96366274, 0.3834415},
{{20, 21}, {22, 23}}}) {0.79172504, 0.5288949}},
{{0.56804454, 0.92559665},
{0.07103606, 0.0871293},
{0.0202184, 0.83261985},
{0.77815676, 0.87001216},
{0.9786183, 0.7991586},
{0.46147937, 0.7805292}}})
.get_vector()}; .get_vector()};
Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")}; Outputs outputs{execute(function, inputs, "${BACKEND_NAME}")};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment