Commit d971678a authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Add a planar data layout tracking for TensorFlow importer

parent e4b51fa8
...@@ -51,7 +51,8 @@ enum DataLayout ...@@ -51,7 +51,8 @@ enum DataLayout
{ {
DATA_LAYOUT_NHWC, DATA_LAYOUT_NHWC,
DATA_LAYOUT_NCHW, DATA_LAYOUT_NCHW,
DATA_LAYOUT_UNKNOWN DATA_LAYOUT_UNKNOWN,
DATA_LAYOUT_PLANAR // 2-dimensional outputs (matmul, flatten, reshape to 2d)
}; };
typedef std::vector<std::pair<String, int> > StrIntVector; typedef std::vector<std::pair<String, int> > StrIntVector;
...@@ -948,7 +949,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -948,7 +949,7 @@ void TFImporter::populateNet(Net dstNet)
// one input only // one input only
int input_blob_index = kernel_blob_index == 0 ? 1 : 0; int input_blob_index = kernel_blob_index == 0 ? 1 : 0;
connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0); connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0);
data_layouts[name] = DATA_LAYOUT_UNKNOWN; data_layouts[name] = DATA_LAYOUT_PLANAR;
} }
else if (type == "Reshape") else if (type == "Reshape")
{ {
...@@ -981,7 +982,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -981,7 +982,7 @@ void TFImporter::populateNet(Net dstNet)
// one input only // one input only
connect(layer_id, dstNet, inpId, id, 0); connect(layer_id, dstNet, inpId, id, 0);
data_layouts[name] = DATA_LAYOUT_UNKNOWN; data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : DATA_LAYOUT_UNKNOWN;
} }
else if (type == "Flatten" || type == "Squeeze") else if (type == "Flatten" || type == "Squeeze")
{ {
...@@ -1020,7 +1021,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1020,7 +1021,7 @@ void TFImporter::populateNet(Net dstNet)
int id = dstNet.addLayer(name, "Flatten", layerParams); int id = dstNet.addLayer(name, "Flatten", layerParams);
layer_id[name] = id; layer_id[name] = id;
connect(layer_id, dstNet, inpId, id, 0); connect(layer_id, dstNet, inpId, id, 0);
data_layouts[name] = DATA_LAYOUT_UNKNOWN; data_layouts[name] = DATA_LAYOUT_PLANAR;
} }
else if (type == "Transpose") else if (type == "Transpose")
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment