batch_norm.cpp 3.29 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16 17 18 19

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

20
#include "ngraph/op/batch_norm.hpp"
21 22 23 24
#include "pyngraph/ops/batch_norm.hpp"

namespace py = pybind11;

25
void regclass_pyngraph_op_BatchNormTraining(py::module m)
26
{
27 28 29 30 31 32
    py::class_<ngraph::op::BatchNormTraining,
               std::shared_ptr<ngraph::op::BatchNormTraining>,
               ngraph::op::Op>
        batch_norm_training(m, "BatchNormTraining");
    batch_norm_training.doc() =
        "ngraph.impl.op.BatchNormTraining wraps ngraph::op::BatchNormTraining";
33
    batch_norm_training.def(py::init<const std::shared_ptr<ngraph::Node>&,
34 35
                                     const std::shared_ptr<ngraph::Node>&,
                                     const std::shared_ptr<ngraph::Node>&,
36
                                     double>());
37 38 39 40 41 42 43 44 45 46
}

void regclass_pyngraph_op_BatchNormInference(py::module m)
{
    py::class_<ngraph::op::BatchNormInference,
               std::shared_ptr<ngraph::op::BatchNormInference>,
               ngraph::op::Op>
        batch_norm_inference(m, "BatchNormInference");
    batch_norm_inference.doc() =
        "ngraph.impl.op.BatchNormInference wraps ngraph::op::BatchNormInference";
47

48
    batch_norm_inference.def(py::init<const std::shared_ptr<ngraph::Node>&,
49 50 51 52
                                      const std::shared_ptr<ngraph::Node>&,
                                      const std::shared_ptr<ngraph::Node>&,
                                      const std::shared_ptr<ngraph::Node>&,
                                      const std::shared_ptr<ngraph::Node>&,
53
                                      double>());
54 55
}

56
void regclass_pyngraph_op_BatchNormTrainingBackprop(py::module m)
57
{
58 59
    py::class_<ngraph::op::BatchNormTrainingBackprop,
               std::shared_ptr<ngraph::op::BatchNormTrainingBackprop>,
60
               ngraph::op::Op>
61 62 63
        batch_norm_training_backprop(m, "BatchNormTrainingBackprop");
    batch_norm_training_backprop.doc() =
        "ngraph.impl.op.BatchNormTrainingBackprop wraps ngraph::op::BatchNormTrainingBackprop";
64
    batch_norm_training_backprop.def(py::init<const std::shared_ptr<ngraph::Node>&,
65 66 67 68 69
                                              const std::shared_ptr<ngraph::Node>&,
                                              const std::shared_ptr<ngraph::Node>&,
                                              const std::shared_ptr<ngraph::Node>&,
                                              const std::shared_ptr<ngraph::Node>&,
                                              const std::shared_ptr<ngraph::Node>&,
70
                                              double>());
71
}