Commit 35c24480 authored by Liubov Batanina's avatar Liubov Batanina

Fix axis

parent 832ca073
...@@ -1967,7 +1967,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1967,7 +1967,7 @@ void TFImporter::populateNet(Net dstNet)
LayerParams reshapeLp; LayerParams reshapeLp;
std::string reshapeName = name + "/reshape"; std::string reshapeName = name + "/reshape";
CV_Assert(layer_id.find(reshapeName) == layer_id.end()); CV_Assert(layer_id.find(reshapeName) == layer_id.end());
reshapeLp.set("axis", indices.at<int>(0)); reshapeLp.set("axis", 0);
reshapeLp.set("num_axes", 1); reshapeLp.set("num_axes", 1);
int newShape[] = {1, 1, -1}; int newShape[] = {1, 1, -1};
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 3)); reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 3));
...@@ -1990,7 +1990,7 @@ void TFImporter::populateNet(Net dstNet) ...@@ -1990,7 +1990,7 @@ void TFImporter::populateNet(Net dstNet)
LayerParams sliceLp; LayerParams sliceLp;
std::string layerShapeName = name + "/slice"; std::string layerShapeName = name + "/slice";
CV_Assert(layer_id.find(layerShapeName) == layer_id.end()); CV_Assert(layer_id.find(layerShapeName) == layer_id.end());
sliceLp.set("axis", indices.at<int>(0)); sliceLp.set("axis", 0);
int begin[] = {0}; int begin[] = {0};
int size[] = {1}; int size[] = {1};
sliceLp.set("begin", DictValue::arrayInt(&begin[0], 1)); sliceLp.set("begin", DictValue::arrayInt(&begin[0], 1));
...@@ -2004,8 +2004,8 @@ void TFImporter::populateNet(Net dstNet) ...@@ -2004,8 +2004,8 @@ void TFImporter::populateNet(Net dstNet)
LayerParams squeezeLp; LayerParams squeezeLp;
std::string squeezeName = name + "/squeeze"; std::string squeezeName = name + "/squeeze";
CV_Assert(layer_id.find(squeezeName) == layer_id.end()); CV_Assert(layer_id.find(squeezeName) == layer_id.end());
squeezeLp.set("axis", indices.at<int>(0)); squeezeLp.set("axis", 0);
squeezeLp.set("end_axis", indices.at<int>(0) + 1); squeezeLp.set("end_axis", 1);
int squeezeId = dstNet.addLayer(squeezeName, "Flatten", squeezeLp); int squeezeId = dstNet.addLayer(squeezeName, "Flatten", squeezeLp);
layer_id[squeezeName] = squeezeId; layer_id[squeezeName] = squeezeId;
connect(layer_id, dstNet, Pin(layerShapeName), squeezeId, 0); connect(layer_id, dstNet, Pin(layerShapeName), squeezeId, 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment