Commit 3f5b4655 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky Committed by GitHub

refactored DNN (#1102)

* the first commit in the merged dnn: convert some public API from Blob's to Mat's

* temporarily or permantently removed OpenCL optimizations, which are not always stable nor usually very efficient; we'll likely use Halide instead

* got rid of Blob and BlobShape completely; use cv::Mat and std::vector<int> instead

* fixed a few compile errors

* got rid of separate .hpp files with layer declarations; instead, put everything into the respective .cpp files

* normalized all the layers' constructors; we concentrate on loading deep networks layers from files instead of constructing them from scratch, so we retained only SomeLayer::SomeLayer(const LayerParams& params); constructors

* fixed sample compilation

* suppress doxygen warnings

* trying to fix python bindings generation for DNN module

* temporarily disable python bindings while we refactor the module

* fix win32/win64 compile errors; remove trailing whitespaces

* fix win32/win64 compile errors; remove trailing whitespaces
parent 4317e27d
......@@ -9,7 +9,7 @@ endif()
set(the_description "Deep neural network module. It allows to load models from different frameworks and to make forward pass")
ocv_add_module(dnn opencv_core opencv_imgproc WRAP python matlab)
ocv_add_module(dnn opencv_core opencv_imgproc)
ocv_warnings_disable(CMAKE_CXX_FLAGS -Wno-shadow -Wno-parentheses -Wmaybe-uninitialized -Wsign-promo
-Wmissing-declarations -Wmissing-prototypes
)
......
This diff is collapsed.
This diff is collapsed.
......@@ -118,6 +118,9 @@ public:
//! If the @p key in the dictionary then returns pointer to its value, else returns NULL.
DictValue *ptr(const String &key);
/** @overload */
const DictValue *ptr(const String &key) const;
//! If the @p key in the dictionary then returns its value, else an error will be generated.
const DictValue &get(const String &key) const;
......
......@@ -45,7 +45,6 @@
#include <vector>
#include <opencv2/core.hpp>
#include <opencv2/dnn/dict.hpp>
#include <opencv2/dnn/blob.hpp>
namespace cv
{
......@@ -70,7 +69,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
{
public:
//TODO: Add ability to name blob params
std::vector<Blob> blobs; //!< List of learned parameters stored as blobs.
std::vector<Mat> blobs; //!< List of learned parameters stored as blobs.
String name; //!< Name of the layer instance (optional, can be used internal purposes).
String type; //!< Type name which was used for creating layer by layer factory (optional).
......@@ -86,7 +85,7 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
public:
//! List of learned parameters must be stored here to allow read them by using Net::getParam().
CV_PROP_RW std::vector<Blob> blobs;
CV_PROP_RW std::vector<Mat> blobs;
/** @brief Allocates internal buffers and output blobs with respect to the shape of inputs.
* @param[in] input vector of already allocated input blobs
......@@ -96,25 +95,25 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
* If this method is called first time then @p output vector consists from empty blobs and its size determined by number of output connections.
* This method can be called multiple times if size of any @p input blob was changed.
*/
virtual void allocate(const std::vector<Blob*> &input, std::vector<Blob> &output) = 0;
virtual void allocate(const std::vector<Mat*> &input, std::vector<Mat> &output) = 0;
/** @brief Given the @p input blobs, computes the output @p blobs.
* @param[in] input the input blobs.
* @param[out] output allocated output blobs, which will store results of the computation.
*/
virtual void forward(std::vector<Blob*> &input, std::vector<Blob> &output) = 0;
virtual void forward(std::vector<Mat*> &input, std::vector<Mat> &output) = 0;
/** @brief @overload */
CV_WRAP void allocate(const std::vector<Blob> &inputs, CV_OUT std::vector<Blob> &outputs);
CV_WRAP void allocate(const std::vector<Mat> &inputs, CV_OUT std::vector<Mat> &outputs);
/** @brief @overload */
CV_WRAP std::vector<Blob> allocate(const std::vector<Blob> &inputs);
CV_WRAP std::vector<Mat> allocate(const std::vector<Mat> &inputs);
/** @brief @overload */
CV_WRAP void forward(const std::vector<Blob> &inputs, CV_IN_OUT std::vector<Blob> &outputs);
CV_WRAP void forward(const std::vector<Mat> &inputs, CV_IN_OUT std::vector<Mat> &outputs);
/** @brief Allocates layer and computes output. */
CV_WRAP void run(const std::vector<Blob> &inputs, CV_OUT std::vector<Blob> &outputs);
CV_WRAP void run(const std::vector<Mat> &inputs, CV_OUT std::vector<Mat> &outputs);
/** @brief Returns index of input blob into the input array.
* @param inputName label of input blob
......@@ -248,13 +247,13 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
* @note If updating blob is not empty then @p blob must have the same shape,
* because network reshaping is not implemented yet.
*/
CV_WRAP void setBlob(String outputName, const Blob &blob);
CV_WRAP void setBlob(String outputName, const Mat &blob);
/** @brief Returns the layer output blob.
* @param outputName the descriptor of the returning layer output blob.
* @see connect(String, String)
*/
CV_WRAP Blob getBlob(String outputName);
CV_WRAP Mat getBlob(String outputName);
/** @brief Sets the new value for the learned param of the layer.
* @param layer name or id of the layer.
......@@ -264,14 +263,14 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
* @note If shape of the new blob differs from the previous shape,
* then the following forward pass may fail.
*/
CV_WRAP void setParam(LayerId layer, int numParam, const Blob &blob);
CV_WRAP void setParam(LayerId layer, int numParam, const Mat &blob);
/** @brief Returns parameter blob of the layer.
* @param layer name or id of the layer.
* @param numParam index of the layer parameter in the Layer::blobs array.
* @see Layer::blobs
*/
CV_WRAP Blob getParam(LayerId layer, int numParam = 0);
CV_WRAP Mat getParam(LayerId layer, int numParam = 0);
/** @brief Returns indexes of layers with unconnected outputs.
*/
......@@ -341,7 +340,10 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
/** @brief Loads blob which was serialized as torch.Tensor object of Torch7 framework.
* @warning This function has the same limitations as createTorchImporter().
*/
CV_EXPORTS_W Blob readTorchBlob(const String &filename, bool isBinary = true);
CV_EXPORTS_W Mat readTorchBlob(const String &filename, bool isBinary = true);
CV_EXPORTS Mat blobFromImage(const Mat& image, double scalefactor=1.0, bool swapRB=true);
CV_EXPORTS Mat blobFromImages(const std::vector<Mat>& image, double scalefactor=1.0, bool swapRB=true);
//! @}
}
......
......@@ -298,6 +298,12 @@ inline DictValue *Dict::ptr(const String &key)
return (i == dict.end()) ? NULL : &i->second;
}
inline const DictValue *Dict::ptr(const String &key) const
{
_Dict::const_iterator i = dict.find(key);
return (i == dict.end()) ? NULL : &i->second;
}
inline const DictValue &Dict::get(const String &key) const
{
_Dict::const_iterator i = dict.find(key);
......
......@@ -122,7 +122,7 @@ static _LayerStaticRegisterer __LayerStaticRegisterer_##type(#type, __LayerStati
template<typename LayerClass>
Ptr<Layer> _layerDynamicRegisterer(LayerParams &params)
{
return Ptr<Layer>(new LayerClass(params));
return Ptr<Layer>(LayerClass::create(params));
}
//allows automatically register created layer on module load time
......
......@@ -43,14 +43,13 @@
#define __OPENCV_DNN_DNN_SHAPE_UTILS_HPP__
#include <opencv2/core.hpp>
#include <opencv2/core/types_c.h>
#include <ostream>
namespace cv {
namespace dnn {
//Useful shortcut
typedef BlobShape Shape;
inline std::ostream &operator<< (std::ostream &s, cv::Range &r)
{
return s << "[" << r.start << ", " << r.end << ")";
......@@ -59,7 +58,7 @@ inline std::ostream &operator<< (std::ostream &s, cv::Range &r)
//Reshaping
//TODO: add -1 specifier for automatic size inferring
template<typename Mat>
/*template<typename Mat>
void reshape(Mat &m, const BlobShape &shape)
{
m = m.reshape(1, shape.dims(), shape.ptr());
......@@ -69,7 +68,7 @@ template<typename Mat>
Mat reshaped(const Mat &m, const BlobShape &shape)
{
return m.reshape(1, shape.dims(), shape.ptr());
}
}*/
//Slicing
......@@ -80,22 +79,19 @@ struct _Range : public cv::Range
_Range(int start, int size = 1) : cv::Range(start, start + size) {}
};
template<typename Mat>
Mat slice(const Mat &m, const _Range &r0)
static inline Mat slice(const Mat &m, const _Range &r0)
{
//CV_Assert(m.dims >= 1);
cv::AutoBuffer<cv::Range, 4> ranges(m.dims);
Range ranges[CV_MAX_DIM];
for (int i = 1; i < m.dims; i++)
ranges[i] = Range::all();
ranges[0] = r0;
return m(&ranges[0]);
}
template<typename Mat>
Mat slice(const Mat &m, const _Range &r0, const _Range &r1)
static inline Mat slice(const Mat &m, const _Range &r0, const _Range &r1)
{
CV_Assert(m.dims >= 2);
cv::AutoBuffer<cv::Range, 4> ranges(m.dims);
Range ranges[CV_MAX_DIM];
for (int i = 2; i < m.dims; i++)
ranges[i] = Range::all();
ranges[0] = r0;
......@@ -103,11 +99,10 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1)
return m(&ranges[0]);
}
template<typename Mat>
Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2)
static inline Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2)
{
CV_Assert(m.dims <= 3);
cv::AutoBuffer<cv::Range, 4> ranges(m.dims);
CV_Assert(m.dims >= 3);
Range ranges[CV_MAX_DIM];
for (int i = 3; i < m.dims; i++)
ranges[i] = Range::all();
ranges[0] = r0;
......@@ -116,11 +111,10 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2)
return m(&ranges[0]);
}
template<typename Mat>
Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2, const _Range &r3)
static inline Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2, const _Range &r3)
{
CV_Assert(m.dims <= 4);
cv::AutoBuffer<cv::Range, 4> ranges(m.dims);
CV_Assert(m.dims >= 4);
Range ranges[CV_MAX_DIM];
for (int i = 4; i < m.dims; i++)
ranges[i] = Range::all();
ranges[0] = r0;
......@@ -130,7 +124,28 @@ Mat slice(const Mat &m, const _Range &r0, const _Range &r1, const _Range &r2, co
return m(&ranges[0]);
}
BlobShape computeShapeByReshapeMask(const BlobShape &srcShape, const BlobShape &maskShape, Range srcRange = Range::all());
static inline Mat getPlane(const Mat &m, int n, int cn)
{
CV_Assert(m.dims > 2);
Range range[CV_MAX_DIM];
int sz[CV_MAX_DIM];
for(int i = 2; i < m.dims; i++)
{
sz[i-2] = m.size.p[i];
range[i] = Range::all();
}
range[0] = Range(n, n+1);
range[1] = Range(cn, cn+1);
return m(range).reshape(1, m.dims-2, sz);
}
static inline size_t shapeTotal(const std::vector<int>& shape)
{
size_t i, n = shape.size(), p = 1;
for( i = 0; i < n; i++ ) p *= shape[i];
return p;
}
}
}
......
#ifdef HAVE_OPENCV_DNN
typedef dnn::DictValue LayerId;
typedef std::vector<cv::dnn::Blob> vector_Blob;
template<>
bool pyopencv_to(PyObject *o, dnn::Blob &blob, const char *name);
template<> struct pyopencvVecConverter<dnn::Blob>
{
static bool to(PyObject* obj, std::vector<dnn::Blob>& value, const ArgInfo info)
{
if (PyArray_Check(obj))
{
value.resize(1);
return pyopencv_to(obj, value[0], info.name);
}
return pyopencv_to_generic_vec(obj, value, info);
}
static PyObject* from(const std::vector<dnn::Blob>& value)
{
return pyopencv_from_generic_vec(value);
}
};
template<>
bool pyopencv_to(PyObject *o, std::vector<dnn::Blob> &blobs, const char *name) //required for Layer::blobs RW
{
return pyopencvVecConverter<dnn::Blob>::to(o, blobs, ArgInfo(name, false));
}
template<>
bool pyopencv_to(PyObject *o, dnn::Blob &blob, const char *name)
{
Mat &dst = blob.matRef();
if (!pyopencv_to(o, dst, name))
return false;
if (PyArray_Check(o)) //try fix channels
{
PyArrayObject* oarr = (PyArrayObject*) o;
if (PyArray_NDIM(oarr) == dst.dims)
return true;
int ndims = PyArray_NDIM(oarr);
std::vector<int> shape(ndims);
const npy_intp* _sizes = PyArray_DIMS(oarr);
for (int i = 0; i < ndims; i++)
shape[i] = (int)_sizes[i];
dst = dst.reshape(1, ndims, &shape[0]);
}
return true;
}
template<>
PyObject *pyopencv_from(const dnn::Blob &blob)
{
return pyopencv_from(blob.matRefConst());
}
template<>
bool pyopencv_to(PyObject *o, dnn::DictValue &dv, const char *name)
......@@ -87,22 +26,4 @@ bool pyopencv_to(PyObject *o, dnn::DictValue &dv, const char *name)
return false;
}
template<>
bool pyopencv_to(PyObject *o, dnn::BlobShape &shape, const char *name)
{
std::vector<int> data;
if (!pyopencv_to_generic_vec(o, data, ArgInfo(name, false)))
return false;
shape = data.size() ? dnn::BlobShape((int)data.size(), &data[0]) : dnn::BlobShape::empty();
return true;
}
template<>
PyObject *pyopencv_from(const dnn::BlobShape &shape)
{
std::vector<int> data(shape.ptr(), shape.ptr() + shape.dims());
return pyopencv_from_generic_vec(data);
}
#endif
\ No newline at end of file
#endif
......@@ -21,15 +21,21 @@ CV_ENUM(GroupSize, GROUP_OFF, GROUP_2);
//Squared Size
#define SSZ(n) cv::Size(n, n)
typedef std::pair<BlobShape, int> InpShapeNumOut;
typedef std::pair<std::vector<int>, int> InpShapeNumOut;
typedef tuple<Size, InpShapeNumOut, GroupSize, StrideSize> ConvParam; //kernel_size, inp shape, groups, stride
typedef TestBaseWithParam<ConvParam> ConvolutionPerfTest;
static inline std::vector<int> blobShape(int count, int nplanes, int height, int width)
{
int data[] = {count, nplanes, height, width};
return std::vector<int>(data, data+4);
}
PERF_TEST_P( ConvolutionPerfTest, perf, Combine(
Values(Size(1, 1), Size(3, 3), Size(5, 5), Size(11, 11)),
Values(make_pair(BlobShape(1, 4, 224, 224), 64),
make_pair(BlobShape(1, 64, 112, 122), 128),
make_pair(BlobShape(1, 256, 28, 28), 512)),
Values(make_pair(blobShape(1, 4, 224, 224), 64),
make_pair(blobShape(1, 64, 112, 122), 128),
make_pair(blobShape(1, 256, 28, 28), 512)),
GroupSize::all(),
StrideSize::all())
)
......@@ -38,17 +44,20 @@ PERF_TEST_P( ConvolutionPerfTest, perf, Combine(
ConvParam params = GetParam();
int ksz = get<0>(params).width;
BlobShape inpShape = get<1>(params).first;
std::vector<int> inpShape = get<1>(params).first;
int outCn = get<1>(params).second;
int groups = get<2>(params);
int stride = (ksz >= 11) ? 4 : (int)get<3>(params);
int inpCn = inpShape[1];
Blob wgtBlob(BlobShape(outCn, inpCn/groups, ksz, ksz)), biasBlob(BlobShape(outCn, 1, 1, 1));
Blob inpBlob(inpShape);
rng.fill(biasBlob.matRef(), RNG::UNIFORM, -1, +1);
rng.fill(wgtBlob.matRef(), RNG::UNIFORM, -1, +1);
rng.fill(inpBlob.matRef(), RNG::UNIFORM, -1, +1);
int wgtSize[] = { outCn, inpCn/groups, ksz, ksz };
int biasSize[] = { outCn, 1, 1, 1 };
const int wtype = CV_32F;
Mat wgtBlob(4, wgtSize, wtype), biasBlob(4, biasSize, wtype);
Mat inpBlob(4, &inpShape[0], wtype);
rng.fill(biasBlob, RNG::UNIFORM, -1, +1);
rng.fill(wgtBlob, RNG::UNIFORM, -1, +1);
rng.fill(inpBlob, RNG::UNIFORM, -1, +1);
LayerParams lp;
lp.set("num_output", outCn);
......@@ -59,15 +68,18 @@ PERF_TEST_P( ConvolutionPerfTest, perf, Combine(
lp.blobs.push_back(wgtBlob);
lp.blobs.push_back(biasBlob);
std::vector<Blob*> inpBlobs(1, &inpBlob);
std::vector<Blob> outBlobs;
std::vector<Mat*> inpBlobs(1, &inpBlob);
std::vector<Mat> outBlobs;
cv::setNumThreads(cv::getNumberOfCPUs());
Ptr<Layer> layer = cv::dnn::LayerFactory::createLayerInstance("Convolution", lp);
layer->allocate(inpBlobs, outBlobs);
declare.in(inpBlob.matRef(), wgtBlob.matRef(), WARMUP_RNG).out(outBlobs[0].matRef()).tbb_threads(cv::getNumThreads());
Mat inpBlob2D = inpBlob.reshape(1, outCn);
Mat wgtBlob2D = wgtBlob.reshape(1, outCn*(inpCn/groups));
Mat outBlob2D = outBlobs[0].reshape(1, outBlobs[0].size[0]);
declare.in(inpBlob2D, wgtBlob2D, WARMUP_RNG).out(outBlob2D).tbb_threads(cv::getNumThreads());
TEST_CYCLE_N(10)
{
......@@ -77,4 +89,4 @@ PERF_TEST_P( ConvolutionPerfTest, perf, Combine(
SANITY_CHECK_NOTHING();
}
}
\ No newline at end of file
}
......@@ -50,9 +50,9 @@ using namespace cv::dnn;
using namespace std;
/* Find best class for the blob (i. e. class with maximal probability) */
void getMaxClass(dnn::Blob &probBlob, int *classId, double *classProb)
void getMaxClass(const Mat &probBlob, int *classId, double *classProb)
{
Mat probMat = probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix
Point classNumber;
minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
......@@ -115,8 +115,7 @@ int main(int argc, char **argv)
}
resize(img, img, Size(224, 224)); //GoogLeNet accepts only 224x224 RGB-images
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
dnn::Blob inputBlob = dnn::Blob::fromImages(img); //Convert Mat to dnn::Blob batch of images
Mat inputBlob = blobFromImage(img); //Convert Mat to batch of images
//! [Prepare blob]
//! [Set input blob]
......@@ -128,7 +127,7 @@ int main(int argc, char **argv)
//! [Make forward pass]
//! [Gather output]
dnn::Blob prob = net.getBlob("prob"); //gather output of "prob" layer
Mat prob = net.getBlob("prob"); //gather output of "prob" layer
int classId;
double classProb;
......
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/core/ocl.hpp>
using namespace cv;
using namespace cv::dnn;
......@@ -45,11 +44,11 @@ static vector<cv::Vec3b> readColors(const string &filename = "pascal-classes.txt
return colors;
}
static void colorizeSegmentation(dnn::Blob &score, const vector<cv::Vec3b> &colors, cv::Mat &segm)
static void colorizeSegmentation(const Mat &score, const vector<cv::Vec3b> &colors, cv::Mat &segm)
{
const int rows = score.rows();
const int cols = score.cols();
const int chns = score.channels();
const int rows = score.size[2];
const int cols = score.size[3];
const int chns = score.size[1];
cv::Mat maxCl(rows, cols, CV_8UC1);
cv::Mat maxVal(rows, cols, CV_32FC1);
......@@ -57,7 +56,7 @@ static void colorizeSegmentation(dnn::Blob &score, const vector<cv::Vec3b> &colo
{
for (int row = 0; row < rows; row++)
{
const float *ptrScore = score.ptrf(0, ch, row);
const float *ptrScore = score.ptr<float>(0, ch, row);
uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
float *ptrMaxVal = maxVal.ptr<float>(row);
for (int col = 0; col < cols; col++)
......@@ -87,7 +86,6 @@ static void colorizeSegmentation(dnn::Blob &score, const vector<cv::Vec3b> &colo
int main(int argc, char **argv)
{
cv::dnn::initModule(); //Required if OpenCV is built as static libs
cv::ocl::setUseOpenCL(false); //OpenCL switcher
String modelTxt = fcnType + "-heavy-pascal.prototxt";
String modelBin = fcnType + "-heavy-pascal.caffemodel";
......@@ -132,7 +130,7 @@ int main(int argc, char **argv)
}
resize(img, img, Size(500, 500)); //FCN accepts 500x500 RGB-images
dnn::Blob inputBlob = dnn::Blob::fromImages(img); //Convert Mat to dnn::Blob batch of images
Mat inputBlob = blobFromImage(img); //Convert Mat to batch of images
//! [Prepare blob]
//! [Set input blob]
......@@ -147,13 +145,13 @@ int main(int argc, char **argv)
//! [Make forward pass]
//! [Gather output]
dnn::Blob score = net.getBlob("score");
Mat score = net.getBlob("score");
cv::Mat colorize;
Mat colorize;
colorizeSegmentation(score, colors, colorize);
cv::Mat show;
cv::addWeighted(img, 0.4, colorize, 0.6, 0.0, show);
cv::imshow("show", show);
cv::waitKey(0);
Mat show;
addWeighted(img, 0.4, colorize, 0.6, 0.0, show);
imshow("show", show);
waitKey(0);
return 0;
} //main
......@@ -101,7 +101,7 @@ int main(int argc, char** argv)
//! [Prepare blob]
Mat preprocessedFrame = preprocess(frame);
dnn::Blob inputBlob = dnn::Blob::fromImages(preprocessedFrame); //Convert Mat to dnn::Blob image
Mat inputBlob = blobFromImage(preprocessedFrame); //Convert Mat to batch of images
//! [Prepare blob]
//! [Set input blob]
......@@ -113,8 +113,8 @@ int main(int argc, char** argv)
//! [Make forward pass]
//! [Gather output]
dnn::Blob detection = net.getBlob("detection_out");
Mat detectionMat(detection.rows(), detection.cols(), CV_32F, detection.ptrf());
Mat detection = net.getBlob("detection_out");
Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
float confidenceThreshold = parser.get<float>("min_confidence");
for(int i = 0; i < detectionMat.rows; i++)
......
......@@ -32,7 +32,7 @@ const String keys =
"{result r || path to save output blob (optional, binary format, NCHW order) }"
;
void getMaxClass(dnn::Blob &probBlob, int *classId, double *classProb);
void getMaxClass(const Mat &probBlob, int *classId, double *classProb);
std::vector<String> readClassNames(const char *filename);
int main(int argc, char **argv)
......@@ -97,9 +97,7 @@ int main(int argc, char **argv)
if (inputImgSize != img.size())
resize(img, img, inputImgSize); //Resize image to input size
cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
dnn::Blob inputBlob = dnn::Blob::fromImages(img); //Convert Mat to dnn::Blob image batch
Mat inputBlob = blobFromImage(img); //Convert Mat to image batch
//! [Prepare blob]
//! [Set input blob]
......@@ -116,11 +114,7 @@ int main(int argc, char **argv)
tm.stop();
//! [Gather output]
dnn::Blob prob = net.getBlob(outBlobName); //gather output of "prob" layer
Mat& result = prob.matRef();
BlobShape shape = prob.shape();
Mat result = net.getBlob(outBlobName); //gather output of "prob" layer
if (!resultFile.empty()) {
CV_Assert(result.isContinuous());
......@@ -130,7 +124,7 @@ int main(int argc, char **argv)
fout.close();
}
std::cout << "Output blob shape " << shape << std::endl;
std::cout << "Output blob shape " << result.size[0] << " x " << result.size[1] << " x " << result.size[2] << " x " << result.size[3] << std::endl;
std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl;
if (!classNamesFile.empty()) {
......@@ -138,7 +132,7 @@ int main(int argc, char **argv)
int classId;
double classProb;
getMaxClass(prob, &classId, &classProb);//find the best class
getMaxClass(result, &classId, &classProb);//find the best class
//! [Print results]
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
......@@ -149,9 +143,9 @@ int main(int argc, char **argv)
/* Find best class for the blob (i. e. class with maximal probability) */
void getMaxClass(dnn::Blob &probBlob, int *classId, double *classProb)
void getMaxClass(const Mat &probBlob, int *classId, double *classProb)
{
Mat probMat = probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix
Point classNumber;
minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
......
......@@ -27,12 +27,12 @@ const String keys =
;
std::vector<String> readClassNames(const char *filename);
static void colorizeSegmentation(Blob &score, Mat &segm,
static void colorizeSegmentation(const Mat &score, Mat &segm,
Mat &legend, vector<String> &classNames);
int main(int argc, char **argv)
{
cv::CommandLineParser parser(argc, argv, keys);
CommandLineParser parser(argc, argv, keys);
if (parser.has("help"))
{
......@@ -78,31 +78,27 @@ int main(int argc, char **argv)
//! [Initialize network]
//! [Prepare blob]
Mat img = imread(imageFile), input;
Mat img = imread(imageFile, 1);
if (img.empty())
{
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
exit(-1);
}
cv::Size inputImgSize = cv::Size(512, 512);
Size inputImgSize(512, 512);
if (inputImgSize != img.size())
resize(img, img, inputImgSize); //Resize image to input size
if(img.channels() == 3)
cv::cvtColor(img, input, cv::COLOR_BGR2RGB);
input.convertTo(input, CV_32F, 1/255.0);
dnn::Blob inputBlob = dnn::Blob::fromImages(input); //Convert Mat to dnn::Blob image batch
Mat inputBlob = blobFromImage(img, 1./255, true); //Convert Mat to image batch
//! [Prepare blob]
//! [Set input blob]
net.setBlob("", inputBlob); //set the network input
//! [Set input blob]
cv::TickMeter tm;
TickMeter tm;
tm.start();
//! [Make forward pass]
......@@ -119,11 +115,7 @@ int main(int argc, char **argv)
oBlob = parser.get<String>("o_blob");
}
dnn::Blob prob = net.getBlob(oBlob); //gather output of "prob" layer
Mat& result = prob.matRef();
BlobShape shape = prob.shape();
Mat result = net.getBlob(oBlob); //gather output of "prob" layer
if (!resultFile.empty()) {
CV_Assert(result.isContinuous());
......@@ -133,20 +125,21 @@ int main(int argc, char **argv)
fout.close();
}
std::cout << "Output blob shape " << shape << std::endl;
std::cout << "Output blob: " << result.size[0] << " x " << result.size[1] << " x " << result.size[2] << " x " << result.size[3] << "\n";
std::cout << "Inference time, ms: " << tm.getTimeMilli() << std::endl;
if (parser.has("show"))
{
size_t nclasses = result.size[1];
std::vector<String> classNames;
if(!classNamesFile.empty()) {
classNames = readClassNames(classNamesFile.c_str());
if (classNames.size() > prob.channels())
classNames = std::vector<String>(classNames.begin() + classNames.size() - prob.channels(),
if (classNames.size() > nclasses)
classNames = std::vector<String>(classNames.begin() + classNames.size() - nclasses,
classNames.end());
}
Mat segm, legend;
colorizeSegmentation(prob, segm, legend, classNames);
colorizeSegmentation(result, segm, legend, classNames);
Mat show;
addWeighted(img, 0.2, segm, 0.8, 0.0, show);
......@@ -184,11 +177,11 @@ std::vector<String> readClassNames(const char *filename)
return classNames;
}
static void colorizeSegmentation(Blob &score, Mat &segm, Mat &legend, vector<String> &classNames)
static void colorizeSegmentation(const Mat &score, Mat &segm, Mat &legend, vector<String> &classNames)
{
const int rows = score.rows();
const int cols = score.cols();
const int chns = score.channels();
const int rows = score.size[2];
const int cols = score.size[3];
const int chns = score.size[1];
vector<Vec3i> colors;
RNG rng(12345678);
......@@ -200,7 +193,7 @@ static void colorizeSegmentation(Blob &score, Mat &segm, Mat &legend, vector<Str
colors.push_back(Vec3i(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256)));
for (int row = 0; row < rows; row++)
{
const float *ptrScore = score.ptrf(0, ch, row);
const float *ptrScore = score.ptr<float>(0, ch, row);
uchar *ptrMaxCl = maxCl.ptr<uchar>(row);
float *ptrMaxVal = maxVal.ptr<float>(row);
for (int col = 0; col < cols; col++)
......
This diff is collapsed.
......@@ -192,38 +192,37 @@ public:
}
}
BlobShape blobShapeFromProto(const caffe::BlobProto &pbBlob)
void blobShapeFromProto(const caffe::BlobProto &pbBlob, std::vector<int>& shape)
{
shape.clear();
if (pbBlob.has_num() || pbBlob.has_channels() || pbBlob.has_height() || pbBlob.has_width())
{
return BlobShape(pbBlob.num(), pbBlob.channels(), pbBlob.height(), pbBlob.width());
shape.push_back(pbBlob.num());
shape.push_back(pbBlob.channels());
shape.push_back(pbBlob.height());
shape.push_back(pbBlob.width());
}
else if (pbBlob.has_shape())
{
const caffe::BlobShape &_shape = pbBlob.shape();
BlobShape shape = BlobShape::all(_shape.dim_size());
for (int i = 0; i < _shape.dim_size(); i++)
shape[i] = (int)_shape.dim(i);
return shape;
shape.push_back((int)_shape.dim(i));
}
else
{
CV_Error(Error::StsError, "Unknown shape of input blob");
return BlobShape();
}
}
void blobFromProto(const caffe::BlobProto &pbBlob, cv::dnn::Blob &dstBlob)
void blobFromProto(const caffe::BlobProto &pbBlob, cv::Mat &dstBlob)
{
BlobShape shape = blobShapeFromProto(pbBlob);
std::vector<int> shape;
blobShapeFromProto(pbBlob, shape);
dstBlob.create(shape, CV_32F);
CV_Assert(pbBlob.data_size() == (int)dstBlob.matRefConst().total());
dstBlob.create((int)shape.size(), &shape[0], CV_32F);
CV_Assert(pbBlob.data_size() == (int)dstBlob.total());
CV_DbgAssert(pbBlob.GetDescriptor()->FindFieldByLowercaseName("data")->cpp_type() == FieldDescriptor::CPPTYPE_FLOAT);
float *dstData = dstBlob.matRef().ptr<float>();
float *dstData = dstBlob.ptr<float>();
for (int i = 0; i < pbBlob.data_size(); i++)
dstData[i] = pbBlob.data(i);
......
This diff is collapsed.
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_CAFFE_LAYER_LOADERS_HPP__
#define __OPENCV_DNN_CAFFE_LAYER_LOADERS_HPP__
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
//Common template for Caffe layer loaders
template <typename PublicLayer>
Ptr<Layer> createLayerFromCaffe(LayerParams&);
Ptr<Layer> createFlattenLayerFromCaffe(LayerParams&);
}
}
#endif
\ No newline at end of file
......@@ -67,6 +67,65 @@ static String toString(const T &v)
return ss.str();
}
Mat blobFromImage(const Mat& image_, double scalefactor, bool swapRB)
{
Mat image;
if(image_.depth() == CV_8U)
{
image_.convertTo(image, CV_32F, scalefactor);
}
else
image = image_;
CV_Assert(image.dims == 2 && image.depth() == CV_32F);
int nch = image.channels();
CV_Assert(nch == 3 || nch == 4);
int sz[] = { 1, 3, image.rows, image.cols };
Mat blob(4, sz, CV_32F);
Mat ch[4];
for( int j = 0; j < 3; j++ )
ch[j] = Mat(image.rows, image.cols, CV_32F, blob.ptr(0, j));
if(swapRB)
std::swap(ch[0], ch[2]);
split(image, ch);
return blob;
}
Mat blobFromImages(const std::vector<Mat>& images, double scalefactor, bool swapRB)
{
size_t i, nimages = images.size();
if(nimages == 0)
return Mat();
Mat image0 = images[0];
int nch = image0.channels();
CV_Assert(image0.dims == 2 && (nch == 3 || nch == 4));
int sz[] = { (int)nimages, 3, image0.rows, image0.cols };
Mat blob(4, sz, CV_32F), image;
Mat ch[4];
for( i = 0; i < nimages; i++ )
{
Mat image_ = images[i];
if(image_.depth() == CV_8U)
{
image_.convertTo(image, CV_32F, scalefactor);
}
else
image = image_;
CV_Assert(image.depth() == CV_32F);
nch = image.channels();
CV_Assert(image.dims == 2 && (nch == 3 || nch == 4));
CV_Assert(image.size() == image0.size());
for( int j = 0; j < 3; j++ )
ch[j] = Mat(image.rows, image.cols, CV_32F, blob.ptr((int)i, j));
if(swapRB)
std::swap(ch[0], ch[2]);
split(image, ch);
}
return blob;
}
struct LayerPin
{
int lid;
......@@ -107,8 +166,8 @@ struct LayerData
std::set<int> requiredOutputs;
Ptr<Layer> layerInstance;
std::vector<Blob> outputBlobs;
std::vector<Blob*> inputBlobs;
std::vector<Mat> outputBlobs;
std::vector<Mat*> inputBlobs;
int flag;
......@@ -130,8 +189,8 @@ struct LayerData
//fake layer containing network input blobs
struct DataLayer : public Layer
{
void allocate(const std::vector<Blob*>&, std::vector<Blob>&) {}
void forward(std::vector<Blob*>&, std::vector<Blob>&) {}
void allocate(const std::vector<Mat*>&, std::vector<Mat>&) {}
void forward(std::vector<Mat*>&, std::vector<Mat>&) {}
int outputNameToIndex(String tgtName)
{
......@@ -348,8 +407,27 @@ struct Net::Impl
if (ld.flag)
return;
size_t ninputs = ld.inputBlobsId.size();
#if 0
printf("layer %s:", ld.name.c_str());
for (size_t i = 0; i < ninputs; i++)
{
int inp_lid = ld.inputBlobsId[i].lid;
LayerData &inp_ld = layers[inp_lid];
int inp_outputs = (int)inp_ld.outputBlobs.size();
std::cout << " " << inp_ld.name << "(" << inp_outputs;
for( int j = 0; j < inp_outputs; j++ )
{
std::cout << (j == 0 ? ": " : ", ") << inp_ld.outputBlobs[j].size;
}
std::cout << ")";
}
printf("\n");
#endif
//determine parent layers
for (size_t i = 0; i < ld.inputBlobsId.size(); i++)
for (size_t i = 0; i < ninputs; i++)
ld.inputLayersId.insert(ld.inputBlobsId[i].lid);
//allocate parents
......@@ -357,8 +435,8 @@ struct Net::Impl
allocateLayer(*i);
//bind inputs
ld.inputBlobs.resize(ld.inputBlobsId.size());
for (size_t i = 0; i < ld.inputBlobsId.size(); i++)
ld.inputBlobs.resize(ninputs);
for (size_t i = 0; i < ninputs; i++)
{
LayerPin from = ld.inputBlobsId[i];
CV_Assert(from.valid());
......@@ -368,15 +446,24 @@ struct Net::Impl
//allocate layer
ld.outputBlobs.resize(std::max((size_t)1, ld.requiredOutputs.size())); //layer produce at least one output blob
try
//try
{
Ptr<Layer> layerPtr = ld.getLayerInstance();
layerPtr->allocate(ld.inputBlobs, ld.outputBlobs);
#if 0
std::cout << "\toutputs:";
size_t noutputs = ld.outputBlobs.size();
for (size_t j = 0; j < noutputs; j++)
{
std::cout << (j == 0 ? " " : ", ") << ld.outputBlobs[j].size;
}
std::cout << "\n";
#endif
}
catch (const cv::Exception &err)
/*catch (const cv::Exception &err)
{
CV_RETHROW_ERROR(err, format("The following error occured while making allocate() for layer \"%s\": %s", ld.name.c_str(), err.err.c_str()));
}
}*/
ld.flag = 1;
}
......@@ -414,14 +501,14 @@ struct Net::Impl
}
//forward itself
try
//try
{
ld.layerInstance->forward(ld.inputBlobs, ld.outputBlobs);
}
catch (const cv::Exception &err)
/*catch (const cv::Exception &err)
{
CV_RETHROW_ERROR(err, format("The following error occured while making forward() for layer \"%s\": %s", ld.name.c_str(), err.err.c_str()));
}
}*/
ld.flag = 1;
}
......@@ -509,7 +596,7 @@ void Net::setNetInputs(const std::vector<String> &inputBlobNames)
impl->netInputLayer->setNames(inputBlobNames);
}
void Net::setBlob(String outputName, const Blob &blob)
void Net::setBlob(String outputName, const Mat &blob_)
{
LayerPin pin = impl->getPinByAlias(outputName);
if (!pin.valid())
......@@ -517,10 +604,10 @@ void Net::setBlob(String outputName, const Blob &blob)
LayerData &ld = impl->layers[pin.lid];
ld.outputBlobs.resize( std::max(pin.oid+1, (int)ld.requiredOutputs.size()) );
ld.outputBlobs[pin.oid] = blob;
ld.outputBlobs[pin.oid] = blob_.clone();
}
Blob Net::getBlob(String outputName)
Mat Net::getBlob(String outputName)
{
LayerPin pin = impl->getPinByAlias(outputName);
if (!pin.valid())
......@@ -535,20 +622,20 @@ Blob Net::getBlob(String outputName)
return ld.outputBlobs[pin.oid];
}
Blob Net::getParam(LayerId layer, int numParam)
Mat Net::getParam(LayerId layer, int numParam)
{
LayerData &ld = impl->getLayerData(layer);
std::vector<Blob> &layerBlobs = ld.layerInstance->blobs;
std::vector<Mat> &layerBlobs = ld.layerInstance->blobs;
CV_Assert(numParam < (int)layerBlobs.size());
return layerBlobs[numParam];
}
void Net::setParam(LayerId layer, int numParam, const Blob &blob)
void Net::setParam(LayerId layer, int numParam, const Mat &blob)
{
LayerData &ld = impl->getLayerData(layer);
std::vector<Blob> &layerBlobs = ld.layerInstance->blobs;
std::vector<Mat> &layerBlobs = ld.layerInstance->blobs;
CV_Assert(numParam < (int)layerBlobs.size());
//we don't make strong checks, use this function carefully
layerBlobs[numParam] = blob;
......@@ -662,30 +749,30 @@ static void vecToPVec(const std::vector<T> &v, std::vector<T*> &pv)
pv[i] = const_cast<T*>(&v[i]);
}
void Layer::allocate(const std::vector<Blob> &inputs, std::vector<Blob> &outputs)
void Layer::allocate(const std::vector<Mat> &inputs, std::vector<Mat> &outputs)
{
std::vector<Blob*> inputsp;
std::vector<Mat*> inputsp;
vecToPVec(inputs, inputsp);
this->allocate(inputsp, outputs);
}
std::vector<Blob> Layer::allocate(const std::vector<Blob> &inputs)
std::vector<Mat> Layer::allocate(const std::vector<Mat> &inputs)
{
std::vector<Blob> outputs;
std::vector<Mat> outputs;
this->allocate(inputs, outputs);
return outputs;
}
void Layer::forward(const std::vector<Blob> &inputs, std::vector<Blob> &outputs)
void Layer::forward(const std::vector<Mat> &inputs, std::vector<Mat> &outputs)
{
std::vector<Blob*> inputsp;
std::vector<Mat*> inputsp;
vecToPVec(inputs, inputsp);
this->forward(inputsp, outputs);
}
void Layer::run(const std::vector<Blob> &inputs, std::vector<Blob> &outputs)
void Layer::run(const std::vector<Mat> &inputs, std::vector<Mat> &outputs)
{
std::vector<Blob*> inputsp;
std::vector<Mat*> inputsp;
vecToPVec(inputs, inputsp);
this->allocate(inputsp, outputs);
this->forward(inputsp, outputs);
......
......@@ -40,19 +40,6 @@
//M*/
#include "precomp.hpp"
#include "caffe/layer_loaders.hpp"
#include "layers/blank_layer.hpp"
#include "layers/crop_layer.hpp"
#include "layers/eltwise_layer.hpp"
#include "layers/flatten_layer.hpp"
#include "layers/permute_layer.hpp"
#include "layers/prior_box_layer.hpp"
#include "layers/detection_output_layer.hpp"
#include "layers/normalize_bbox_layer.hpp"
#include "layers/shift_layer.hpp"
#include "layers/padding_layer.hpp"
#include "layers/scale_layer.hpp"
namespace cv
{
......@@ -65,7 +52,7 @@ struct AutoInitializer
AutoInitializer() : status(false)
{
cv::dnn::initModule();
initModule();
}
};
......@@ -76,41 +63,41 @@ void initModule()
if (init.status)
return;
REG_RUNTIME_LAYER_FUNC(Slice, createLayerFromCaffe<SliceLayer>);
REG_RUNTIME_LAYER_FUNC(Split, createLayerFromCaffe<SplitLayer>);
REG_RUNTIME_LAYER_FUNC(Concat, createLayerFromCaffe<ConcatLayer>);
REG_RUNTIME_LAYER_FUNC(Reshape, createLayerFromCaffe<ReshapeLayer>);
REG_RUNTIME_LAYER_CLASS(Slice, SliceLayer);
REG_RUNTIME_LAYER_CLASS(Split, SplitLayer);
REG_RUNTIME_LAYER_CLASS(Concat, ConcatLayer);
REG_RUNTIME_LAYER_CLASS(Reshape, ReshapeLayer);
REG_RUNTIME_LAYER_CLASS(Flatten, FlattenLayer);
REG_RUNTIME_LAYER_FUNC(Convolution, createLayerFromCaffe<ConvolutionLayer>);
REG_RUNTIME_LAYER_FUNC(Deconvolution, createLayerFromCaffe<DeconvolutionLayer>);
REG_RUNTIME_LAYER_FUNC(Pooling, createLayerFromCaffe<PoolingLayer>);
REG_RUNTIME_LAYER_FUNC(LRN, createLayerFromCaffe<LRNLayer>);
REG_RUNTIME_LAYER_FUNC(InnerProduct, createLayerFromCaffe<InnerProductLayer>);
REG_RUNTIME_LAYER_FUNC(Softmax, createLayerFromCaffe<SoftmaxLayer>);
REG_RUNTIME_LAYER_FUNC(MVN, createLayerFromCaffe<MVNLayer>);
REG_RUNTIME_LAYER_CLASS(Convolution, ConvolutionLayer);
REG_RUNTIME_LAYER_CLASS(Deconvolution, DeconvolutionLayer);
REG_RUNTIME_LAYER_CLASS(Pooling, PoolingLayer);
REG_RUNTIME_LAYER_CLASS(LRN, LRNLayer);
REG_RUNTIME_LAYER_CLASS(InnerProduct, InnerProductLayer);
REG_RUNTIME_LAYER_CLASS(Softmax, SoftmaxLayer);
REG_RUNTIME_LAYER_CLASS(MVN, MVNLayer);
REG_RUNTIME_LAYER_FUNC(ReLU, createLayerFromCaffe<ReLULayer>);
REG_RUNTIME_LAYER_FUNC(ChannelsPReLU, createLayerFromCaffe<ChannelsPReLULayer>);
REG_RUNTIME_LAYER_FUNC(Sigmoid, createLayerFromCaffe<SigmoidLayer>);
REG_RUNTIME_LAYER_FUNC(TanH, createLayerFromCaffe<TanHLayer>);
REG_RUNTIME_LAYER_FUNC(BNLL, createLayerFromCaffe<BNLLLayer>);
REG_RUNTIME_LAYER_FUNC(AbsVal, createLayerFromCaffe<AbsLayer>);
REG_RUNTIME_LAYER_FUNC(Power, createLayerFromCaffe<PowerLayer>);
REG_RUNTIME_LAYER_FUNC(BatchNorm, createLayerFromCaffe<BatchNormLayer>);
REG_RUNTIME_LAYER_FUNC(MaxUnpool, createLayerFromCaffe<MaxUnpoolLayer>);
REG_RUNTIME_LAYER_CLASS(ReLU, ReLULayer);
REG_RUNTIME_LAYER_CLASS(ChannelsPReLU, ChannelsPReLULayer);
REG_RUNTIME_LAYER_CLASS(Sigmoid, SigmoidLayer);
REG_RUNTIME_LAYER_CLASS(TanH, TanHLayer);
REG_RUNTIME_LAYER_CLASS(BNLL, BNLLLayer);
REG_RUNTIME_LAYER_CLASS(AbsVal, AbsLayer);
REG_RUNTIME_LAYER_CLASS(Power, PowerLayer);
REG_RUNTIME_LAYER_CLASS(BatchNorm, BatchNormLayer);
REG_RUNTIME_LAYER_CLASS(MaxUnpool, MaxUnpoolLayer);
REG_RUNTIME_LAYER_CLASS(Dropout, BlankLayer);
REG_RUNTIME_LAYER_CLASS(Identity, BlankLayer);
REG_RUNTIME_LAYER_FUNC(Crop, createLayerFromCaffe<CropLayer>);
REG_RUNTIME_LAYER_FUNC(Eltwise, createLayerFromCaffe<EltwiseLayer>);
REG_RUNTIME_LAYER_CLASS(Crop, CropLayer);
REG_RUNTIME_LAYER_CLASS(Eltwise, EltwiseLayer);
REG_RUNTIME_LAYER_CLASS(Permute, PermuteLayer);
REG_RUNTIME_LAYER_CLASS(PriorBox, PriorBoxLayer);
REG_RUNTIME_LAYER_CLASS(DetectionOutput, DetectionOutputLayer);
REG_RUNTIME_LAYER_CLASS(NormalizeBBox, NormalizeBBoxLayer);
REG_RUNTIME_LAYER_CLASS(Shift, ShiftLayer);
REG_RUNTIME_LAYER_CLASS(Padding, PaddingLayer);
REG_RUNTIME_LAYER_FUNC(Scale, createLayerFromCaffe<ScaleLayer>);
REG_RUNTIME_LAYER_CLASS(Scale, ScaleLayer);
init.status = true;
}
......
......@@ -9,78 +9,95 @@
Implementation of Batch Normalization layer.
*/
#include "batch_norm_layer.hpp"
#include "../precomp.hpp"
namespace cv
{
namespace dnn
{
BatchNormLayerImpl::BatchNormLayerImpl(bool hasWeights_, bool hasBias_, float epsilon_):
hasWeights(hasWeights_),
hasBias(hasBias_),
epsilon(epsilon_)
{}
void BatchNormLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
class BatchNormLayerImpl : public BatchNormLayer
{
CV_Assert(blobs.size() >= 2);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
public:
BatchNormLayerImpl(const LayerParams& params)
{
CV_Assert(blobs[0].total() == inputs[i]->channels());
CV_Assert(blobs[1].total() == inputs[i]->channels());
outputs[i].create(inputs[i]->shape());
}
}
void BatchNormLayerImpl::forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() == 1);
setParamsFrom(params);
CV_Assert(blobs.size() >= 3);
Blob &inpBlob = *inputs[0];
int weightsBlobIndex = 2;
int biasBlobIndex = weightsBlobIndex + hasWeights;
float varMeanScale = 1;
if (!hasWeights && !hasBias) {
varMeanScale = *blobs[2].ptrf();
if (varMeanScale != 0)
varMeanScale = 1/varMeanScale;
hasWeights = params.get<bool>("has_weight", false);
hasBias = params.get<bool>("has_bias", false);
epsilon = params.get<float>("eps", 1E-5);
}
Mat invStdMat;
cv::pow(blobs[1].matRefConst()*varMeanScale + epsilon, -0.5, invStdMat);
void allocate(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
CV_Assert(blobs.size() >= 2);
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
{
CV_Assert(blobs[0].total() == inputs[i]->size[1]);
CV_Assert(blobs[1].total() == inputs[i]->size[1]);
Mat* inp = inputs[i];
outputs[i].create(inp->dims, &inp->size.p[0], inp->type());
}
}
for (size_t ii = 0; ii < outputs.size(); ii++)
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
Blob &outBlob = outputs[ii];
if (hasWeights)
CV_Assert(inpBlob.channels() == blobs[weightsBlobIndex].total());
if (hasBias)
CV_Assert(inpBlob.channels() == blobs[biasBlobIndex].total());
for(int num = 0; num < outBlob.num(); num++)
{
for (int n = 0; n < outBlob.channels(); n++)
{
float mean = blobs[0].matRefConst().at<float>(n)*varMeanScale;
double invstd = invStdMat.at<float>(n);
float w = hasWeights ? blobs[weightsBlobIndex].matRefConst().at<float>(n) : 1;
float b = hasBias ? blobs[biasBlobIndex].matRefConst().at<float>(n) : 0;
outBlob.getPlane(num, n) = (inpBlob.getPlane(num, n) - mean)*w*invstd + b;
}
}
CV_Assert(inputs.size() == 1);
Mat &inpBlob = *inputs[0];
int weightsBlobIndex = 2;
int biasBlobIndex = weightsBlobIndex + hasWeights;
float varMeanScale = 1;
if (!hasWeights && !hasBias) {
varMeanScale = *blobs[2].ptr<float>();
if (varMeanScale != 0)
varMeanScale = 1/varMeanScale;
}
Mat invStdMat;
cv::pow(blobs[1]*varMeanScale + epsilon, -0.5, invStdMat);
int rows = inpBlob.size[2];
int cols = inpBlob.size[3];
for (size_t ii = 0; ii < outputs.size(); ii++)
{
Mat &outBlob = outputs[ii];
if (hasWeights)
CV_Assert(inpBlob.size[1] == blobs[weightsBlobIndex].total());
if (hasBias)
CV_Assert(inpBlob.size[1] == blobs[biasBlobIndex].total());
for(int num = 0; num < outBlob.size[0]; num++)
{
for (int n = 0; n < outBlob.size[1]; n++)
{
float mean = blobs[0].at<float>(n)*varMeanScale;
double invstd = invStdMat.at<float>(n);
float w = hasWeights ? blobs[weightsBlobIndex].at<float>(n) : 1;
float b = hasBias ? blobs[biasBlobIndex].at<float>(n) : 0;
Mat inpBlobPlane(rows, cols, CV_32F, inpBlob.ptr<float>(num, n));
Mat outBlobPlane(rows, cols, CV_32F, outBlob.ptr<float>(num, n));
inpBlobPlane.convertTo(outBlobPlane, CV_32F, w*invstd, b - mean*w*invstd);
}
}
}
}
}
Ptr<BatchNormLayer> BatchNormLayer::create(bool hasWeights, bool hasBias, float epsilon)
bool hasWeights, hasBias;
float epsilon;
};
Ptr<BatchNormLayer> BatchNormLayer::create(const LayerParams& params)
{
return Ptr<BatchNormLayer>(new BatchNormLayerImpl(hasWeights, hasBias, epsilon));
return Ptr<BatchNormLayer>(new BatchNormLayerImpl(params));
}
} // namespace dnn
......
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
// Copyright (C) 2016, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
/*
Declaration of Batch Normalization layer.
*/
#ifndef __OPENCV_DNN_LAYERS_BATCH_NORM_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_BATCH_NORM_LAYER_HPP__
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
class BatchNormLayerImpl : public BatchNormLayer
{
public:
BatchNormLayerImpl(bool hasWeights_, bool hasBias_, float epsilon_);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
private:
bool hasWeights, hasBias;
float epsilon;
};
}
}
#endif // __OPENCV_DNN_LAYERS_BATCH_NORM_LAYER_HPP__
......@@ -38,30 +38,35 @@
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_LAYERS_FLATTEN_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_FLATTEN_LAYER_HPP__
#include "../precomp.hpp"
namespace cv
{
namespace dnn
{
class FlattenLayer : public Layer
class BlankLayerImpl : public BlankLayer
{
int _startAxis;
int _endAxis;
size_t _numAxes;
BlobShape resultShape;
public:
FlattenLayer(LayerParams &params);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
BlankLayerImpl(const LayerParams&) {}
void allocate(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
outputs[i] = *inputs[i];
}
void checkInputs(const std::vector<Blob*> &inputs);
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
for (size_t i = 0; i < inputs.size(); i++)
outputs[i] = *inputs[i];
}
};
Ptr<BlankLayer> BlankLayer::create(const LayerParams& params)
{
return Ptr<BlankLayer>(new BlankLayerImpl(params));
}
}
}
#endif
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_LAYERS_BLANK_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_BLANK_LAYER_HPP__
#include "../precomp.hpp"
namespace cv
{
namespace dnn
{
class BlankLayer : public Layer
{
public:
BlankLayer(LayerParams&)
{
}
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
outputs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); i++)
outputs[i].shareFrom(*inputs[i]);
}
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
for (size_t i = 0; i < inputs.size(); i++)
outputs[i] = *inputs[i];
}
};
}
}
#endif
......@@ -41,80 +41,69 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "concat_layer.hpp"
#include <opencv2/core/ocl.hpp>
namespace cv
{
namespace dnn
{
ConcatLayerImpl::ConcatLayerImpl(int axis_ /*= 1*/)
class ConcatLayerImpl : public ConcatLayer
{
axis = axis_;
}
void ConcatLayerImpl::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(inputs.size() > 0);
BlobShape refShape = inputs[0]->shape();
axisIdx = inputs[0]->canonicalAxis(axis);
public:
ConcatLayerImpl(const LayerParams& params)
{
setParamsFrom(params);
axis = params.get<int>("axis", 1);
}
int axisSum = 0;
useOpenCL = false;
for (size_t i = 0; i < inputs.size(); i++)
void allocate(const std::vector<Mat *> &inputs, std::vector<Mat> &outputs)
{
BlobShape curShape = inputs[i]->shape();
CV_Assert(inputs.size() > 0);
CV_Assert(curShape.dims() == refShape.dims() && inputs[i]->type() == inputs[0]->type());
for (int curAxis = 0; curAxis < refShape.dims(); curAxis++)
int dims = inputs[0]->dims, dtype = inputs[0]->type();
std::vector<int> refShape(inputs[0]->size.p, inputs[0]->size.p + dims);
axisIdx = axis < 0 ? axis + dims : axis;
int axisSum = 0;
for (size_t i = 0; i < inputs.size(); i++)
{
if (curAxis != axisIdx && refShape[curAxis] != curShape[curAxis])
CV_Error(Error::StsBadSize, "Inconsitent shape for ConcatLayer");
CV_Assert(inputs[i]->type() == dtype);
for (int curAxis = 0; curAxis < dims; curAxis++)
{
if (curAxis != axisIdx && inputs[0]->size[curAxis] != inputs[i]->size[curAxis])
CV_Error(Error::StsBadSize, "Inconsitent shape for ConcatLayer");
}
axisSum += inputs[i]->size[axisIdx];
}
axisSum += curShape[axisIdx];
useOpenCL |= inputs[i]->getState() == Blob::HEAD_AT_MAT;
}
refShape[axisIdx] = axisSum;
useOpenCL &= ocl::useOpenCL();
int allocFlags = (useOpenCL) ? Blob::ALLOC_UMAT : Blob::ALLOC_MAT;
outputs.resize(1);
outputs[0].create(refShape, inputs[0]->type(), allocFlags);
}
refShape[axisIdx] = axisSum;
outputs.resize(1);
outputs[0].create(dims, &refShape[0], dtype);
}
void ConcatLayerImpl::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
#ifdef HAVE_OPENCL
if (useOpenCL)
forward_<UMat>(inputs, outputs);
else
#endif
forward_<Mat>(inputs, outputs);
}
template<typename XMat>
void ConcatLayerImpl::forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs)
{
XMat& outMat = outputs[0].getRef<XMat>();
std::vector<Range> ranges(outputs[0].dims(), Range::all());
ranges[axisIdx].start = 0;
for (size_t i = 0; i < inputs.size(); i++)
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
ranges[axisIdx].end = ranges[axisIdx].start + inputs[i]->size(axisIdx);
inputs[i]->getRefConst<XMat>().copyTo(outMat(&ranges[0]));
ranges[axisIdx].start = ranges[axisIdx].end;
Mat& outMat = outputs[0];
std::vector<Range> ranges(outputs[0].dims, Range::all());
ranges[axisIdx].start = 0;
for (size_t i = 0; i < inputs.size(); i++)
{
ranges[axisIdx].end = ranges[axisIdx].start + inputs[i]->size[axisIdx];
inputs[i]->copyTo(outMat(&ranges[0]));
ranges[axisIdx].start = ranges[axisIdx].end;
}
}
}
Ptr<ConcatLayer> ConcatLayer::create(int axis)
int axisIdx;
};
Ptr<ConcatLayer> ConcatLayer::create(const LayerParams& params)
{
return Ptr<ConcatLayer>(new ConcatLayerImpl(axis));
return Ptr<ConcatLayer>(new ConcatLayerImpl(params));
}
}
......
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_LAYERS_CONCAT_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_CONCAT_LAYER_HPP__
#include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
class ConcatLayerImpl : public ConcatLayer
{
bool useOpenCL;
int axisIdx;
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
public:
ConcatLayerImpl(int axis_ = 1);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
}
}
#endif
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_CONVOLUTION_LAYER_HPP__
#include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
class BaseConvolutionLayerImpl : public ConvolutionLayer
{
public:
BaseConvolutionLayerImpl();
virtual void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
protected:
void init();
virtual void computeInpOutShape(const Blob &inpBlob) = 0;
bool is1x1() const;
int numOutput, group;
int inpH, inpW, inpCn;
int outH, outW, outCn;
int inpGroupCn, outGroupCn;
int ksize;
BlobShape colRowBlobShape;
bool bias;
bool tryUseOpenCL, useOpenCL;
Blob colRowBlob, biasOnesBlob;
};
//TODO: simultaneously convolution and bias addition for cache optimization
class ConvolutionLayerImpl : public BaseConvolutionLayerImpl
{
public:
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
protected:
virtual void computeInpOutShape(const Blob &inpBlob);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void im2col(const Mat &srcImg, Mat &dstCol);
void im2row(const Mat &srcImg, Mat &dstRow);
void im2col(const UMat &srcImg, UMat &dstCol);
void im2row(const UMat &srcImg, UMat &dstCol);
};
class DeConvolutionLayerImpl : public BaseConvolutionLayerImpl
{
public:
virtual void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
protected:
virtual void computeInpOutShape(const Blob &inpBlob);
template<typename XMat>
void forward_(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void col2im(const Mat &colMat, Mat &dstImg);
void col2im(const UMat &colMat, UMat &dstImg);
};
//Importers
Ptr<Layer> createConvolutionLayerFromCaffe(LayerParams &params);
Ptr<Layer> createDeconvolutionLayerFromCaffe(LayerParams &params);
}
}
#endif
......@@ -41,87 +41,97 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "crop_layer.hpp"
namespace cv
{
namespace dnn
{
CropLayerImpl::CropLayerImpl(int start_axis_, const std::vector<int> &offset_)
class CropLayerImpl : public CropLayer
{
startAxis = start_axis_;
offset = offset_;
}
void CropLayerImpl::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
CV_Assert(2 == inputs.size());
const Blob &inpBlob = *inputs[0];
const Blob &inpSzBlob = *inputs[1];
int start_axis = inpBlob.canonicalAxis(startAxis);
int dims = inpBlob.dims();
std::vector<int> offset_final(dims, 0);
if (offset.size() == 1)
public:
CropLayerImpl(const LayerParams& params)
{
for (int i = start_axis; i < dims; i++)
offset_final[i] = offset[0];
}
else if (offset.size() > 1)
{
if ((int)offset.size() != dims - start_axis)
CV_Error(Error::StsBadArg, "number of offset values specified must be equal to the number of dimensions following axis.");
setParamsFrom(params);
startAxis = params.get<int>("axis", 2);
const DictValue *paramOffset = params.ptr("offset");
for (int i = start_axis; i < dims; i++)
offset_final[i] = offset[i - start_axis];
if (paramOffset)
{
for (int i = 0; i < paramOffset->size(); i++)
offset.push_back(paramOffset->get<int>(i));
}
}
BlobShape dstShape = inpBlob.shape();
crop_ranges.resize(dims, Range::all());
for (int i = start_axis; i < dims; i++)
void allocate(const std::vector<Mat *> &inputs, std::vector<Mat> &outputs)
{
dstShape[i] = inpSzBlob.size(i);
CV_Assert(2 == inputs.size());
if (!offset.empty()) //normal case
{
if (offset_final[i] < 0 || offset_final[i] + inpSzBlob.size(i) > inpBlob.size(i))
CV_Error(Error::StsBadArg, "invalid crop parameters");
const Mat &inpBlob = *inputs[0];
const Mat &inpSzBlob = *inputs[1];
int dims = inpBlob.dims;
int start_axis = startAxis < 0 ? startAxis + dims : startAxis;
crop_ranges[i] = Range(offset_final[i], offset_final[i] + inpSzBlob.size(i));
std::vector<int> offset_final(dims, 0);
if (offset.size() == 1)
{
for (int i = start_axis; i < dims; i++)
offset_final[i] = offset[0];
}
else //detect offset automatically so that cropped image is center of original one
else if (offset.size() > 1)
{
if (inpSzBlob.size(i) > inpBlob.size(i))
CV_Error(Error::StsBadArg, "invalid output blob size");
if ((int)offset.size() != dims - start_axis)
CV_Error(Error::StsBadArg, "number of offset values specified must be equal to the number of dimensions following axis.");
for (int i = start_axis; i < dims; i++)
offset_final[i] = offset[i - start_axis];
}
int cur_crop = (inpBlob.size(i) - inpSzBlob.size(i)) / 2;
crop_ranges[i] = Range(cur_crop, cur_crop + inpSzBlob.size(i));
std::vector<int> dstShape(dims);
crop_ranges.resize(dims, Range::all());
for (int i = 0; i < dims; i++)
{
dstShape[i] = inpSzBlob.size[i];
if( i < start_axis )
continue;
if (!offset.empty()) //normal case
{
if (offset_final[i] < 0 || offset_final[i] + inpSzBlob.size[i] > inpBlob.size[i])
CV_Error(Error::StsBadArg, "invalid crop parameters");
crop_ranges[i] = Range(offset_final[i], offset_final[i] + inpSzBlob.size[i]);
}
else //detect offset automatically so that cropped image is center of original one
{
if (inpSzBlob.size[i] > inpBlob.size[i])
CV_Error(Error::StsBadArg, "invalid output blob size");
int cur_crop = (inpBlob.size[i] - inpSzBlob.size[i]) / 2;
crop_ranges[i] = Range(cur_crop, cur_crop + inpSzBlob.size[i]);
}
}
outputs.resize(1);
outputs[0].create(dims, &dstShape[0], inpBlob.type());
}
outputs.resize(1);
outputs[0].create(dstShape);
}
void forward(std::vector<Mat *> &inputs, std::vector<Mat> &outputs)
{
Mat &input = *inputs[0];
Mat &output = outputs[0];
input(&crop_ranges[0]).copyTo(output);
}
std::vector<Range> crop_ranges;
};
void CropLayerImpl::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
{
Blob &input = *inputs[0];
Blob &output = outputs[0];
#ifdef HAVE_OPENCL
if (input.getState() == Blob::HEAD_AT_UMAT)
input.umatRefConst()(&crop_ranges[0]).copyTo(output.umatRef());
else
#endif
input.matRefConst()(&crop_ranges[0]).copyTo(output.matRef());
}
Ptr<CropLayer> CropLayer::create(int start_axis, const std::vector<int> &offset)
Ptr<CropLayer> CropLayer::create(const LayerParams& params)
{
return Ptr<CropLayer>(new CropLayerImpl(start_axis, offset));
return Ptr<CropLayer>(new CropLayerImpl(params));
}
}
......
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_LAYERS_CROP_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_CROP_LAYER_HPP__
#include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
class CropLayerImpl : public CropLayer
{
std::vector<Range> crop_ranges;
public:
CropLayerImpl(int start_axis, const std::vector<int> &offset);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
}
}
#endif
This diff is collapsed.
This diff is collapsed.
......@@ -41,88 +41,117 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "eltwise_layer.hpp"
namespace cv
{
namespace dnn
{
EltwiseLayerImpl::EltwiseLayerImpl(EltwiseOp op_, const std::vector<int> &coeffs_)
class EltwiseLayerImpl : public EltwiseLayer
{
public:
EltwiseOp op;
std::vector<int> coeffs;
EltwiseLayerImpl(const LayerParams& params)
{
op = op_;
coeffs = coeffs_;
setParamsFrom(params);
op = EltwiseLayer::SUM;
if (params.has("operation"))
{
String operation = params.get<String>("operation").toLowerCase();
if (operation == "prod")
op = EltwiseLayer::PROD;
else if (operation == "sum")
op = EltwiseLayer::SUM;
else if (operation == "max")
op = EltwiseLayer::MAX;
else
CV_Error(cv::Error::StsBadArg, "Unknown operaticon type \"" + operation + "\"");
}
if (params.has("coeff"))
{
DictValue paramCoeff = params.get("coeff");
int i, n = paramCoeff.size();
coeffs.resize(n);
for (i = 0; i < n; i++)
{
coeffs[i] = paramCoeff.get<int>(i);
}
}
}
void EltwiseLayerImpl::allocate(const std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
void allocate(const std::vector<Mat *> &inputs, std::vector<Mat> &outputs)
{
CV_Assert(2 <= inputs.size());
CV_Assert(coeffs.size() == 0 || coeffs.size() == inputs.size());
CV_Assert(op == SUM || coeffs.size() == 0);
const BlobShape &shape0 = inputs[0]->shape();
for (size_t i = 1; i < inputs.size(); ++i)
{
BlobShape iShape = inputs[i]->shape();
CV_Assert(shape0 == iShape);
CV_Assert(inputs[i]->size == inputs[0]->size);
}
outputs.resize(1);
outputs[0].create(shape0);
outputs[0].create(inputs[0]->dims, inputs[0]->size.p, inputs[0]->type());
}
void EltwiseLayerImpl::forward(std::vector<Blob *> &inputs, std::vector<Blob> &outputs)
void forward(std::vector<Mat *> &inputs, std::vector<Mat> &outputs)
{
switch (op)
{
case SUM:
case SUM:
{
CV_Assert(coeffs.size() == 0 || coeffs.size() == inputs.size());
Mat& output = outputs[0].matRef();
Mat& output = outputs[0];
output.setTo(0.);
if (0 < coeffs.size())
{
for (size_t i = 0; i < inputs.size(); i++)
{
output += inputs[i]->matRefConst() * coeffs[i];
output += *inputs[i] * coeffs[i];
}
}
else
{
for (size_t i = 0; i < inputs.size(); i++)
{
output += inputs[i]->matRefConst();
output += *inputs[i];
}
}
}
break;
case PROD:
break;
case PROD:
{
Mat& output = outputs[0].matRef();
Mat& output = outputs[0];
output.setTo(1.);
for (size_t i = 0; i < inputs.size(); i++)
{
output = output.mul(inputs[i]->matRefConst());
output = output.mul(*inputs[i]);
}
}
break;
case MAX:
break;
case MAX:
{
Mat& output = outputs[0].matRef();
cv::max(inputs[0]->matRefConst(), inputs[1]->matRefConst(), output);
Mat& output = outputs[0];
cv::max(*inputs[0], *inputs[1], output);
for (size_t i = 2; i < inputs.size(); i++)
{
cv::max(output, inputs[i]->matRefConst(), output);
cv::max(output, *inputs[i], output);
}
}
break;
default:
CV_Assert(0);
break;
};
break;
default:
CV_Assert(0);
break;
}
}
};
Ptr<EltwiseLayer> EltwiseLayer::create(const LayerParams& params)
{
return Ptr<EltwiseLayer>(new EltwiseLayerImpl(params));
}
Ptr<EltwiseLayer> EltwiseLayer::create(EltwiseOp op, const std::vector<int> &coeffs)
{
return Ptr<EltwiseLayer>(new EltwiseLayerImpl(op, coeffs));
}
}
}
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_DNN_LAYERS_ELTWISE_LAYER_HPP__
#define __OPENCV_DNN_LAYERS_ELTWISE_LAYER_HPP__
#include "../precomp.hpp"
#include <opencv2/dnn/all_layers.hpp>
namespace cv
{
namespace dnn
{
class EltwiseLayerImpl : public EltwiseLayer
{
EltwiseOp op;
std::vector<int> coeffs;
public:
EltwiseLayerImpl(EltwiseOp op, const std::vector<int> &coeffs);
void allocate(const std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
void forward(std::vector<Blob*> &inputs, std::vector<Blob> &outputs);
};
}
}
#endif
This diff is collapsed.
This diff is collapsed.
......@@ -54,7 +54,8 @@ std::string makeName(const std::string& str1, const std::string& str2)
return str1 + str2;
}
bool getParameter(LayerParams &params, const std::string& nameBase, const std::string& nameAll, int &parameterH, int &parameterW, bool hasDefault = false, const int& defaultValue = 0)
bool getParameter(const LayerParams &params, const std::string& nameBase, const std::string& nameAll,
int &parameterH, int &parameterW, bool hasDefault = false, const int& defaultValue = 0)
{
std::string nameH = makeName(nameBase, std::string("_h"));
std::string nameW = makeName(nameBase, std::string("_w"));
......@@ -92,7 +93,7 @@ bool getParameter(LayerParams &params, const std::string& nameBase, const std::s
}
}
void getKernelSize(LayerParams &params, int &kernelH, int &kernelW)
void getKernelSize(const LayerParams &params, int &kernelH, int &kernelW)
{
if(!util::getParameter(params, "kernel", "kernel_size", kernelH, kernelW))
{
......@@ -102,7 +103,7 @@ void getKernelSize(LayerParams &params, int &kernelH, int &kernelW)
CV_Assert(kernelH > 0 && kernelW > 0);
}
void getStrideAndPadding(LayerParams &params, int &padH, int &padW, int &strideH, int &strideW, cv::String& padMode)
void getStrideAndPadding(const LayerParams &params, int &padH, int &padW, int &strideH, int &strideW, cv::String& padMode)
{
util::getParameter(params, "pad", "pad", padH, padW, true, 0);
util::getParameter(params, "stride", "stride", strideH, strideW, true, 1);
......@@ -118,7 +119,7 @@ void getStrideAndPadding(LayerParams &params, int &padH, int &padW, int &strideH
}
void getPoolingKernelParams(LayerParams &params, int &kernelH, int &kernelW, bool &globalPooling,
void getPoolingKernelParams(const LayerParams &params, int &kernelH, int &kernelW, bool &globalPooling,
int &padH, int &padW, int &strideH, int &strideW, cv::String &padMode)
{
util::getStrideAndPadding(params, padH, padW, strideH, strideW, padMode);
......@@ -142,7 +143,7 @@ void getPoolingKernelParams(LayerParams &params, int &kernelH, int &kernelW, boo
}
}
void getConvolutionKernelParams(LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW,
void getConvolutionKernelParams(const LayerParams &params, int &kernelH, int &kernelW, int &padH, int &padW,
int &strideH, int &strideW, int &dilationH, int &dilationW, cv::String &padMode)
{
util::getKernelSize(params, kernelH, kernelW);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment