Commit d95053f2 authored by Vadim Pisarevsky's avatar Vadim Pisarevsky

Merge pull request #1186 from dkurt:halide_support

parents 645260af dc93eede
......@@ -55,6 +55,23 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
typedef std::vector<int> MatShape;
/**
* @brief Enum of computation backends supported by layers.
*/
enum Backend
{
DNN_BACKEND_DEFAULT,
DNN_BACKEND_HALIDE
};
/**
* @brief Enum of target devices for computations.
*/
enum Target
{
DNN_TARGET_CPU
};
/** @brief Initialize dnn module and built-in layers.
*
* This function automatically called on most of OpenCV builds,
......@@ -77,6 +94,54 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
String type; //!< Type name which was used for creating layer by layer factory (optional).
};
/**
* @brief Derivatives of this class encapsulates functions of certain backends.
*/
class BackendNode
{
public:
BackendNode(int backendId);
virtual ~BackendNode(); //!< Virtual destructor to make polymorphism.
int backendId; //!< Backend identifier.
};
/**
* @brief Derivatives of this class wraps cv::Mat for different backends and targets.
*/
class BackendWrapper
{
public:
BackendWrapper(int backendId, int targetId);
/**
* @brief Wrap cv::Mat for specific backend and target.
* @param[in] targetId Target identifier.
* @param[in] m cv::Mat for wrapping.
*
* Make CPU->GPU data transfer if it's require for the target.
*/
BackendWrapper(int targetId, const cv::Mat& m);
/**
* @brief Make wrapper for reused cv::Mat.
* @param[in] base Wrapper of cv::Mat that will be reused.
* @param[in] shape Specific shape.
*
* Initialize wrapper from another one. It'll wrap the same host CPU
* memory and mustn't allocate memory on device(i.e. GPU). It might
* has different shape. Use in case of CPU memory reusing for reuse
* associented memory on device too.
*/
BackendWrapper(const Ptr<BackendWrapper>& base, const MatShape& shape);
virtual ~BackendWrapper(); //!< Virtual destructor to make polymorphism.
int backendId; //!< Backend identifier.
int targetId; //!< Target identifier.
};
/** @brief This interface class allows to build new Layers - are building blocks of networks.
*
* Each class, derived from Layer, must implement allocate() methods to declare own outputs and forward() to compute outputs.
......@@ -131,6 +196,50 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
*/
virtual int outputNameToIndex(String outputName);
/**
* @brief Ask layer if it support specific backend for doing computations.
* @param[in] backendId computation backend identifier.
* @see Backend
*/
virtual bool supportBackend(int backendId);
/**
* @brief Returns Halide backend node.
* @param[in] inputs Input Halide buffers.
* @see BackendNode, BackendWrapper
*
* Input buffers should be exactly the same that will be used in forward invocations.
* Despite we can use Halide::ImageParam based on input shape only,
* it helps prevent some memory management issues (if something wrong,
* Halide tests will be failed).
*/
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs);
/**
* @brief Automatic Halide scheduling based on layer hyper-parameters.
* @param[in] node Backend node with Halide functions.
* @param[in] inputs Blobs that will be used in forward invocations.
* @param[in] outputs Blobs that will be used in forward invocations.
* @see BackendNode
*
* Layer don't use own Halide::Func members because we can have applied
* layers fusing. In this way the fused function should be scheduled.
*/
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const;
/**
* @brief Implement layers fusing.
* @param[in] node Backend node of bottom layer.
* @see BackendNode
*
* Actual for graph-based backends. If layer attached successfully,
* returns non-empty cv::Ptr to node of the same backend.
* Fuse only over the last function.
*/
virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node);
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
......@@ -251,6 +360,24 @@ namespace dnn //! This namespace is used for dnn module functionlaity.
/** @overload */
void forwardOpt(const std::vector<LayerId> &toLayers);
/**
* @brief Compile Halide layers.
* @param[in] scheduler Path to YAML file with scheduling directives.
* @see setPreferableBackend
*
* Schedule layers that support Halide backend. Then compile them for
* specific target. For layers that not represented in scheduling file
* or if no manual scheduling used at all, automatic scheduling will be applied.
*/
void compileHalide(const std::string& scheduler = "");
/**
* @brief Ask network to use specific computation backend where it supported.
* @param[in] backendId backend identifier.
* @see Backend
*/
void setPreferableBackend(int backendId);
/** @brief Sets the new value for the layer output blob
* @param outputName descriptor of the updating layer output blob.
* @param blob new blob.
......
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
namespace cvtest
{
#ifdef HAVE_HALIDE
using namespace cv;
using namespace dnn;
static void loadNet(const std::string& weights, const std::string& proto,
const std::string& scheduler, int inWidth, int inHeight,
const std::string& outputLayer, const std::string& framework,
int targetId, Net* net, int* outputLayerId)
{
Mat input(inHeight, inWidth, CV_32FC3);
randu(input, 0.0f, 1.0f);
if (framework == "caffe")
{
*net = cv::dnn::readNetFromCaffe(proto, weights);
}
else if (framework == "torch")
{
*net = cv::dnn::readNetFromTorch(weights);
}
else if (framework == "tensorflow")
{
*net = cv::dnn::readNetFromTensorflow(weights);
}
else
CV_Error(Error::StsNotImplemented, "Unknown framework " + framework);
net->setBlob("", cv::dnn::blobFromImage(input, 1.0, false));
net->setPreferableBackend(DNN_BACKEND_HALIDE);
net->compileHalide(scheduler);
*outputLayerId = net->getLayerId(outputLayer);
net->forward(*outputLayerId);
}
PERF_TEST(GoogLeNet, HalidePerfTest)
{
Net net;
int outputLayerId;
loadNet(findDataFile("dnn/bvlc_googlenet.caffemodel"),
findDataFile("dnn/bvlc_googlenet.prototxt"),
"", 227, 227, "prob", "caffe", DNN_TARGET_CPU, &net, &outputLayerId);
TEST_CYCLE_N(10)
{
net.forward(outputLayerId);
}
SANITY_CHECK_NOTHING();
}
PERF_TEST(AlexNet, HalidePerfTest)
{
Net net;
int outputLayerId;
loadNet(findDataFile("dnn/bvlc_alexnet.caffemodel"),
findDataFile("dnn/bvlc_alexnet.prototxt"),
findDataFile("dnn/halide_scheduler_alexnet.yml"),
227, 227, "prob", "caffe", DNN_TARGET_CPU, &net, &outputLayerId);
TEST_CYCLE_N(10)
{
net.forward(outputLayerId);
}
SANITY_CHECK_NOTHING();
}
// PERF_TEST(ResNet50, HalidePerfTest)
// {
// Net net;
// int outputLayerId;
// loadNet(findDataFile("dnn/ResNet-50-model.caffemodel"),
// findDataFile("dnn/ResNet-50-deploy.prototxt"),
// findDataFile("dnn/halide_scheduler_resnet_50.yml"),
// 224, 224, "prob", "caffe", DNN_TARGET_CPU, &net, &outputLayerId);
//
// TEST_CYCLE_N(10)
// {
// net.forward(outputLayerId);
// }
// SANITY_CHECK_NOTHING();
// }
// PERF_TEST(SqueezeNet_v1_1, HalidePerfTest)
// {
// Net net;
// int outputLayerId;
// loadNet(findDataFile("dnn/squeezenet_v1_1.caffemodel"),
// findDataFile("dnn/squeezenet_v1_1.prototxt"),
// findDataFile("dnn/halide_scheduler_squeezenet_v1_1.yml"),
// 227, 227, "prob", "caffe", DNN_TARGET_CPU, &net, &outputLayerId);
//
// TEST_CYCLE_N(10)
// {
// net.forward(outputLayerId);
// }
// SANITY_CHECK_NOTHING();
// }
PERF_TEST(Inception_5h, HalidePerfTest)
{
Net net;
int outputLayerId;
loadNet(findDataFile("dnn/tensorflow_inception_graph.pb"), "",
findDataFile("dnn/halide_scheduler_inception_5h.yml"),
224, 224, "softmax2", "tensorflow", DNN_TARGET_CPU,
&net, &outputLayerId);
TEST_CYCLE_N(10)
{
net.forward(outputLayerId);
}
SANITY_CHECK_NOTHING();
}
PERF_TEST(ENet, HalidePerfTest)
{
Net net;
int outputLayerId;
loadNet(findDataFile("dnn/Enet-model-best.net"), "",
findDataFile("dnn/halide_scheduler_enet.yml"),
512, 256, "l367_Deconvolution", "torch", DNN_TARGET_CPU,
&net, &outputLayerId);
TEST_CYCLE_N(10)
{
net.forward(outputLayerId);
}
SANITY_CHECK_NOTHING();
}
#endif // HAVE_HALIDE
} // namespace cvtest
#include "perf_precomp.hpp"
CV_PERF_TEST_MAIN(dnn)
static const char* extraTestDataPath =
#ifdef WINRT
NULL;
#else
getenv("OPENCV_DNN_TEST_DATA_PATH");
#endif
CV_PERF_TEST_MAIN(dnn,
extraTestDataPath ? (void)cvtest::addDataSearchPath(extraTestDataPath) : (void)0
)
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
// Sample of using Halide backend in OpenCV deep learning module.
// Based on dnn/samples/caffe_googlenet.cpp.
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;
#include <fstream>
#include <iostream>
#include <cstdlib>
/* Find best class for the blob (i. e. class with maximal probability) */
void getMaxClass(const Mat &probBlob, int *classId, double *classProb)
{
Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix
Point classNumber;
minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
*classId = classNumber.x;
}
std::vector<std::string> readClassNames(const char *filename = "synset_words.txt")
{
std::vector<std::string> classNames;
std::ifstream fp(filename);
if (!fp.is_open())
{
std::cerr << "File with classes labels not found: " << filename << std::endl;
exit(-1);
}
std::string name;
while (!fp.eof())
{
std::getline(fp, name);
if (name.length())
classNames.push_back( name.substr(name.find(' ')+1) );
}
fp.close();
return classNames;
}
int main(int argc, char **argv)
{
initModule(); // Required if OpenCV is built as static libs.
std::string modelTxt = "train_val.prototxt";
std::string modelBin = "squeezenet_v1.1.caffemodel";
std::string imageFile = (argc > 1) ? argv[1] : "space_shuttle.jpg";
//! [Read and initialize network]
Net net = dnn::readNetFromCaffe(modelTxt, modelBin);
//! [Read and initialize network]
//! [Check that network was read successfully]
if (net.empty())
{
std::cerr << "Can't load network by using the following files: " << std::endl;
std::cerr << "prototxt: " << modelTxt << std::endl;
std::cerr << "caffemodel: " << modelBin << std::endl;
std::cerr << "SqueezeNet v1.1 can be downloaded from:" << std::endl;
std::cerr << "https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1" << std::endl;
exit(-1);
}
//! [Check that network was read successfully]
//! [Prepare blob]
Mat img = imread(imageFile);
if (img.empty())
{
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
exit(-1);
}
if (img.channels() != 3)
{
std::cerr << "Image " << imageFile << " isn't 3-channel" << std::endl;
exit(-1);
}
resize(img, img, Size(227, 227)); // SqueezeNet v1.1 predict class by 3x227x227 input image.
Mat inputBlob = blobFromImage(img, 1.0, false); // Convert Mat to 4-dimensional batch.
//! [Prepare blob]
//! [Set input blob]
net.setBlob("", inputBlob); // Set the network input.
//! [Set input blob]
//! [Enable Halide backend]
net.setPreferableBackend(DNN_BACKEND_HALIDE); // Tell engine to use Halide where it possible.
//! [Enable Halide backend]
//! [Compile Halide pipeline]
net.compileHalide(); // Compile Halide pipeline.
//! [Compile Halide pipeline]
//! [Make forward pass]
net.forward(); // Compute output.
//! [Make forward pass]
//! [Gather output]
Mat prob = net.getBlob("prob"); // Gather output of "prob" layer.
int classId;
double classProb;
getMaxClass(prob, &classId, &classProb); // Find the best class.
//! [Gather output]
//! [Print results]
std::vector<std::string> classNames = readClassNames();
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
//! [Print results]
return 0;
} //main
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "halide_scheduler.hpp"
#include "op_halide.hpp"
namespace cv
{
namespace dnn
{
#ifdef HAVE_HALIDE
static void applySplit(const FileNode& directive, Halide::Func& func,
const FileNode& params)
{
for (const auto& varNode : directive)
{
const std::string varName = varNode.name();
const std::string factorName = (std::string)varNode;
Halide::Var var(varName);
Halide::Var outerVar(varName + "o");
Halide::Var innerVar(varName + "i");
// If split factor is integer or parameters map has parameter value.
CV_Assert(varNode.isString() && !params[factorName].empty() ||
varNode.isInt());
int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
func.split(var, outerVar, innerVar, factor);
}
}
static void applyReorder(const FileNode& directive, Halide::Func& func)
{
std::string varName;
const int numVars = directive.size();
std::vector<Halide::VarOrRVar> reorderedVars;
reorderedVars.reserve(numVars);
for (int i = 0; i < numVars; ++i)
{
directive[i] >> varName;
reorderedVars.push_back(Halide::Var(varName));
}
func.reorder(reorderedVars);
}
static void applyFuse(const FileNode& directive, Halide::Func& func)
{
CV_Assert(directive["src"].size() >= 2);
CV_Assert(directive["dst"].size() == 1);
std::string str;
directive["src"][0] >> str;
Halide::Var firstVar(str);
directive["src"][1] >> str;
Halide::Var secondVar(str);
directive["dst"] >> str;
Halide::Var dstVar(str);
func.fuse(firstVar, secondVar, dstVar);
for (int i = 2, n = directive["src"].size(); i < n; ++i)
{
directive["src"][i] >> str;
func.fuse(Halide::Var(str), dstVar, dstVar);
}
}
static void applyParallel(const FileNode& directive, Halide::Func& func)
{
std::string varName;
for (int i = 0, n = directive.size(); i < n; ++i)
{
directive[i] >> varName;
func.parallel(Halide::Var(varName));
}
}
static void applyUnroll(const FileNode& directive, Halide::Func& func)
{
std::string varName;
for (int i = 0, n = directive.size(); i < n; ++i)
{
directive[i] >> varName;
func.unroll(Halide::Var(varName));
}
}
static void applyVectorize(const FileNode& directive, Halide::Func& func,
const FileNode& params)
{
for (const auto& varNode : directive)
{
const std::string varName = varNode.name();
const std::string factorName = (std::string)varNode;
// If split factor is integer or parameters map has parameter value.
CV_Assert(varNode.isString() && !params[factorName].empty() ||
varNode.isInt());
int factor = (int)(varNode.isInt() ? varNode : params[factorName]);
Halide::Var var(varName);
Halide::Var inner(varName + "v");
func.split(var, var, inner, factor);
func.vectorize(inner);
}
}
static void applyStoreAt(const FileNode& directive, Halide::Func& func,
std::map<std::string, Halide::Func>& funcsMap)
{
for (const auto& funcNode : directive)
{
const std::string targetFuncName = funcNode.name();
if (funcsMap.find(targetFuncName) == funcsMap.end())
CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
" is not represented in Halide pipeline");
Halide::Func targetFunc = funcsMap[targetFuncName];
func.store_at(targetFunc, (std::string)funcNode);
break;
}
}
static void applyComputeAt(const FileNode& directive, Halide::Func& func,
std::map<std::string, Halide::Func>& funcsMap)
{
for (const auto& funcNode : directive)
{
const std::string targetFuncName = funcNode.name();
if (funcsMap.find(targetFuncName) == funcsMap.end())
CV_Error(cv::Error::StsParseError, "Function " + targetFuncName +
" is not represented in Halide pipeline");
Halide::Func targetFunc = funcsMap[targetFuncName];
func.compute_at(targetFunc, (std::string)funcNode);
break;
}
}
static void applyComputeRoot(const FileNode& directive, Halide::Func& func)
{
bool compute_root;
directive >> compute_root;
if (compute_root)
func.compute_root();
}
static void apply(const FileNode& directives, Halide::Func& func,
std::map<std::string, Halide::Func>& funcsMap,
const FileNode& params)
{
for (const auto& directive : directives)
{
if (directive.name() == "split")
applySplit(directive, func, params);
else if (directive.name() == "reorder")
applyReorder(directive, func);
else if (directive.name() == "fuse")
applyFuse(directive, func);
else if (directive.name() == "parallel")
applyParallel(directive, func);
else if (directive.name() == "unroll")
applyUnroll(directive, func);
else if (directive.name() == "vectorize")
applyVectorize(directive, func, params);
else if (directive.name() == "store_at")
applyStoreAt(directive, func, funcsMap);
else if (directive.name() == "compute_at")
applyComputeAt(directive, func, funcsMap);
else if (directive.name() == "compute_root")
applyComputeRoot(directive, func);
else
CV_Error(Error::StsNotImplemented, "Scheduling directive " +
directive.name() + " is not implemented.");
}
}
// Remove any numeric symbols after '$' sign.
static std::string Deunique(std::string str)
{
int pos = -1;
do
{
pos = str.find('$');
if (pos != -1)
{
int len = str.find_first_not_of("0123456789", pos + 1) - pos;
str = str.replace(pos, len, "");
}
}
while (pos != -1);
return str;
}
#endif // HAVE_HALIDE
HalideScheduler::HalideScheduler(const std::string& configFile)
{
if (!configFile.empty())
fs = FileStorage(configFile, FileStorage::READ);
}
HalideScheduler::~HalideScheduler()
{
if (fs.isOpened())
fs.release();
}
bool HalideScheduler::process(Ptr<BackendNode>& node)
{
#ifdef HAVE_HALIDE
if (!fs.isOpened())
return false;
const FileNode& scheduleNode = fs["scheduling"];
if (scheduleNode.empty())
CV_Error(cv::Error::StsParseError, "Scheduling file should has scheduling node");
std::string str;
std::map<std::string, Halide::Func> funcsMap; // Scheduled functions.
// For every function, from top to bottom, we try to find a scheduling node.
// Scheduling is successful (return true) if for the first function (top)
// node is respresented.
CV_Assert(!node.empty());
std::vector<Halide::Func>& funcs = node.dynamicCast<HalideBackendNode>()->funcs;
for (int i = funcs.size() - 1; i >= 0; --i)
{
Halide::Func& func = funcs[i];
// For functions with the same name Halide generates unique names
// for example func, func$1, func$2.
// They are always formed with '$' and number.
std::string funcName = Deunique(func.name());
const FileNode& funcNode = scheduleNode[funcName];
if (!funcNode.empty())
{
if (!funcNode["pattern"].empty())
{
funcNode["pattern"] >> str;
if (fs["patterns"][str].empty())
CV_Error(cv::Error::StsParseError, "Scheduling pattern " + str +
" is not defined");
apply(fs["patterns"][str], func, funcsMap, funcNode["params"]);
}
else
{
apply(funcNode, func, funcsMap, funcNode["params"]);
}
}
else
{
if (funcsMap.empty())
return false;
}
funcsMap[funcName] = func;
}
return true;
#endif // HAVE_HALIDE
return false;
}
} // namespace dnn
} // namespace cv
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_HALIDE_SCHEDULER_HPP__
#define __OPENCV_DNN_HALIDE_SCHEDULER_HPP__
#include <opencv2/dnn.hpp>
namespace cv
{
namespace dnn
{
class HalideScheduler
{
public:
HalideScheduler(const std::string& configFile);
~HalideScheduler();
// Returns true if pipeline found in scheduling file.
// If more than one function, returns true if the top function scheduled.
// Other functions are optional to scheduling.
bool process(Ptr<BackendNode>& node);
private:
FileStorage fs;
};
} // namespace dnn
} // namespace cv
#endif // __OPENCV_DNN_HALIDE_SCHEDULER_HPP__
......@@ -10,6 +10,7 @@ Implementation of Batch Normalization layer.
*/
#include "../precomp.hpp"
#include "op_halide.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv
......@@ -39,6 +40,12 @@ public:
return true;
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
CV_Assert(blobs.size() >= 2);
......@@ -88,6 +95,73 @@ public:
}
}
virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node)
{
switch (node->backendId)
{
case DNN_BACKEND_HALIDE:
{
#ifdef HAVE_HALIDE
auto base = node.dynamicCast<HalideBackendNode>();
Halide::Func& input = base->funcs.back();
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = attachHalide(input(x, y, c, n));
return Ptr<BackendNode>(new HalideBackendNode(base, top));
#endif // HAVE_HALIDE
break;
}
}
return Ptr<BackendNode>();
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> input = halideBuffer(inputs[0]);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = attachHalide(input(x, y, c, n));
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
#ifdef HAVE_HALIDE
// attachHalide can work both with Halide::Buffer and Halide::Func. In the
// second case it will be a fusion.
Halide::Func attachHalide(const Halide::Expr& input)
{
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Var x("x"), y("y"), c("c"), n("n");
const int weightsBlobIndex = 2;
const int biasBlobIndex = weightsBlobIndex + hasWeights;
const int numChannels = blobs[0].total();
float* meanData = (float*)blobs[0].data;
float* stdData = (float*)blobs[1].data;
float* weightsData = (hasWeights ? (float*)blobs[weightsBlobIndex].data : NULL);
float* biasData = (hasBias ? (float*)blobs[biasBlobIndex].data : NULL);
float varMeanScale = 1.f;
if (!hasWeights && !hasBias) {
varMeanScale = *blobs[2].ptr<float>();
if (varMeanScale != 0)
varMeanScale = 1/varMeanScale;
}
Halide::Buffer<float> weights(numChannels);
Halide::Buffer<float> bias(numChannels);
for (int i = 0; i < numChannels; ++i)
{
weights(i) = (hasWeights ? weightsData[i] : 1.0f) /
sqrt(stdData[i] * varMeanScale + epsilon);
bias(i) = (hasBias ? biasData[i] : 0.0f) -
weights(i) * meanData[i] * varMeanScale;
}
top(x, y, c, n) = input * weights(c) + bias(c);
return top;
}
#endif // HAVE_HALIDE
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......
......@@ -41,6 +41,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
namespace cv
{
......@@ -86,6 +87,12 @@ public:
return false;
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1; // By channels
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
int cAxis = clamp(axis, inputs[0]->dims);
......@@ -100,6 +107,52 @@ public:
ranges[cAxis].start = ranges[cAxis].end;
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &input)
{
#ifdef HAVE_HALIDE
std::vector<Halide::Buffer<> > inputBuffers = halideBuffers(input);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
int offset = inputBuffers[0].channels();
Halide::Expr topExpr = select(c < offset,
inputBuffers[0](x, y, c, n),
inputBuffers[1](x, y, c - offset, n));
for (int i = 2; i < input.size(); ++i)
{
offset += inputBuffers[i - 1].channels();
topExpr = select(c < offset, topExpr,
inputBuffers[i](x, y, c - offset, n));
}
top(x, y, c, n) = topExpr;
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"), yi("yi"), yo("yo");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
if (outW == 1 || outH <= 2)
return;
top.reorder(x, c, y)
.split(y, yo, yi, 2)
.fuse(yo, n, tile)
.parallel(tile)
.unroll(yi)
.vectorize(x, outW >= 16 ? 16 : outW);
#endif // HAVE_HALIDE
}
};
Ptr<ConcatLayer> ConcatLayer::create(const LayerParams& params)
......
......@@ -43,6 +43,7 @@
#include "layers_common.hpp"
#include "op_im2col.hpp"
#include "op_blas.hpp"
#include "op_halide.hpp"
#include "opencv2/core/hal/intrin.hpp"
#include <iostream>
......@@ -64,6 +65,13 @@ public:
}
#endif
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
void finalize(const std::vector<Mat*> &inputs, std::vector<Mat> &outputs)
{
CV_Assert(inputs.size() > 0);
......@@ -98,6 +106,40 @@ public:
(dilation.height == 1 && dilation.width == 1);
}
bool setActivation(const Ptr<ActivationLayer>& ) { return false; }
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"), yi("yi"), yo("yo"), co("co"), ci("ci");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs[1];
Halide::Func& padded_input = node.dynamicCast<HalideBackendNode>()->funcs[0];
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
if (outW == 1 || outH <= 2)
return;
if (is1x1() || outC <= 16)
top.reorder(x, c, y)
.split(y, yo, yi, 2)
.fuse(yo, n, tile)
.parallel(tile)
.unroll(yi)
.vectorize(x, outW >= 16 ? 16 : outW);
else
top.reorder(x, c, y)
.split(y, yo, yi, 2)
.split(c, co, ci, 16)
.fuse(yo, co, tile).fuse(n, tile, tile)
.parallel(tile)
.unroll(yi)
.vectorize(x, outW >= 16 ? 16 : outW);
padded_input.compute_at(top, yi);
#endif // HAVE_HALIDE
}
};
//TODO: simultaneously convolution and bias addition for cache optimization
......@@ -155,6 +197,66 @@ public:
bool setActivation(const Ptr<ActivationLayer>& layer) { activ = layer; return true; }
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
const int inpCn = inputBuffer.channels();
const int outCn = blobs[0].size[0];
const int inpGroupCn = blobs[0].size[1];
const int group = inpCn / inpGroupCn;
const int outGroupCn = outCn / group;
Halide::Buffer<float> weights = wrapToHalideBuffer(blobs[0]);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Func padded_input(name + "_constant_exterior");
if (pad.width || pad.height)
{
Halide::Func bounded =
Halide::BoundaryConditions::constant_exterior(inputBuffer, 0);
padded_input(x, y, c, n) = bounded(x, y, c, n);
}
else
{
padded_input(x, y, c, n) = inputBuffer(x, y, c, n);
}
Halide::RDom r(0, kernel.width, 0, kernel.height, 0, inpGroupCn);
Halide::Expr kc = r.z;
if (group > 1)
{
int outCnBound = outGroupCn;
int inpChBound = inpGroupCn;
Halide::Expr shift = select(c < outCnBound, 0, inpChBound);
for (int i = 2; i < group; ++i)
{
outCnBound += outGroupCn;
inpChBound += inpGroupCn;
shift = select(c < outCnBound, shift, inpChBound);
}
kc += shift;
}
Halide::Expr kx = x * stride.width - pad.width + r.x * dilation.width;
Halide::Expr ky = y * stride.height - pad.height + r.y * dilation.height;
Halide::Expr topExpr = sum(padded_input(kx, ky, kc, n) *
weights(r.x, r.y, r.z, c));
if (hasBias())
{
Halide::Buffer<float> bias = wrapToHalideBuffer(blobs[1], {outCn});
topExpr += bias(c);
}
top(x, y, c, n) = topExpr;
Ptr<BackendNode> pp(new HalideBackendNode({ padded_input, top }));
return Ptr<BackendNode>(new HalideBackendNode({ padded_input, top }));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
class ParallelConv : public cv::ParallelLoopBody
{
public:
......@@ -644,6 +746,53 @@ public:
dilation.height, dilation.width, dstImg.ptr<float>(), &ofsbuf[0]);
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
int inW, inH, inC, inN, outC = blobs[0].size[0];
getCanonicalSize(inputBuffer, &inW, &inH, &inC, &inN);
if (inC / blobs[0].size[1] != 1)
CV_Error(cv::Error::StsNotImplemented,
"Halide backend for Deconvolution with group > 1 is not implemented");
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Func padded_input(name + "_constant_exterior");
auto weights = wrapToHalideBuffer(blobs[0], {kernel.width,
kernel.height, outC, inC});
Halide::Func dilated_input("dilated_input");
dilated_input(x, y, c, n) = 0.0f;
Halide::RDom r1(0, inW, 0, inH);
dilated_input(r1.x * stride.width, r1.y * stride.height, c, n) =
inputBuffer(r1.x, r1.y, c, n);
dilated_input.compute_root();
Halide::Func bounded =
Halide::BoundaryConditions::constant_exterior(dilated_input, 0,
0, (inW - 1) * stride.width + 1,
0, (inH - 1) * stride.height + 1,
0, inC, 0, inN);
padded_input(x, y, c, n) = bounded(x, y, c, n);
Halide::RDom r(0, kernel.width, 0, kernel.height, 0, inC);
Halide::Expr topExpr = sum(
padded_input(x + pad.width - r.x, y + pad.height - r.y, r.z, n) *
weights(r.x, r.y, c, r.z));
if (hasBias())
{
auto bias = wrapToHalideBuffer(blobs[1], {outC});
topExpr += bias(c);
}
top(x, y, c, n) = topExpr;
return Ptr<BackendNode>(new HalideBackendNode({ padded_input, top }));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......@@ -680,6 +829,8 @@ static void initConvDeconvLayerFromCaffe(Ptr<BaseConvolutionLayer> l, const Laye
CV_Assert(numOutput % ngroups == 0);
CV_Assert((bias && l->blobs.size() == 2) || (!bias && l->blobs.size() == 1));
CV_Assert(l->adjustPad.width < l->stride.width &&
l->adjustPad.height < l->stride.height);
}
Ptr<BaseConvolutionLayer> ConvolutionLayer::create(const LayerParams &params)
......
#include "../precomp.hpp"
#include "op_halide.hpp"
#include "opencv2/imgproc.hpp"
#include <opencv2/dnn/shape_utils.hpp>
......@@ -64,6 +65,44 @@ public:
ElementWiseLayer(const Func &f=Func()) { func = f; }
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node)
{
switch (node->backendId)
{
case DNN_BACKEND_HALIDE:
{
#ifdef HAVE_HALIDE
auto base = node.dynamicCast<HalideBackendNode>();
Halide::Func& input = base->funcs.back();
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (this->name.empty() ? Halide::Func() : Halide::Func(this->name));
func.attachHalide(input(x, y, c, n), top);
return Ptr<BackendNode>(new HalideBackendNode(base, top));
#endif // HAVE_HALIDE
break;
}
}
return Ptr<BackendNode>();
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> input = halideBuffer(inputs[0]);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (this->name.empty() ? Halide::Func() : Halide::Func(this->name));
func.attachHalide(input(x, y, c, n), top);
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
......@@ -147,6 +186,21 @@ struct ReLUFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
if (slope)
{
top(x, y, c, n) = select(input >= 0.0f, input, slope);
}
else
{
top(x, y, c, n) = max(input, 0.0f);
}
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 1; }
};
......@@ -166,6 +220,14 @@ struct TanHFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
top(x, y, c, n) = tanh(input);
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 1; }
};
......@@ -185,6 +247,14 @@ struct SigmoidFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
top(x, y, c, n) = 1.0f / (1.0f + exp(-input));
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 3; }
};
......@@ -204,6 +274,14 @@ struct AbsValFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
top(x, y, c, n) = abs(input);
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 1; }
};
......@@ -223,6 +301,14 @@ struct BNLLFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
top(x, y, c, n) = log(1.0f + exp(-abs(input)));
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 5; }
};
......@@ -264,6 +350,23 @@ struct PowerFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Expr topExpr = (scale == 1.0f ? input : input * scale);
if (shift)
{
topExpr += shift;
}
if (power != 1.0f)
{
topExpr = pow(topExpr, power);
}
top(x, y, c, n) = topExpr;
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return power == 1 ? 2 : 10; }
};
......@@ -314,6 +417,15 @@ struct ChannelsPReLUFunctor
}
}
#ifdef HAVE_HALIDE
void attachHalide(const Halide::Expr& input, Halide::Func& top)
{
Halide::Var x("x"), y("y"), c("c"), n("n");
auto weights = wrapToHalideBuffer(scale, {(int)scale.total()});
top(x, y, c, n) = select(input > 0.0f, input, weights(c) * input);
}
#endif // HAVE_HALIDE
int64 getFLOPSPerElement() const { return 1; }
};
......
......@@ -41,6 +41,8 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
namespace cv
{
namespace dnn
......@@ -81,6 +83,12 @@ public:
}
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
......@@ -144,6 +152,75 @@ public:
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &input)
{
#ifdef HAVE_HALIDE
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Expr topExpr;
std::vector<Halide::Buffer<> > inputBuffers = halideBuffers(input);
switch (op)
{
case SUM:
if (coeffs.empty())
{
topExpr = inputBuffers[0](x, y, c, n) +
inputBuffers[1](x, y, c, n);
for (int i = 2; i < inputBuffers.size(); ++i)
topExpr += inputBuffers[i](x, y, c, n);
}
else
{
topExpr = coeffs[0] * inputBuffers[0](x, y, c, n) +
coeffs[1] * inputBuffers[1](x, y, c, n);
for (int i = 2; i < inputBuffers.size(); ++i)
topExpr += coeffs[i] * inputBuffers[i](x, y, c, n);
}
break;
case PROD:
topExpr = inputBuffers[0](x, y, c, n) *
inputBuffers[1](x, y, c, n);
for (int i = 2; i < inputBuffers.size(); ++i)
topExpr *= inputBuffers[i](x, y, c, n);
break;
case MAX:
topExpr = max(inputBuffers[0](x, y, c, n),
inputBuffers[1](x, y, c, n));
for (int i = 2; i < inputBuffers.size(); ++i)
topExpr = max(topExpr, inputBuffers[i](x, y, c, n));
break;
default:
return Ptr<BackendNode>();
}
top(x, y, c, n) = topExpr;
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"), yi("yi"), yo("yo");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
if (outW == 1 || outH <= 2)
return;
top.reorder(x, c, y)
.split(y, yo, yi, 2)
.fuse(yo, n, tile)
.parallel(tile)
.unroll(yi)
.vectorize(x, outW >= 16 ? 16 : outW);
#endif // HAVE_HALIDE
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......
......@@ -42,6 +42,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_blas.hpp"
#include "op_halide.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv
......@@ -104,6 +105,12 @@ public:
return false;
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1;
}
class FullConnected : public ParallelLoopBody
{
public:
......@@ -213,6 +220,55 @@ public:
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
int inW, inH, inC, inN, outC = blobs[0].size[0];
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
getCanonicalSize(inputBuffer, &inW, &inH, &inC, &inN);
auto weights = wrapToHalideBuffer(blobs[0], {inW, inH, inC, outC});
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::RDom r(0, inW, 0, inH, 0, inC);
Halide::Expr topExpr = sum(inputBuffer(r.x, r.y, r.z, n) *
weights(r.x, r.y, r.z, c));
if (bias)
{
Halide::Buffer<float> bias = wrapToHalideBuffer(blobs[1], {outC});
topExpr += bias(c);
}
top(x, y, c, n) = topExpr;
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
Halide::Var x("x"), y("y"), c("c"), n("n"), co("co"), ci("ci"), tile("tile");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
if (outC + outN == 1)
return;
if (outC > 8)
top.split(c, co, ci, 8)
.fuse(x, y, tile).fuse(co, tile, tile).fuse(n, tile, tile)
.parallel(tile)
.vectorize(ci, 8);
else
top.fuse(x, y, tile).fuse(c, tile, tile).fuse(n, tile, tile)
.parallel(tile);
#endif // HAVE_HALIDE
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......
......@@ -41,6 +41,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/dnn/shape_utils.hpp"
#include "opencv2/core/hal/hal.hpp"
......@@ -76,6 +77,12 @@ public:
normBySize = params.get<bool>("norm_by_size", true);
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
CV_Assert(inputs.size() == outputs.size());
......@@ -222,6 +229,73 @@ public:
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
float alphaSize = alpha;
if (normBySize)
alphaSize /= (type == CHANNEL_NRM ? size : size * size);
int width, height, channels, numImgs;
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
getCanonicalSize(inputBuffer, &width, &height, &channels, &numImgs);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Func padded_sq(name + "_padded_sq");
Halide::Func sq("sq");
sq(x, y, c, n) = inputBuffer(x, y, c, n) * inputBuffer(x, y, c, n);
Halide::Func bounded =
Halide::BoundaryConditions::constant_exterior(sq, 0, 0, width,
0, height,
0, channels,
0, numImgs);
padded_sq(x, y, c, n) = bounded(x, y, c, n);
Halide::Expr base;
if (type == CHANNEL_NRM)
{
Halide::RDom r((1 - size) / 2, size);
base = alphaSize * sum(padded_sq(x, y, c + r, n));
}
else // SPATIAL_NRM
{
Halide::RDom r((1 - size) / 2, size, (1 - size) / 2, size);
base = alphaSize * sum(padded_sq(x + r.x, y + r.y, c, n));
}
base += static_cast<float>(bias);
top(x, y, c, n) = inputBuffer(x, y, c, n) / pow(base, beta);
return Ptr<BackendNode>(new HalideBackendNode({ padded_sq, top }));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
Halide::Var x("x"), y("y"), c("c"), n("n"), yo("yo"), yi("yi"), tile("tile");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs[1];
Halide::Func& padded_sq = node.dynamicCast<HalideBackendNode>()->funcs[0];
if (outW < 8 || outH <= 2)
return;
top.reorder(x, c, y, n)
.split(y, yo, yi, 2)
.fuse(yo, n, tile)
.parallel(tile)
.unroll(yi)
.vectorize(x, 8);
padded_sq.store_at(top, tile)
.compute_at(top, yi);
#endif // HAVE_HALIDE
}
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......
......@@ -11,6 +11,7 @@ Implementation of Batch Normalization layer.
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv
......@@ -29,6 +30,13 @@ public:
poolStride = Size(params.get<int>("pool_stride_w"), params.get<int>("pool_stride_h"));
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide() &&
!poolPad.width && !poolPad.height;
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
......@@ -81,6 +89,54 @@ public:
}
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &input)
{
#ifdef HAVE_HALIDE
// Meaningless operation if false because if kernel > stride
// it is not deterministic and if kernel < stride we just
// skip a part of input data (you'd better change your model).
if (poolKernel.width != poolStride.width ||
poolKernel.height != poolStride.height)
CV_Error(cv::Error::StsNotImplemented,
"Halide backend for maximum unpooling "
"is not support cases when kernel != stride");
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Buffer<float> inputBuffer = halideBuffer(input[0]);
Halide::Buffer<float> indices = halideBuffer(input[1]);
Halide::Expr pooledX = x / poolKernel.width;
Halide::Expr pooledY = y / poolKernel.height;
const int outW = inputBuffer.width() * poolKernel.width;
top(x, y, c, n) = select(y * outW + x == indices(pooledX, pooledY, c, n),
inputBuffer(pooledX, pooledY, c, n), 0.0f);
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"), yi("yi"), yo("yo");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
top.reorder(x, c, y)
.split(y, yo, yi, 2)
.fuse(yo, n, tile)
.parallel(tile)
.unroll(yi)
.vectorize(x, outW >= 16 ? 16 : outW);
#endif // HAVE_HALIDE
}
};
Ptr<MaxUnpoolLayer> MaxUnpoolLayer::create(const LayerParams& params)
......
......@@ -41,6 +41,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
#include <float.h>
#include <algorithm>
using std::max;
......@@ -92,6 +93,14 @@ public:
getConvPoolPaddings(inp, out, kernel, stride, padMode, pad);
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide() &&
(type == PoolingLayer::MAX ||
type == PoolingLayer::AVE && !pad.width && !pad.height);
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
for (size_t ii = 0; ii < inputs.size(); ii++)
......@@ -111,6 +120,16 @@ public:
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
if (type == PoolingLayer::MAX)
return initMaxPoolingHalide(inputs);
else if (type == PoolingLayer::AVE)
return initAvePoolingHalide(inputs);
else
return Ptr<BackendNode>();
}
void maxPooling(Mat &src, Mat &dst, Mat &mask)
{
Size inp(src.size[3], src.size[2]),
......@@ -195,6 +214,120 @@ public:
}
}
virtual Ptr<BackendNode> initMaxPoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
const int inWidth = inputBuffer.width();
const int inHeight = inputBuffer.height();
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::RDom r(0, kernel.width, 0, kernel.height);
Halide::Expr kx, ky;
if (pad.width || pad.height)
{
kx = clamp(x * stride.width + r.x - pad.width, 0, inWidth - 1);
ky = clamp(y * stride.height + r.y - pad.height, 0, inHeight - 1);
}
else
{
kx = min(x * stride.width + r.x, inWidth - 1);
ky = min(y * stride.height + r.y, inHeight - 1);
}
// Halide::argmax returns tuple (r.x, r.y, max).
Halide::Tuple res = argmax(inputBuffer(kx, ky, c, n));
// Compute offset from argmax in range [0, kernel_size).
Halide::Expr max_index;
if (pad.width || pad.height)
{
max_index = clamp(y * stride.height + res[1] - pad.height,
0, inHeight - 1) * inWidth +
clamp(x * stride.width + res[0] - pad.width,
0, inWidth - 1);
}
else
{
max_index = min(y * stride.height + res[1], inHeight - 1) * inWidth +
min(x * stride.width + res[0], inWidth - 1);
}
top(x, y, c, n) = { res[2], Halide::cast<float>(max_index) };
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual Ptr<BackendNode> initAvePoolingHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
const int inW = inputBuffer.width(), inH = inputBuffer.height();
if ((inW - kernel.width) % stride.width || (inH - kernel.height) % stride.height)
{
CV_Error(cv::Error::StsNotImplemented,
"Halide backend for average pooling with partial "
"kernels is not implemented");
}
const float norm = 1.0f / (kernel.width * kernel.height);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::RDom r(0, kernel.width, 0, kernel.height);
top(x, y, c, n) = sum(
inputBuffer(x * stride.width + r.x,
y * stride.height + r.y, c, n)) * norm;
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
Halide::Var x("x"), y("y"), c("c"), n("n"), tile("tile"),
xi("xi"), yi("yi"), ci("ci"), xo("xo"), yo("yo"), co("co");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
if (outW < 8 || outH < 8)
{
if (outC > 8)
top.split(c, co, ci, 8)
.fuse(x, y, tile).fuse(co, tile, tile).fuse(n, tile, tile)
.parallel(tile)
.vectorize(ci);
else
{
top.fuse(y, c, tile).fuse(n, tile, tile)
.parallel(tile);
if (outW > 1)
top.vectorize(x);
}
}
else
{
if (outC > 8)
top.split(x, xo, xi, 8).split(y, yo, yi, 8).split(c, co, ci, 8)
.fuse(xo, yo, tile).fuse(co, tile, tile).fuse(n, tile, tile)
.parallel(tile)
.vectorize(xi);
else
top.split(x, xo, xi, 8).split(y, yo, yi, 8)
.fuse(xo, yo, tile).fuse(c, tile, tile).fuse(n, tile, tile)
.parallel(tile)
.vectorize(xi);
}
#endif // HAVE_HALIDE
}
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
......
......@@ -11,6 +11,7 @@ Implementation of Scale layer.
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv
......@@ -36,6 +37,12 @@ public:
return true;
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide();
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
CV_Assert(blobs.size() == 1 + hasBias);
......@@ -65,6 +72,58 @@ public:
}
}
virtual Ptr<BackendNode> tryAttach(const Ptr<BackendNode>& node)
{
switch (node->backendId)
{
case DNN_BACKEND_HALIDE:
{
#ifdef HAVE_HALIDE
auto base = node.dynamicCast<HalideBackendNode>();
Halide::Func& input = base->funcs.back();
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = attachHalide(input(x, y, c, n));
return Ptr<BackendNode>(new HalideBackendNode(base, top));
#endif // HAVE_HALIDE
break;
}
}
return Ptr<BackendNode>();
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> input = halideBuffer(inputs[0]);
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = attachHalide(input(x, y, c, n));
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
#ifdef HAVE_HALIDE
// attachHalide can work both with Halide::Buffer and Halide::Func. In the
// second case it will be a fusion.
Halide::Func attachHalide(const Halide::Expr& input)
{
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Var x("x"), y("y"), c("c"), n("n");
const int numChannels = blobs[0].total();
auto weights = wrapToHalideBuffer(blobs[0], {numChannels});
Halide::Expr topExpr = input * weights(c);
if (hasBias)
{
auto bias = wrapToHalideBuffer(blobs[1], {numChannels});
topExpr += bias(c);
}
top(x, y, c, n) = topExpr;
return top;
}
#endif // HAVE_HALIDE
virtual int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......
......@@ -41,6 +41,7 @@
#include "../precomp.hpp"
#include "layers_common.hpp"
#include "op_halide.hpp"
#include <algorithm>
#include <stdlib.h>
using std::max;
......@@ -74,6 +75,12 @@ public:
return inplace;
}
virtual bool supportBackend(int backendId)
{
return backendId == DNN_BACKEND_DEFAULT ||
backendId == DNN_BACKEND_HALIDE && haveHalide() && axisRaw == 1;
}
void forward(std::vector<Mat*> &inputs, std::vector<Mat> &outputs, std::vector<Mat> &internals)
{
const Mat &src = *inputs[0];
......@@ -155,6 +162,58 @@ public:
}
}
virtual Ptr<BackendNode> initHalide(const std::vector<Ptr<BackendWrapper> > &inputs)
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
int inW, inH, inC, inN;
getCanonicalSize(inputBuffer, &inW, &inH, &inC, &inN);
if (inW != 1 || inH != 1)
CV_Error(cv::Error::StsNotImplemented,
"Halide backend for SoftMax with spatial size "
"more than 1x1 is not implemented");
Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::Func expInput("expInput");
Halide::RDom r(0, inW, 0, inH, 0, inC);
expInput(x, y, c, n) = exp(inputBuffer(x, y, c, n));
Halide::Expr globalSum = sum(expInput(r.x, r.y, r.z, n));
top(x, y, c, n) = expInput(x, y, c, n) / globalSum;
return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE
return Ptr<BackendNode>();
}
virtual void applyHalideScheduler(Ptr<BackendNode>& node,
const std::vector<Mat*> &inputs,
const std::vector<Mat> &outputs) const
{
#ifdef HAVE_HALIDE
int outW, outH, outC, outN;
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
// Most common case when SoftMax is a layer after fully-connected.
// So we just schedule it in the same way.
Halide::Var x("x"), y("y"), c("c"), n("n"), co("co"), ci("ci"), tile("tile");
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
if (outC + outN == 1)
return;
if (outC > 8)
top.split(c, co, ci, 8)
.fuse(x, y, tile).fuse(co, tile, tile).fuse(n, tile, tile)
.parallel(tile)
.vectorize(ci, 8);
else
top.fuse(x, y, tile).fuse(c, tile, tile).fuse(n, tile, tile)
.parallel(tile);
#endif // HAVE_HALIDE
}
int64 getFLOPS(const std::vector<MatShape> &inputs,
const std::vector<MatShape> &outputs) const
{
......
......@@ -72,7 +72,7 @@ public:
{
CV_Assert(inputs.size() == 1);
Layer::getMemoryShapes(inputs, outputsCount >= 0 ? outputsCount : requiredOutputs,
Layer::getMemoryShapes(inputs, max(1, outputsCount >= 0 ? outputsCount : requiredOutputs),
outputs, internals);
return true;
}
......@@ -81,6 +81,7 @@ public:
{
for (size_t i = 0; i < outputs.size(); i++)
{
CV_Assert(inputs[0]->total() == outputs[i].total());
if (outputs[i].data != inputs[0]->data)
inputs[0]->copyTo(outputs[i]);
}
......
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#include "op_halide.hpp"
namespace cv
{
namespace dnn
{
#ifdef HAVE_HALIDE
Halide::Buffer<float> wrapToHalideBuffer(const Mat& mat)
{
int n, c, w, h;
getCanonicalSize(mat.size, &w, &h, &c, &n);
return wrapToHalideBuffer(mat, {w, h, c, n});
}
Halide::Buffer<float> wrapToHalideBuffer(const Mat& mat,
const std::vector<int>& sizes)
{
Halide::Buffer<float> buffer((float*)mat.data, sizes);
buffer.set_host_dirty(); // Indicate that data is on CPU.
return buffer;
}
Halide::Buffer<> halideBuffer(const Ptr<BackendWrapper>& ptr)
{
CV_Assert(!ptr.empty());
return ptr.dynamicCast<HalideBackendWrapper>()->buffer;
}
std::vector<Halide::Buffer<> > halideBuffers(const std::vector<Ptr<BackendWrapper> >& ptrs)
{
std::vector<Halide::Buffer<> > vec;
vec.reserve(ptrs.size());
for (const Ptr<BackendWrapper>& ptr : ptrs)
{
vec.push_back(halideBuffer(ptr));
}
return vec;
}
void getCanonicalSize(const Halide::Buffer<>& buffer, int* width, int* height,
int* channels, int* batch)
{
CV_Assert(buffer.dimensions() == 4);
*width = buffer.extent(0);
*height = buffer.extent(1);
*channels = buffer.extent(2);
*batch = buffer.extent(3);
}
HalideBackendNode::HalideBackendNode(const Halide::Func& func)
: BackendNode(DNN_BACKEND_HALIDE), funcs(1, func) {}
HalideBackendNode::HalideBackendNode(const std::vector<Halide::Func>& funcs)
: BackendNode(DNN_BACKEND_HALIDE), funcs(funcs) {}
HalideBackendNode::HalideBackendNode(const Ptr<HalideBackendNode>& base,
const Halide::Func& top)
: BackendNode(DNN_BACKEND_HALIDE), funcs(base->funcs)
{
funcs.back() = top;
}
HalideBackendWrapper::HalideBackendWrapper(int targetId, const cv::Mat& m)
: BackendWrapper(DNN_BACKEND_HALIDE, targetId)
{
buffer = wrapToHalideBuffer(m);
if (targetId != DNN_TARGET_CPU)
CV_Error(Error::StsNotImplemented, "Unknown target identifier");
}
HalideBackendWrapper::HalideBackendWrapper(const Ptr<BackendWrapper>& base,
const MatShape& shape)
: BackendWrapper(DNN_BACKEND_HALIDE, base->targetId)
{
if (base->targetId != DNN_TARGET_CPU)
CV_Error(Error::StsNotImplemented, "Unknown target identifier");
int w, h, c, n;
getCanonicalSize(shape, &w, &h, &c, &n);
Halide::Buffer<float> baseBuffer = halideBuffer(base);
buffer = Halide::Buffer<float>((float*)baseBuffer.raw_buffer()->host,
{w, h, c, n});
buffer.set_host_dirty(); // Indicate that data is on CPU.
}
#endif // HAVE_HALIDE
void getCanonicalSize(const MatSize& size, int* width, int* height,
int* channels, int* batch)
{
const int dims = size.p[-1];
CV_Assert(dims == 2 || dims == 4);
*batch = size[0];
*channels = size[1];
if (dims == 4)
{
*width = size[3];
*height = size[2];
}
else
{
*width = 1;
*height = 1;
}
}
void getCanonicalSize(const MatShape& shape, int* width, int* height,
int* channels, int* batch)
{
const int dims = shape.size();
CV_Assert(dims == 2 || dims == 4);
*batch = shape[0];
*channels = shape[1];
if (dims == 4)
{
*width = shape[3];
*height = shape[2];
}
else
{
*width = 1;
*height = 1;
}
}
void compileHalide(std::vector<Mat> &outputs, Ptr<BackendNode>& node, int targetId)
{
#ifdef HAVE_HALIDE
CV_Assert(!node.empty());
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
int outW, outH, outC, outN;
Halide::Var x("x"), y("y"), c("c"), n("n");
getCanonicalSize(outputs[0].size, &outW, &outH, &outC, &outN);
top.bound(x, 0, outW).bound(y, 0, outH)
.bound(c, 0, outC).bound(n, 0, outN);
Halide::Target target = Halide::get_host_target();
target.set_feature(Halide::Target::NoAsserts);
top.compile_jit(target);
#endif // HAVE_HALIDE
}
void forwardHalide(std::vector<Ptr<BackendWrapper> > &outputs,
const Ptr<BackendNode>& node)
{
#ifdef HAVE_HALIDE
CV_Assert(!node.empty());
Halide::Func& top = node.dynamicCast<HalideBackendNode>()->funcs.back();
auto outputBuffers = halideBuffers(outputs);
top.realize(Halide::Realization(outputBuffers));
#endif // HAVE_HALIDE
}
bool haveHalide()
{
#ifdef HAVE_HALIDE
return true;
#else
return false;
#endif // HAVE_HALIDE
}
} // namespace dnn
} // namespace cv
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
#ifndef __OPENCV_DNN_OP_HALIDE_HPP__
#define __OPENCV_DNN_OP_HALIDE_HPP__
#include "precomp.hpp"
#ifdef HAVE_HALIDE
#include <Halide.h>
#endif // HAVE_HALIDE
namespace cv
{
namespace dnn
{
#ifdef HAVE_HALIDE
// Returns four-dimensional buffer with float32 type that wrap cv::Mat data.
// No data copy here.
Halide::Buffer<float> wrapToHalideBuffer(const Mat& mat);
Halide::Buffer<float> wrapToHalideBuffer(const Mat& mat,
const std::vector<int>& shape);
// Extract batch size, number of channels, width and height from buffer.
void getCanonicalSize(const Halide::Buffer<>& buffer, int* width, int* height,
int* channels, int* batch);
// Cast pointer and create copy of Halide buffer. No data copy.
Halide::Buffer<> halideBuffer(const Ptr<BackendWrapper>& ptr);
std::vector<Halide::Buffer<> > halideBuffers(const std::vector<Ptr<BackendWrapper> >& ptrs);
class HalideBackendNode : public BackendNode
{
public:
HalideBackendNode(const Halide::Func& func);
HalideBackendNode(const std::vector<Halide::Func>& funcs);
// Initialize from the <base> node but replace last function to <top>.
// It's using in case of layers fusing when we want to keep functions of
// root layer but replace top by fused one (i.e. conv+padding to relu+padding).
HalideBackendNode(const Ptr<HalideBackendNode>& base, const Halide::Func& top);
std::vector<Halide::Func> funcs;
};
class HalideBackendWrapper : public BackendWrapper
{
public:
HalideBackendWrapper(int targetId, const cv::Mat& m);
HalideBackendWrapper(const Ptr<BackendWrapper>& base, const MatShape& shape);
Halide::Buffer<float> buffer;
};
#endif // HAVE_HALIDE
// Extract batch size, number of channels, width and height from MatSize.
void getCanonicalSize(const MatSize& size, int* width, int* height,
int* channels, int* batch);
void getCanonicalSize(const MatShape& shape, int* width, int* height,
int* channels, int* batch);
// Realize Halide pipeline into output blobs.
void forwardHalide(std::vector<Ptr<BackendWrapper> > &outputs,
const Ptr<BackendNode>& node);
// Compile Halide pipeline to specific target. Use outputs to set bounds of functions.
void compileHalide(std::vector<Mat> &outputs, Ptr<BackendNode>& node, int targetId);
bool haveHalide();
} // namespace dnn
} // namespace cv
#endif // __OPENCV_DNN_OP_HALIDE_HPP__
......@@ -574,7 +574,6 @@ void TFImporter::populateNet(Net dstNet)
{
CV_Assert(layer.input_size() == 2);
layerParams.set("axis", 0);
layerParams.set("bias_term", false);
layerParams.blobs.resize(1);
......@@ -622,7 +621,6 @@ void TFImporter::populateNet(Net dstNet)
}
else if (type == "Softmax")
{
layerParams.set("axis", -1);
int id = dstNet.addLayer(name, "Softmax", layerParams);
layer_id[name] = id;
......
This diff is collapsed.
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// Copyright (C) 2017, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
namespace cvtest
{
#ifdef HAVE_HALIDE
using namespace cv;
using namespace dnn;
static void loadNet(const std::string& weights, const std::string& proto,
const std::string& framework, Net* net)
{
if (framework == "caffe")
{
*net = cv::dnn::readNetFromCaffe(proto, weights);
}
else if (framework == "torch")
{
*net = cv::dnn::readNetFromTorch(weights);
}
else if (framework == "tensorflow")
{
*net = cv::dnn::readNetFromTensorflow(weights);
}
else
CV_Error(Error::StsNotImplemented, "Unknown framework " + framework);
}
static void test(const std::string& weights, const std::string& proto,
const std::string& scheduler, int inWidth, int inHeight,
const std::string& outputLayer, const std::string& framework,
int targetId)
{
Mat input(inHeight, inWidth, CV_32FC3), outputDefault, outputHalide;
randu(input, 0.0f, 1.0f);
Net netDefault, netHalide;
loadNet(weights, proto, framework, &netDefault);
loadNet(weights, proto, framework, &netHalide);
netDefault.setBlob("", blobFromImage(input.clone(), 1.0f, false));
netDefault.forward(netDefault.getLayerId(outputLayer));
outputDefault = netDefault.getBlob(outputLayer).clone();
netHalide.setBlob("", blobFromImage(input.clone(), 1.0f, false));
netHalide.setPreferableBackend(DNN_BACKEND_HALIDE);
netHalide.compileHalide(scheduler);
netHalide.forward(netHalide.getLayerId(outputLayer));
outputHalide = netHalide.getBlob(outputLayer).clone();
normAssert(outputDefault, outputHalide);
// An extra test: change input.
input *= 0.1f;
netDefault.setBlob("", blobFromImage(input.clone(), 1.0, false));
netHalide.setBlob("", blobFromImage(input.clone(), 1.0, false));
normAssert(outputDefault, outputHalide);
// Swap backends.
netHalide.setPreferableBackend(DNN_BACKEND_DEFAULT);
netHalide.forward(netHalide.getLayerId(outputLayer));
netDefault.setPreferableBackend(DNN_BACKEND_HALIDE);
netDefault.compileHalide(scheduler);
netDefault.forward(netDefault.getLayerId(outputLayer));
outputDefault = netHalide.getBlob(outputLayer).clone();
outputHalide = netDefault.getBlob(outputLayer).clone();
normAssert(outputDefault, outputHalide);
}
TEST(Reproducibility_GoogLeNet_Halide, Accuracy)
{
test(findDataFile("dnn/bvlc_googlenet.caffemodel"),
findDataFile("dnn/bvlc_googlenet.prototxt"),
"", 227, 227, "prob", "caffe", DNN_TARGET_CPU);
};
TEST(Reproducibility_AlexNet_Halide, Accuracy)
{
test(getOpenCVExtraDir() + "/dnn/bvlc_alexnet.caffemodel",
getOpenCVExtraDir() + "/dnn/bvlc_alexnet.prototxt",
getOpenCVExtraDir() + "/dnn/halide_scheduler_alexnet.yml",
227, 227, "prob", "caffe", DNN_TARGET_CPU);
};
// TEST(Reproducibility_ResNet_50_Halide, Accuracy)
// {
// test(getOpenCVExtraDir() + "/dnn/ResNet-50-model.caffemodel",
// getOpenCVExtraDir() + "/dnn/ResNet-50-deploy.prototxt",
// getOpenCVExtraDir() + "/dnn/halide_scheduler_resnet_50.yml",
// 224, 224, "prob", "caffe", DNN_TARGET_CPU);
// };
// TEST(Reproducibility_SqueezeNet_v1_1_Halide, Accuracy)
// {
// test(getOpenCVExtraDir() + "/dnn/squeezenet_v1_1.caffemodel",
// getOpenCVExtraDir() + "/dnn/squeezenet_v1_1.prototxt",
// getOpenCVExtraDir() + "/dnn/halide_scheduler_squeezenet_v1_1.yml",
// 227, 227, "prob", "caffe", DNN_TARGET_CPU);
// };
TEST(Reproducibility_Inception_5h_Halide, Accuracy)
{
test(getOpenCVExtraDir() + "/dnn/tensorflow_inception_graph.pb", "",
getOpenCVExtraDir() + "/dnn/halide_scheduler_inception_5h.yml",
224, 224, "softmax2", "tensorflow", DNN_TARGET_CPU);
};
TEST(Reproducibility_ENet_Halide, Accuracy)
{
test(getOpenCVExtraDir() + "/dnn/Enet-model-best.net", "",
getOpenCVExtraDir() + "/dnn/halide_scheduler_enet.yml",
512, 512, "l367_Deconvolution", "torch", DNN_TARGET_CPU);
};
#endif // HAVE_HALIDE
} // namespace cvtest
# How to enable Halide backend for improve efficiency {#tutorial_dnn_halide}
## Introduction
This tutorial guidelines how to run your models in OpenCV deep learning module
using Halide language backend. Halide is an open-source project that let us
write image processing algorithms in well-readable format, schedule computations
according to specific device and evaluate it with a quite good efficiency.
An official website of the Halide project: http://halide-lang.org/.
## Efficiency comparison
Measured on Intel&reg; Core&trade; i7-6700K CPU @ 4.00GHz x 8.
Single image forward pass (in milliseconds):
| Architecture | MKL backend | Halide backend | Speed Up ratio |
|-----------------:|------------:|---------------:|---------------:|
| AlexNet | 16.55 | 22.38 | x0.73 |
| ResNet-50 | 63.69 | 73.91 | x0.86 |
| SqueezeNet v1.1 | 10.11 | 8.21 | x1.23 |
| Inception-5h | 35.38 | 37.06 | x0.95 |
| ENet @ 3x512x256 | 82.26 | 41.21 | x1.99 |
Scheduling directives might be found @ [opencv_extra/testdata/dnn](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn).
## Requirements
### LLVM compiler
@note LLVM compilation might take a long time.
- Download LLVM source code from http://releases.llvm.org/4.0.0/llvm-4.0.0.src.tar.xz.
Unpack it. Let **llvm_root** is a root directory of source code.
- Create directory **llvm_root**/tools/clang
- Download Clang with the same version as LLVM. In our case it will be from
http://releases.llvm.org/4.0.0/cfe-4.0.0.src.tar.xz. Unpack it into
**llvm_root**/tools/clang. Note that it should be a root for Clang source code.
- Build LLVM on Linux
@code
cd llvm_root
mkdir build && cd build
cmake -DLLVM_ENABLE_TERMINFO=OFF -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=Release ..
make -j4
@endcode
- Build LLVM on Windows (Developer Command Prompt)
@code
mkdir \\path-to-llvm-build\\ && cd \\path-to-llvm-build\\
cmake.exe -DLLVM_ENABLE_TERMINFO=OFF -DLLVM_TARGETS_TO_BUILD=X86 -DLLVM_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=\\path-to-llvm-install\\ -G "Visual Studio 14 Win64" \\path-to-llvm-src\\
MSBuild.exe /m:4 /t:Build /p:Configuration=Release .\\INSTALL.vcxproj
@endcode
@note `\\path-to-llvm-build\\` and `\\path-to-llvm-install\\` are different directories.
### Halide language.
- Download source code from GitHub repository, https://github.com/halide/Halide
or using git. The root directory will be a **halide_root**.
@code
git clone https://github.com/halide/Halide.git
@endcode
- Build Halide on Linux
@code
cd halide_root
mkdir build && cd build
cmake -DLLVM_DIR=llvm_root/build/lib/cmake/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_VERSION=40 -DWITH_TESTS=OFF -DWITH_APPS=OFF -DWITH_TUTORIALS=OFF ..
make -j4
@endcode
- Build Halide on Windows (Developer Command Prompt)
@code
cd halide_root
mkdir build && cd build
cmake.exe -DLLVM_DIR=\\path-to-llvm-install\\lib\\cmake\\llvm -DLLVM_VERSION=40 -DWITH_TESTS=OFF -DWITH_APPS=OFF -DWITH_TUTORIALS=OFF -DCMAKE_BUILD_TYPE=Release -G "Visual Studio 14 Win64" ..
MSBuild.exe /m:4 /t:Build /p:Configuration=Release .\\ALL_BUILD.vcxproj
@endcode
## Build OpenCV with Halide backend
When you build OpenCV add the following configuration flags:
- `WITH_HALIDE` - enable Halide linkage
- `HALIDE_ROOT_DIR` - path to Halide build directory
How to build OpenCV with DNN module you may find in @ref tutorial_dnn_build.
## Sample
@include dnn/samples/squeezenet_halide.cpp
## Explanation
Download Caffe model from SqueezeNet repository: [train_val.prototxt](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/train_val.prototxt) and [squeezenet_v1.1.caffemodel](https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/squeezenet_v1.1.caffemodel).
Also you need file with names of [ILSVRC2012](http://image-net.org/challenges/LSVRC/2012/browse-synsets) classes:
[synset_words.txt](https://raw.githubusercontent.com/ludv1x/opencv_contrib/master/modules/dnn/samples/synset_words.txt).
Put these files into working dir of this program example.
-# Read and initialize network using path to .prototxt and .caffemodel files
@snippet dnn/samples/squeezenet_halide.cpp Read and initialize network
-# Check that network was read successfully
@snippet dnn/samples/squeezenet_halide.cpp Check that network was read successfully
-# Read input image and convert to the 4-dimensional blob, acceptable by SqueezeNet v1.1
@snippet dnn/samples/squeezenet_halide.cpp Prepare blob
-# Pass the blob to the network
@snippet dnn/samples/squeezenet_halide.cpp Set input blob
-# Enable using Halide backend for layers where it is implemented
@snippet dnn/samples/squeezenet_halide.cpp Enable Halide backend
-# Compile Halide functions to execute on CPU
@snippet dnn/samples/squeezenet_halide.cpp Compile Halide pipeline
-# Make forward pass
@snippet dnn/samples/squeezenet_halide.cpp Make forward pass
Remember that the first forward pass after initialization require quite more
time that the next ones. It's because of runtime compilation of Halide pipelines
at the first invocation.
-# Determine the best class
@snippet dnn/samples/squeezenet_halide.cpp Gather output
-# Print results
@snippet dnn/samples/squeezenet_halide.cpp Print results
For our image we get:
> Best class: #812 'space shuttle'
>
> Probability: 97.9812%
# How to schedule your network for Halide backend {#tutorial_dnn_halide_scheduling}
## Introduction
Halide code is the same for every device we use. But for achieving the satisfied
efficiency we should schedule computations properly. In this tutorial we describe
the ways to schedule your networks using Halide backend in OpenCV deep learning module.
For better understanding of Halide scheduling you might want to read tutorials @ http://halide-lang.org/tutorials.
If it's your first meeting with Halide in OpenCV, we recommend to start from @ref tutorial_dnn_halide.
## Configuration files
When you call ```cv::dnn::Net::compileHalide```, you can pass a path to textual file
contains scheduling directives for specific device.
Scheduling configuration files represented as YAML files where each node is a
scheduled function or a scheduling directive.
@code
relu1:
reorder: [x, c, y]
split: { y: 2, c: 8 }
parallel: [yo, co]
unroll: yi
vectorize: { x: 4 }
conv1_constant_exterior:
compute_at: { relu1: yi }
@endcode
Considered use variables `n` for batch dimension, `c` for channels,
`y` for rows and `x` for columns. For variables after split are used names
with the same prefix but `o` and `i` suffixes for outer and inner variables
correspondingly. In example, for variable `x` in range `[0, 10)` directive
`split: { x: 2 }` gives new ones `xo` in range `[0, 5)` and `xi` in range `[0, 2)`.
Variable name `x` is no longer available in the same scheduling node.
You can find scheduling examples at [opencv_extra/testdata/dnn](https://github.com/opencv/opencv_extra/tree/master/testdata/dnn)
and use it for schedule your networks.
## Layers fusing
Thanks to layers fusing we can schedule only the top layers of fused sets.
Because for every output value we use the fused formula.
In example, if you have three layers Convolution + Scale + ReLU one by one,
@code
conv(x, y, c, n) = sum(...) + bias(c);
scale(x, y, c, n) = conv(x, y, c, n) * weights(c);
relu(x, y, c, n) = max(scale(x, y, c, n), 0);
@endcode
fused function is something like
@code
relu(x, y, c, n) = max((sum(...) + bias(c)) * weights(c), 0);
@endcode
So only function called `relu` require scheduling.
## Scheduling patterns
Sometimes networks built using blocked structure that means some layer are
identical or quite similar. If you want to apply the same scheduling for
different layers accurate to tiling or vectorization factors, define scheduling
patterns in section `patterns` at the beginning of scheduling file.
Also, your patters may use some parametric variables.
@code
# At the beginning of the file
patterns:
fully_connected:
split: { c: c_split }
fuse: { src: [x, y, co], dst: block }
parallel: block
vectorize: { ci: c_split }
# Somewhere below
fc8:
pattern: fully_connected
params: { c_split: 8 }
@endcode
## Automatic scheduling
Based on manual scheduling experience, proposed way to schedule layers
automatically. Just skip scheduling file path argument at ```cv::dnn::Net::compileHalide```
for let DNN schedule your network. Sometimes it might be even better
than manual scheduling.
You can mix both manual and automatic scheduling ways. Write scheduling file
and skip layers that you want to be scheduled automatically.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment