Commit 9ed372b2 authored by Liubov Batanina's avatar Liubov Batanina

Update get memory shapes

parent 46253371
...@@ -149,11 +149,12 @@ public: ...@@ -149,11 +149,12 @@ public:
out.push_back(outputs[0].size[i]); out.push_back(outputs[0].size[i]);
} }
kernel_size.resize(out.size()); kernel_size.resize(out.size());
int diff_size = isGlobalPooling.size() - kernel_size.size();
for (int i = 0; i < kernel_size.size(); i++) for (int i = 0; i < kernel_size.size(); i++)
{ {
if (isGlobalPooling[i + diff_size]) int pool_idx = isGlobalPooling.size() - 1 - i;
kernel_size[i] = inp[i]; int kernel_idx = kernel_size.size() - 1 - i;
if (isGlobalPooling[pool_idx])
kernel_size[kernel_idx] = inp[kernel_idx];
} }
kernel = Size(kernel_size[1], kernel_size[0]); kernel = Size(kernel_size[1], kernel_size[0]);
...@@ -1001,20 +1002,27 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp ...@@ -1001,20 +1002,27 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
std::vector<int> inpShape(inputs[0].begin() + 2, inputs[0].end()); std::vector<int> inpShape(inputs[0].begin() + 2, inputs[0].end());
std::vector<int> outShape(inputs[0].begin(), inputs[0].begin() + 2); std::vector<int> outShape(inputs[0].begin(), inputs[0].begin() + 2);
if (globalPooling) std::vector<size_t> local_kernel = kernel_size.empty() ?
std::vector<size_t>(inpShape.begin(), inpShape.end()) : kernel_size;
for (int i = 0; i < local_kernel.size(); i++)
{ {
outShape.push_back(1); int pool_idx = isGlobalPooling.size() - 1 - i;
outShape.push_back(1); int kernel_idx = local_kernel.size() - 1 - i;
if (isGlobalPooling[pool_idx])
local_kernel[kernel_idx] = inpShape[kernel_idx];
} }
else if (type == ROI || type == PSROI)
if (type == ROI || type == PSROI)
{ {
outShape.push_back(pooledSize.height); outShape.push_back(pooledSize.height);
outShape.push_back(pooledSize.width); outShape.push_back(pooledSize.width);
} }
else if (padMode.empty()) else if (padMode.empty())
{ {
for (int i = 0; i < kernel_size.size(); i++) { for (int i = 0; i < local_kernel.size(); i++) {
float dst = (float)(inpShape[i] + pads_begin[i] + pads_end[i] - kernel_size[i]) / strides[i]; float dst = (float)(inpShape[i] + pads_begin[i] + pads_end[i] - local_kernel[i]) / strides[i];
outShape.push_back(1 + (ceilMode ? ceil(dst) : floor(dst))); outShape.push_back(1 + (ceilMode ? ceil(dst) : floor(dst)));
} }
...@@ -1029,7 +1037,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp ...@@ -1029,7 +1037,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
} }
else else
{ {
getConvPoolOutParams(inpShape, kernel_size, strides, padMode, std::vector<size_t>(kernel_size.size(), 1), outShape); getConvPoolOutParams(inpShape, local_kernel, strides, padMode, std::vector<size_t>(local_kernel.size(), 1), outShape);
} }
if (type == ROI) if (type == ROI)
{ {
...@@ -1044,13 +1052,6 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp ...@@ -1044,13 +1052,6 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
outShape[1] = psRoiOutChannels; outShape[1] = psRoiOutChannels;
} }
int diff_size = isGlobalPooling.size() - (outShape.size() - 2);
for (int i = 2; i < outShape.size(); i++)
{
if (isGlobalPooling[i - 2 + diff_size])
outShape[i] = 1;
}
int numOutputs = requiredOutputs ? requiredOutputs : (type == MAX ? 2 : 1); int numOutputs = requiredOutputs ? requiredOutputs : (type == MAX ? 2 : 1);
CV_Assert(numOutputs == 1 || (numOutputs == 2 && type == MAX)); CV_Assert(numOutputs == 1 || (numOutputs == 2 && type == MAX));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment