Commit 9b635830 authored by Dmitry Kurtaev's avatar Dmitry Kurtaev

Run Reshape layer for const input from ONNX models

parent 4f764b81
...@@ -140,9 +140,10 @@ Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto) ...@@ -140,9 +140,10 @@ Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
return blob; return blob;
} }
void runLayer(Ptr<Layer> layer, const std::vector<Mat>& inputs, void runLayer(LayerParams& params, const std::vector<Mat>& inputs,
std::vector<Mat>& outputs) std::vector<Mat>& outputs)
{ {
Ptr<Layer> layer = LayerFactory::createLayerInstance(params.type, params);
std::vector<MatShape> inpShapes(inputs.size()); std::vector<MatShape> inpShapes(inputs.size());
int ddepth = CV_32F; int ddepth = CV_32F;
for (size_t i = 0; i < inputs.size(); ++i) for (size_t i = 0; i < inputs.size(); ++i)
...@@ -669,14 +670,15 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -669,14 +670,15 @@ void ONNXImporter::populateNet(Net dstNet)
Mat blob = getBlob(node_proto, constBlobs, 1); Mat blob = getBlob(node_proto, constBlobs, 1);
CV_Assert(blob.type() == CV_32SC1); CV_Assert(blob.type() == CV_32SC1);
layerParams.set("dim", DictValue::arrayInt<int*>(
blob.ptr<int>(), blob.total() ));
if (layer_id.find(node_proto.input(0)) == layer_id.end()) { if (layer_id.find(node_proto.input(0)) == layer_id.end()) {
Mat input = getBlob(node_proto, constBlobs, 0); std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), outputs;
Mat out = input.reshape(0, static_cast<std::vector<int> >(blob)); runLayer(layerParams, inputs, outputs);
constBlobs.insert(std::make_pair(layerParams.name, out)); constBlobs.insert(std::make_pair(layerParams.name, outputs[0]));
continue; continue;
} }
layerParams.set("dim", DictValue::arrayInt<int*>(
blob.ptr<int>(), blob.total() ));
} }
else { else {
DictValue shape = layerParams.get("shape"); DictValue shape = layerParams.get("shape");
...@@ -749,8 +751,7 @@ void ONNXImporter::populateNet(Net dstNet) ...@@ -749,8 +751,7 @@ void ONNXImporter::populateNet(Net dstNet)
{ {
inputs[i] = getBlob(node_proto, constBlobs, i); inputs[i] = getBlob(node_proto, constBlobs, i);
} }
Ptr<Layer> concat = ConcatLayer::create(layerParams); runLayer(layerParams, inputs, concatenated);
runLayer(concat, inputs, concatenated);
CV_Assert(concatenated.size() == 1); CV_Assert(concatenated.size() == 1);
constBlobs.insert(std::make_pair(layerParams.name, concatenated[0])); constBlobs.insert(std::make_pair(layerParams.name, concatenated[0]));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment