Commit b964d3a1 authored by arrybn's avatar arrybn

Added few tests for torch

parent 5d9808b0
......@@ -23,6 +23,7 @@ const String keys =
"{c_names c || path to file with classnames for channels (optional, categories.txt) }"
"{result r || path to save output blob (optional, binary format, NCHW order) }"
"{show s || whether to show all output channels or not}"
"{o_blob || output blob's name. If empty, last blob's name in net is used}"
;
std::vector<String> readClassNames(const char *filename);
......@@ -112,7 +113,13 @@ int main(int argc, char **argv)
//! [Gather output]
dnn::Blob prob = net.getBlob(net.getLayerNames().back()); //gather output of "prob" layer
String oBlob = net.getLayerNames().back();
if (!parser.get<String>("o_blob").empty())
{
oBlob = parser.get<String>("o_blob");
}
dnn::Blob prob = net.getBlob(oBlob); //gather output of "prob" layer
Mat& result = prob.matRef();
......
......@@ -28,6 +28,8 @@ void MaxUnpoolLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<
outShape[2] = outSize.height;
outShape[3] = outSize.width;
CV_Assert(inputs[0]->total() == inputs[1]->total());
outputs.resize(1);
outputs[0].create(outShape);
}
......
......@@ -72,7 +72,8 @@ TEST(Torch_Importer, simple_read)
importer->populateNet(net);
}
static void runTorchNet(String prefix, String outLayerName, bool isBinary)
static void runTorchNet(String prefix, String outLayerName = "",
bool check2ndBlob = false, bool isBinary = false)
{
String suffix = (isBinary) ? ".dat" : ".txt";
......@@ -92,52 +93,69 @@ static void runTorchNet(String prefix, String outLayerName, bool isBinary)
Blob out = net.getBlob(outLayerName);
normAssert(outRef, out);
if (check2ndBlob)
{
Blob out2 = net.getBlob(outLayerName + ".1");
Blob ref2 = readTorchBlob(_tf(prefix + "_output_2" + suffix), isBinary);
normAssert(out2, ref2);
}
}
TEST(Torch_Importer, run_convolution)
{
runTorchNet("net_conv", "l1_Convolution", false);
runTorchNet("net_conv");
}
TEST(Torch_Importer, run_pool_max)
{
runTorchNet("net_pool_max", "l1_Pooling", false);
runTorchNet("net_pool_max", "", true);
}
TEST(Torch_Importer, run_pool_ave)
{
runTorchNet("net_pool_ave", "l1_Pooling", false);
runTorchNet("net_pool_ave");
}
TEST(Torch_Importer, run_reshape)
{
runTorchNet("net_reshape", "l1_Reshape", false);
runTorchNet("net_reshape_batch", "l1_Reshape", false);
runTorchNet("net_reshape");
runTorchNet("net_reshape_batch");
}
TEST(Torch_Importer, run_linear)
{
runTorchNet("net_linear_2d", "l1_InnerProduct", false);
runTorchNet("net_linear_2d");
}
TEST(Torch_Importer, run_paralel)
{
runTorchNet("net_parallel", "l2_torchMerge", false);
runTorchNet("net_parallel", "l2_torchMerge");
}
TEST(Torch_Importer, run_concat)
{
runTorchNet("net_concat", "l2_torchMerge", false);
runTorchNet("net_concat", "l2_torchMerge");
}
TEST(Torch_Importer, run_deconv)
{
runTorchNet("net_deconv", "", false);
runTorchNet("net_deconv");
}
TEST(Torch_Importer, run_batch_norm)
{
runTorchNet("net_batch_norm", "", false);
runTorchNet("net_batch_norm");
}
TEST(Torch_Importer, net_prelu)
{
runTorchNet("net_prelu");
}
TEST(Torch_Importer, net_cadd_table)
{
runTorchNet("net_cadd_table");
}
#if defined(ENABLE_TORCH_ENET_TESTS)
......
......@@ -27,6 +27,8 @@ function save(net, input, label)
torch.save(label .. '_input.txt', input, 'ascii')
--torch.save(label .. '_output.dat', output)
torch.save(label .. '_output.txt', output, 'ascii')
return net
end
local net_simple = nn.Sequential()
......@@ -38,7 +40,8 @@ save(net_simple, torch.Tensor(2, 3, 25, 35), 'net_simple')
local net_pool_max = nn.Sequential()
net_pool_max:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2):ceil()) --TODO: add ceil and floor modes
save(net_pool_max, torch.rand(2, 3, 50, 30), 'net_pool_max')
local net = save(net_pool_max, torch.rand(2, 3, 50, 30), 'net_pool_max')
torch.save('net_pool_max_output_2.txt', net.modules[1].indices - 1, 'ascii')
local net_pool_ave = nn.Sequential()
net_pool_ave:add(nn.SpatialAveragePooling(4,5, 2,1, 1,2))
......@@ -74,5 +77,15 @@ net_deconv:add(nn.SpatialFullConvolution(3, 9, 4, 5, 1, 2, 0, 1, 0, 1))
save(net_deconv, torch.rand(2, 3, 4, 3) - 0.5, 'net_deconv')
local net_batch_norm = nn.Sequential()
net_batch_norm:add(nn.SpatialBatchNormalization(3))
save(net_batch_norm, torch.rand(1, 3, 4, 3) - 0.5, 'net_batch_norm')
\ No newline at end of file
net_batch_norm:add(nn.SpatialBatchNormalization(4, 1e-3))
save(net_batch_norm, torch.rand(1, 4, 5, 6) - 0.5, 'net_batch_norm')
local net_prelu = nn.Sequential()
net_prelu:add(nn.PReLU(5))
save(net_prelu, torch.rand(1, 5, 40, 50) - 0.5, 'net_prelu')
local net_cadd_table = nn.Sequential()
local sum = nn.ConcatTable()
sum:add(nn.Identity()):add(nn.Identity())
net_cadd_table:add(sum):add(nn.CAddTable())
save(net_cadd_table, torch.rand(1, 5, 40, 50) - 0.5, 'net_cadd_table')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment