detection_output_layer_test.cpp 4.72 KB
Newer Older
1
// Copyright (C) 2018-2019 Intel Corporation
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>
#include <string.h>
#include <ie_builders.hpp>
#include <builders/ie_detection_output_layer.hpp>

#include "builder_test.hpp"

using namespace testing;
using namespace InferenceEngine;

class DetectionOutputLayerBuilderTest : public BuilderTestCommon {};

TEST_F(DetectionOutputLayerBuilderTest, getExistsLayerFromNetworkBuilder) {
    Builder::Network network("network");
    Builder::DetectionOutputLayer layer("detection output layer");
    layer.setNumClasses(2);
    layer.setShareLocation(true);
    layer.setBackgroudLabelId(-1);
    layer.setNMSThreshold(0.45);
    layer.setTopK(400);
    layer.setCodeType("caffe.PriorBoxParameter.CENTER_SIZE");
    layer.setVariantEncodedInTarget(false);
    layer.setKeepTopK(200);
    layer.setConfidenceThreshold(0.01);
    size_t ind = 0;
    ASSERT_NO_THROW(ind = network.addLayer(layer));
    Builder::DetectionOutputLayer layerFromNet(network.getLayer(ind));
    ASSERT_EQ(layerFromNet.getName(), layer.getName());
    ASSERT_EQ(layerFromNet.getNumClasses(), layer.getNumClasses());
    ASSERT_EQ(layerFromNet.getShareLocation(), layer.getShareLocation());
    ASSERT_EQ(layerFromNet.getBackgroudLabelId(), layer.getBackgroudLabelId());
    ASSERT_EQ(layerFromNet.getNMSThreshold(), layer.getNMSThreshold());
    ASSERT_EQ(layerFromNet.getTopK(), layer.getTopK());
    ASSERT_EQ(layerFromNet.getCodeType(), layer.getCodeType());
    ASSERT_EQ(layerFromNet.getVariantEncodedInTarget(), layer.getVariantEncodedInTarget());
    ASSERT_EQ(layerFromNet.getKeepTopK(), layer.getKeepTopK());
    ASSERT_EQ(layerFromNet.getConfidenceThreshold(), layer.getConfidenceThreshold());
}

TEST_F(DetectionOutputLayerBuilderTest, cannotCreateLayerWithWrongNumClasses) {
    Builder::Network network("network");
    Builder::DetectionOutputLayer layer("detection output layer");
    layer.setNumClasses(0);  // here
    layer.setShareLocation(true);
    layer.setBackgroudLabelId(-1);
    layer.setNMSThreshold(0.45);
    layer.setTopK(400);
    layer.setCodeType("caffe.PriorBoxParameter.CENTER_SIZE");
    layer.setVariantEncodedInTarget(false);
    layer.setKeepTopK(200);
    layer.setConfidenceThreshold(0.01);
    ASSERT_THROW(network.addLayer(layer), InferenceEngine::details::InferenceEngineException);
}

TEST_F(DetectionOutputLayerBuilderTest, cannotCreateLayerWithWrongCodeType) {
    Builder::Network network("network");
    Builder::DetectionOutputLayer layer("detection output layer");
    layer.setNumClasses(2);
    layer.setShareLocation(true);
    layer.setBackgroudLabelId(-1);
    layer.setNMSThreshold(0.45);
    layer.setTopK(400);
    layer.setCodeType("trololo");  // here
    layer.setVariantEncodedInTarget(false);
    layer.setKeepTopK(200);
    layer.setConfidenceThreshold(0.01);
    ASSERT_THROW(network.addLayer(layer), InferenceEngine::details::InferenceEngineException);
}

TEST_F(DetectionOutputLayerBuilderTest, cannotCreateLayerWithWrongBackLabelId) {
    Builder::Network network("network");
    Builder::DetectionOutputLayer layer("detection output layer");
    layer.setNumClasses(2);
    layer.setShareLocation(true);
    layer.setBackgroudLabelId(-100);  // here
    layer.setNMSThreshold(0.45);
    layer.setTopK(400);
    layer.setCodeType("caffe.PriorBoxParameter.CENTER_SIZE");
    layer.setVariantEncodedInTarget(false);
    layer.setKeepTopK(200);
    layer.setConfidenceThreshold(0.01);
    ASSERT_THROW(network.addLayer(layer), InferenceEngine::details::InferenceEngineException);
}

TEST_F(DetectionOutputLayerBuilderTest, cannotCreateLayerWithWrongNMSThreshold) {
    Builder::Network network("network");
    Builder::DetectionOutputLayer layer("detection output layer");
    layer.setNumClasses(2);
    layer.setShareLocation(true);
    layer.setBackgroudLabelId(-1);
95
    layer.setNMSThreshold(-0.02);  // here
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    layer.setTopK(400);
    layer.setCodeType("caffe.PriorBoxParameter.CENTER_SIZE");
    layer.setVariantEncodedInTarget(false);
    layer.setKeepTopK(200);
    layer.setConfidenceThreshold(0.01);
    ASSERT_THROW(network.addLayer(layer), InferenceEngine::details::InferenceEngineException);
}

TEST_F(DetectionOutputLayerBuilderTest, cannotCreateLayerWithWrongConfidenceThreshold) {
    Builder::Network network("network");
    Builder::DetectionOutputLayer layer("detection output layer");
    layer.setNumClasses(2);
    layer.setShareLocation(true);
    layer.setBackgroudLabelId(-1);
    layer.setNMSThreshold(0.45);
    layer.setTopK(400);
    layer.setCodeType("caffe.PriorBoxParameter.CENTER_SIZE");
    layer.setVariantEncodedInTarget(false);
    layer.setKeepTopK(200);
115
    layer.setConfidenceThreshold(-0.1);  // here
116
    ASSERT_THROW(network.addLayer(layer), InferenceEngine::details::InferenceEngineException);
117
}