mkldnn_fullyconnected_node.h 1.59 KB
Newer Older
1
// Copyright (C) 2018-2020 Intel Corporation
openvino-pushbot's avatar
openvino-pushbot committed
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ie_common.h>
#include <mkldnn_node.h>
#include <memory>
#include <string>
#include <vector>

namespace MKLDNNPlugin {

class MKLDNNFullyConnectedNode : public MKLDNNNode {
public:
17
    MKLDNNFullyConnectedNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, int socket);
openvino-pushbot's avatar
openvino-pushbot committed
18 19 20 21 22 23 24 25 26 27 28 29 30
    ~MKLDNNFullyConnectedNode() override = default;

    void getSupportedDescriptors() override;
    void createPrimitive() override;
    bool created() const override;
    bool canBeInPlace() const override {
        return false;
    }

    const std::vector<impl_desc_type>& getPrimitivesPriority() override;
    void createDescriptor(const std::vector<InferenceEngine::TensorDesc>& inputDesc,
                          const std::vector<InferenceEngine::TensorDesc>& outputDesc) override;

31 32 33 34 35 36 37 38 39
    size_t descInputNumbers(MKLDNNDescriptor desc) override {
        return static_cast<size_t>(baseInputsNumber);
    }

    MKLDNNMemoryDesc getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) override;

    const mkldnn::memory& getWeights() const;
    const mkldnn::memory& getBias() const;

40 41 42
protected:
    std::shared_ptr<mkldnn::primitive_attr> initPrimitiveAttr() const override;

openvino-pushbot's avatar
openvino-pushbot committed
43 44 45 46
private:
    InferenceEngine::SizeVector weightsDims;
    InferenceEngine::SizeVector biasesDims;
    mkldnn::memory::format weightsFormatForSrcFormat(mkldnn::memory::format sourceFormat);
47 48

    InferenceEngine::Blob::Ptr wScale, oScale;
49 50 51

    bool withBiases;
    int baseInputsNumber;
openvino-pushbot's avatar
openvino-pushbot committed
52 53 54 55
};

}  // namespace MKLDNNPlugin