Commit e4a80aee authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Fix #15296

parent 7c96857c
...@@ -1129,15 +1129,14 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1129,15 +1129,14 @@ void TFImporter::populateNet(Net dstNet)
if (value_id.find(layer.input(1)) != value_id.end()) if (value_id.find(layer.input(1)) != value_id.end())
{ {
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1)); Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
if (newShape.total() == 4)
{
// NHWC->NCHW
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
}
if (inpLayout == DATA_LAYOUT_NHWC) if (inpLayout == DATA_LAYOUT_NHWC)
{ {
if (newShape.total() == 4)
{
// NHWC->NCHW
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
}
if (newShape.total() != 4 || newShape.at<int>(1) == 1) if (newShape.total() != 4 || newShape.at<int>(1) == 1)
{ {
LayerParams permLP; LayerParams permLP;
......
...@@ -279,7 +279,7 @@ TEST_P(Test_TensorFlow_layers, matmul) ...@@ -279,7 +279,7 @@ TEST_P(Test_TensorFlow_layers, matmul)
// Reference output values are in range [-5.688, 4.484] // Reference output values are in range [-5.688, 4.484]
double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1; double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1;
runTensorFlowNet("nhwc_reshape_matmul", false, l1); runTensorFlowNet("nhwc_reshape_matmul", false, l1);
runTensorFlowNet("matmul_layout");
} }
TEST_P(Test_TensorFlow_layers, reshape) TEST_P(Test_TensorFlow_layers, reshape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment