ie_simpler_nms_layer.cpp 3.96 KB
Newer Older
1
// Copyright (C) 2018-2019 Intel Corporation
2 3 4 5
// SPDX-License-Identifier: Apache-2.0
//

#include <builders/ie_simpler_nms_layer.hpp>
6
#include <ie_cnn_layer_builder.h>
7 8 9 10 11
#include <vector>
#include <string>

using namespace InferenceEngine;

12 13
Builder::SimplerNMSLayer::SimplerNMSLayer(const std::string& name): LayerDecorator("SimplerNMS", name) {
    getLayer()->getOutputPorts().resize(1);
14 15
}

16 17 18 19 20 21
Builder::SimplerNMSLayer::SimplerNMSLayer(const Layer::Ptr& layer): LayerDecorator(layer) {
    checkType("SimplerNMS");
}

Builder::SimplerNMSLayer::SimplerNMSLayer(const Layer::CPtr& layer): LayerDecorator(layer) {
    checkType("SimplerNMS");
22 23 24
}

Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setName(const std::string& name) {
25
    getLayer()->setName(name);
26 27 28
    return *this;
}
const std::vector<Port>& Builder::SimplerNMSLayer::getInputPorts() const {
29
    return getLayer()->getInputPorts();
30 31
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setInputPorts(const std::vector<Port>& ports) {
32
    getLayer()->getInputPorts() = ports;
33 34 35
    return *this;
}
const Port& Builder::SimplerNMSLayer::getOutputPort() const {
36
    return getLayer()->getOutputPorts()[0];
37 38
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setOutputPort(const Port& port) {
39
    getLayer()->getOutputPorts()[0] = port;
40 41 42 43
    return *this;
}

size_t Builder::SimplerNMSLayer::getPreNMSTopN() const {
44
    return getLayer()->getParameters().at("pre_nms_topn");
45 46
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setPreNMSTopN(size_t topN) {
47
    getLayer()->getParameters()["pre_nms_topn"] = topN;
48 49 50
    return *this;
}
size_t Builder::SimplerNMSLayer::getPostNMSTopN() const {
51
    return getLayer()->getParameters().at("post_nms_topn");
52 53
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setPostNMSTopN(size_t topN) {
54
    getLayer()->getParameters()["post_nms_topn"] = topN;
55 56 57
    return *this;
}
size_t Builder::SimplerNMSLayer::getFeatStride() const {
58
    return getLayer()->getParameters().at("feat_stride");
59 60
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setFeatStride(size_t featStride) {
61
    getLayer()->getParameters()["feat_stride"] = featStride;
62 63 64
    return *this;
}
size_t Builder::SimplerNMSLayer::getMinBoxSize() const {
65
    return getLayer()->getParameters().at("min_bbox_size");
66 67
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setMinBoxSize(size_t minSize) {
68
    getLayer()->getParameters()["min_bbox_size"] = minSize;
69 70 71
    return *this;
}
size_t Builder::SimplerNMSLayer::getScale() const {
72
    return getLayer()->getParameters().at("scale");
73 74
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setScale(size_t scale) {
75
    getLayer()->getParameters()["scale"] = scale;
76 77 78 79
    return *this;
}

float Builder::SimplerNMSLayer::getCLSThreshold() const {
80
    return getLayer()->getParameters().at("cls_threshold");
81 82
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setCLSThreshold(float threshold) {
83
    getLayer()->getParameters()["cls_threshold"] = threshold;
84 85 86
    return *this;
}
float Builder::SimplerNMSLayer::getIOUThreshold() const {
87
    return getLayer()->getParameters().at("iou_threshold");
88 89
}
Builder::SimplerNMSLayer& Builder::SimplerNMSLayer::setIOUThreshold(float threshold) {
90
    getLayer()->getParameters()["iou_threshold"] = threshold;
91 92
    return *this;
}
93 94 95 96 97 98 99 100 101 102

REG_CONVERTER_FOR(SimplerNMS, [](const CNNLayerPtr& cnnLayer, Builder::Layer& layer) {
    layer.getParameters()["iou_threshold"] = cnnLayer->GetParamAsFloat("iou_threshold");
    layer.getParameters()["cls_threshold"] = cnnLayer->GetParamAsFloat("cls_threshold");
    layer.getParameters()["scale"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("scale"));
    layer.getParameters()["min_bbox_size"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("min_bbox_size"));
    layer.getParameters()["feat_stride"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("feat_stride"));
    layer.getParameters()["pre_nms_topn"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("pre_nms_topn"));
    layer.getParameters()["post_nms_topn"] = static_cast<size_t>(cnnLayer->GetParamAsUInt("post_nms_topn"));
});