Commit ffcb070d authored by Anna Petrovicheva's avatar Anna Petrovicheva

Rewrote Concat layer

parent 22a14ea7
......@@ -47,14 +47,14 @@ namespace cv
{
namespace dnn
{
ConcatLayer::ConcatLayer(LayerParams &params) : Layer(params)
{
ConcatLayer::ConcatLayer(LayerParams &params) : Layer(params)
{
axis = params.get<int>("axis", 1);
CV_Assert(axis >= 0);
}
}
void ConcatLayer::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
void ConcatLayer::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
int refType = inputs[0]->type();
......@@ -80,21 +80,39 @@ namespace dnn
refShape[axis] = axisSum;
outputs.resize(1);
outputs[0].create(refShape);
}
}
void ConcatLayer::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
void ConcatLayer::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
// In case when Blob shape used in allocation and inner matrix shape do not match, this layer did not work in previous implementation. This implementation is just a fix and needs to be rewritten more optimally.
if (inputs.size() == 1)
{
// In case when Blob shape used in allocation and inner matrix shape do not match, this layer did not work in previous implementation. This implementation is just a fix and needs to be rewritten.
return;
}
float* outputData = outputs[0].ptrf();
size_t usedSize = 0;
for (size_t i = 0; i < inputs.size(); i++)
size_t numConcats = inputs[0]->total(0, axis);
size_t outputStride = outputs[0].total(axis);
size_t offset = 0;
for (int i = 0; i < inputs.size(); ++i)
{
Mat inMat(1, inputs[i]->total(), CV_32F, inputs[i]->ptrf());
Mat outMat(1, inputs[i]->total(), CV_32F, outputs[0].ptrf() + usedSize);
size_t inputSliceSize = inputs[i]->total(axis);
const float* inputData = inputs[i]->ptrf();
inMat.copyTo(outMat);
usedSize += inputs[i]->total();
for (size_t n = 0; n < numConcats; ++n)
{
const float* src = inputData + n * inputSliceSize;
float* dst = outputData + n * outputStride + offset;
// memcpy(dst, src, inputSliceSize);
for(size_t k = 0; k < inputSliceSize; k++)
{
dst[k] = src[k];
}
}
offset += inputSliceSize;
}
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment