Commit 57cf1201 authored by Alexander Alekhin's avatar Alexander Alekhin

Merge pull request #16709 from ashishkrshrivastava:cvonnx

parents 0b85d0ec e18d5e94
...@@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net, ...@@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
{ {
int numNodes = net->getNumNodes(); int numNodes = net->getNumNodes();
std::vector<int> matchedNodesIds, targetNodesIds; std::vector<int> matchedNodesIds, targetNodesIds;
for (int i = 0; i < numNodes; ++i) for (int j = 0; j < patterns.size(); ++j)
{ {
for (int j = 0; j < patterns.size(); ++j) for (int i = 0; i < numNodes; ++i)
{ {
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
{ {
patterns[j]->replace(net, matchedNodesIds, targetNodesIds); patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added.
break;
} }
} }
} }
......
...@@ -154,6 +154,32 @@ private: ...@@ -154,6 +154,32 @@ private:
int axis; int axis;
}; };
class GatherCastSubgraph : public Subgraph
{
public:
GatherCastSubgraph()
{
int input = addNodeToMatch("");
int index = addNodeToMatch("Constant");
int gather = addNodeToMatch("Gather", input, index);
addNodeToMatch("Cast", gather);
setFusedNode("Gather", input, index);
}
};
class MulCastSubgraph : public Subgraph
{
public:
MulCastSubgraph()
{
int input = addNodeToMatch("");
int scaleNode = addNodeToMatch("Constant");
int mul = addNodeToMatch("Mul", input, scaleNode);
addNodeToMatch("Cast", mul);
setFusedNode("Mul", input, scaleNode);
}
};
class ExtractScalesSubgraph : public Subgraph class ExtractScalesSubgraph : public Subgraph
{ {
public: public:
...@@ -164,20 +190,16 @@ public: ...@@ -164,20 +190,16 @@ public:
int indexH = addNodeToMatch("Constant"); int indexH = addNodeToMatch("Constant");
int shape1 = addNodeToMatch("Shape", input); int shape1 = addNodeToMatch("Shape", input);
int gather1 = addNodeToMatch("Gather", shape1, indexH); int gather1 = addNodeToMatch("Gather", shape1, indexH);
int castG1 = addNodeToMatch("Cast", gather1);
scaleHNode = addNodeToMatch("Constant"); scaleHNode = addNodeToMatch("Constant");
int mul1 = addNodeToMatch("Mul", castG1, scaleHNode); int mul1 = addNodeToMatch("Mul", gather1, scaleHNode);
int castM1 = addNodeToMatch("Cast", mul1); int floor1 = addNodeToMatch("Floor", mul1);
int floor1 = addNodeToMatch("Floor", castM1);
int indexW = addNodeToMatch("Constant"); int indexW = addNodeToMatch("Constant");
int shape2 = addNodeToMatch("Shape", input); int shape2 = addNodeToMatch("Shape", input);
int gather2 = addNodeToMatch("Gather", shape2, indexW); int gather2 = addNodeToMatch("Gather", shape2, indexW);
int castG2 = addNodeToMatch("Cast", gather2);
scaleWNode = addNodeToMatch("Constant"); scaleWNode = addNodeToMatch("Constant");
int mul2 = addNodeToMatch("Mul", castG2, scaleWNode); int mul2 = addNodeToMatch("Mul", gather2, scaleWNode);
int castM2 = addNodeToMatch("Cast", mul2); int floor2 = addNodeToMatch("Floor", mul2);
int floor2 = addNodeToMatch("Floor", castM2);
int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1); int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1);
int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2); int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2);
...@@ -190,19 +212,23 @@ public: ...@@ -190,19 +212,23 @@ public:
{ {
opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t(); opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t();
float scaleW = getMatFromTensor(tensor_proto).at<float>(0); Mat scaleW = getMatFromTensor(tensor_proto);
CV_Assert(scaleW.total() == 1);
scaleW.convertTo(scaleW, CV_32F);
constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node; constant_node = inputs[2].dynamicCast<ONNXNodeWrapper>()->node;
tensor_proto = constant_node->attribute(0).t(); tensor_proto = constant_node->attribute(0).t();
float scaleH = getMatFromTensor(tensor_proto).at<float>(0); Mat scaleH = getMatFromTensor(tensor_proto);
CV_Assert(scaleH.total() == 1);
scaleH.convertTo(scaleH, CV_32F);
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node; opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::AttributeProto* attrH = node->add_attribute(); opencv_onnx::AttributeProto* attrH = node->add_attribute();
attrH->set_name("height_scale"); attrH->set_name("height_scale");
attrH->set_i(scaleH); attrH->set_i(scaleH.at<float>(0));
opencv_onnx::AttributeProto* attrW = node->add_attribute(); opencv_onnx::AttributeProto* attrW = node->add_attribute();
attrW->set_name("width_scale"); attrW->set_name("width_scale");
attrW->set_i(scaleW); attrW->set_i(scaleW.at<float>(0));
node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs
} }
...@@ -267,6 +293,8 @@ public: ...@@ -267,6 +293,8 @@ public:
void simplifySubgraphs(opencv_onnx::GraphProto& net) void simplifySubgraphs(opencv_onnx::GraphProto& net)
{ {
std::vector<Ptr<Subgraph> > subgraphs; std::vector<Ptr<Subgraph> > subgraphs;
subgraphs.push_back(makePtr<GatherCastSubgraph>());
subgraphs.push_back(makePtr<MulCastSubgraph>());
subgraphs.push_back(makePtr<UpsampleSubgraph>()); subgraphs.push_back(makePtr<UpsampleSubgraph>());
subgraphs.push_back(makePtr<ResizeSubgraph1>()); subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>()); subgraphs.push_back(makePtr<ResizeSubgraph2>());
......
...@@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused) ...@@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("upsample_unfused_torch1.2");
testONNXModels("upsample_unfused_opset9_torch1.4"); testONNXModels("upsample_unfused_opset9_torch1.4");
testONNXModels("resize_nearest_unfused_opset11_torch1.4"); testONNXModels("resize_nearest_unfused_opset11_torch1.4");
testONNXModels("resize_nearest_unfused_opset11_torch1.3"); testONNXModels("resize_nearest_unfused_opset11_torch1.3");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment