Commit b964d3a1 authored by arrybn's avatar arrybn

Added few tests for torch

parent 5d9808b0
...@@ -23,6 +23,7 @@ const String keys = ...@@ -23,6 +23,7 @@ const String keys =
"{c_names c || path to file with classnames for channels (optional, categories.txt) }" "{c_names c || path to file with classnames for channels (optional, categories.txt) }"
"{result r || path to save output blob (optional, binary format, NCHW order) }" "{result r || path to save output blob (optional, binary format, NCHW order) }"
"{show s || whether to show all output channels or not}" "{show s || whether to show all output channels or not}"
"{o_blob || output blob's name. If empty, last blob's name in net is used}"
; ;
std::vector<String> readClassNames(const char *filename); std::vector<String> readClassNames(const char *filename);
...@@ -112,7 +113,13 @@ int main(int argc, char **argv) ...@@ -112,7 +113,13 @@ int main(int argc, char **argv)
//! [Gather output] //! [Gather output]
dnn::Blob prob = net.getBlob(net.getLayerNames().back()); //gather output of "prob" layer String oBlob = net.getLayerNames().back();
if (!parser.get<String>("o_blob").empty())
{
oBlob = parser.get<String>("o_blob");
}
dnn::Blob prob = net.getBlob(oBlob); //gather output of "prob" layer
Mat& result = prob.matRef(); Mat& result = prob.matRef();
......
...@@ -28,6 +28,8 @@ void MaxUnpoolLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector< ...@@ -28,6 +28,8 @@ void MaxUnpoolLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<
outShape[2] = outSize.height; outShape[2] = outSize.height;
outShape[3] = outSize.width; outShape[3] = outSize.width;
CV_Assert(inputs[0]->total() == inputs[1]->total());
outputs.resize(1); outputs.resize(1);
outputs[0].create(outShape); outputs[0].create(outShape);
} }
......
...@@ -72,7 +72,8 @@ TEST(Torch_Importer, simple_read) ...@@ -72,7 +72,8 @@ TEST(Torch_Importer, simple_read)
importer->populateNet(net); importer->populateNet(net);
} }
static void runTorchNet(String prefix, String outLayerName, bool isBinary) static void runTorchNet(String prefix, String outLayerName = "",
bool check2ndBlob = false, bool isBinary = false)
{ {
String suffix = (isBinary) ? ".dat" : ".txt"; String suffix = (isBinary) ? ".dat" : ".txt";
...@@ -92,52 +93,69 @@ static void runTorchNet(String prefix, String outLayerName, bool isBinary) ...@@ -92,52 +93,69 @@ static void runTorchNet(String prefix, String outLayerName, bool isBinary)
Blob out = net.getBlob(outLayerName); Blob out = net.getBlob(outLayerName);
normAssert(outRef, out); normAssert(outRef, out);
if (check2ndBlob)
{
Blob out2 = net.getBlob(outLayerName + ".1");
Blob ref2 = readTorchBlob(_tf(prefix + "_output_2" + suffix), isBinary);
normAssert(out2, ref2);
}
} }
TEST(Torch_Importer, run_convolution) TEST(Torch_Importer, run_convolution)
{ {
runTorchNet("net_conv", "l1_Convolution", false); runTorchNet("net_conv");
} }
TEST(Torch_Importer, run_pool_max) TEST(Torch_Importer, run_pool_max)
{ {
runTorchNet("net_pool_max", "l1_Pooling", false); runTorchNet("net_pool_max", "", true);
} }
TEST(Torch_Importer, run_pool_ave) TEST(Torch_Importer, run_pool_ave)
{ {
runTorchNet("net_pool_ave", "l1_Pooling", false); runTorchNet("net_pool_ave");
} }
TEST(Torch_Importer, run_reshape) TEST(Torch_Importer, run_reshape)
{ {
runTorchNet("net_reshape", "l1_Reshape", false); runTorchNet("net_reshape");
runTorchNet("net_reshape_batch", "l1_Reshape", false); runTorchNet("net_reshape_batch");
} }
TEST(Torch_Importer, run_linear) TEST(Torch_Importer, run_linear)
{ {
runTorchNet("net_linear_2d", "l1_InnerProduct", false); runTorchNet("net_linear_2d");
} }
TEST(Torch_Importer, run_paralel) TEST(Torch_Importer, run_paralel)
{ {
runTorchNet("net_parallel", "l2_torchMerge", false); runTorchNet("net_parallel", "l2_torchMerge");
} }
TEST(Torch_Importer, run_concat) TEST(Torch_Importer, run_concat)
{ {
runTorchNet("net_concat", "l2_torchMerge", false); runTorchNet("net_concat", "l2_torchMerge");
} }
TEST(Torch_Importer, run_deconv) TEST(Torch_Importer, run_deconv)
{ {
runTorchNet("net_deconv", "", false); runTorchNet("net_deconv");
} }
TEST(Torch_Importer, run_batch_norm) TEST(Torch_Importer, run_batch_norm)
{ {
runTorchNet("net_batch_norm", "", false); runTorchNet("net_batch_norm");
}
TEST(Torch_Importer, net_prelu)
{
runTorchNet("net_prelu");
}
TEST(Torch_Importer, net_cadd_table)
{
runTorchNet("net_cadd_table");
} }
#if defined(ENABLE_TORCH_ENET_TESTS) #if defined(ENABLE_TORCH_ENET_TESTS)
......
...@@ -27,6 +27,8 @@ function save(net, input, label) ...@@ -27,6 +27,8 @@ function save(net, input, label)
torch.save(label .. '_input.txt', input, 'ascii') torch.save(label .. '_input.txt', input, 'ascii')
--torch.save(label .. '_output.dat', output) --torch.save(label .. '_output.dat', output)
torch.save(label .. '_output.txt', output, 'ascii') torch.save(label .. '_output.txt', output, 'ascii')
return net
end end
local net_simple = nn.Sequential() local net_simple = nn.Sequential()
...@@ -38,7 +40,8 @@ save(net_simple, torch.Tensor(2, 3, 25, 35), 'net_simple') ...@@ -38,7 +40,8 @@ save(net_simple, torch.Tensor(2, 3, 25, 35), 'net_simple')
local net_pool_max = nn.Sequential() local net_pool_max = nn.Sequential()
net_pool_max:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2):ceil()) --TODO: add ceil and floor modes net_pool_max:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2):ceil()) --TODO: add ceil and floor modes
save(net_pool_max, torch.rand(2, 3, 50, 30), 'net_pool_max') local net = save(net_pool_max, torch.rand(2, 3, 50, 30), 'net_pool_max')
torch.save('net_pool_max_output_2.txt', net.modules[1].indices - 1, 'ascii')
local net_pool_ave = nn.Sequential() local net_pool_ave = nn.Sequential()
net_pool_ave:add(nn.SpatialAveragePooling(4,5, 2,1, 1,2)) net_pool_ave:add(nn.SpatialAveragePooling(4,5, 2,1, 1,2))
...@@ -74,5 +77,15 @@ net_deconv:add(nn.SpatialFullConvolution(3, 9, 4, 5, 1, 2, 0, 1, 0, 1)) ...@@ -74,5 +77,15 @@ net_deconv:add(nn.SpatialFullConvolution(3, 9, 4, 5, 1, 2, 0, 1, 0, 1))
save(net_deconv, torch.rand(2, 3, 4, 3) - 0.5, 'net_deconv') save(net_deconv, torch.rand(2, 3, 4, 3) - 0.5, 'net_deconv')
local net_batch_norm = nn.Sequential() local net_batch_norm = nn.Sequential()
net_batch_norm:add(nn.SpatialBatchNormalization(3)) net_batch_norm:add(nn.SpatialBatchNormalization(4, 1e-3))
save(net_batch_norm, torch.rand(1, 3, 4, 3) - 0.5, 'net_batch_norm') save(net_batch_norm, torch.rand(1, 4, 5, 6) - 0.5, 'net_batch_norm')
\ No newline at end of file
local net_prelu = nn.Sequential()
net_prelu:add(nn.PReLU(5))
save(net_prelu, torch.rand(1, 5, 40, 50) - 0.5, 'net_prelu')
local net_cadd_table = nn.Sequential()
local sum = nn.ConcatTable()
sum:add(nn.Identity()):add(nn.Identity())
net_cadd_table:add(sum):add(nn.CAddTable())
save(net_cadd_table, torch.rand(1, 5, 40, 50) - 0.5, 'net_cadd_table')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment