Commit 13770af2 authored by tsocha's avatar tsocha Committed by Nick Korovaiko

[Py] Enable Pad op, hotfix for tox (#804)

- Enable Padding op
- Supress multiline comment warning
- improve tox configuration
parent e189f9c6
......@@ -50,6 +50,7 @@ from ngraph.ops import minimum
from ngraph.ops import multiply
from ngraph.ops import negative
from ngraph.ops import not_equal
from ngraph.ops import pad
from ngraph.ops import parameter
from ngraph.ops import prod
from ngraph.ops import reshape
......
......@@ -71,6 +71,7 @@ from _pyngraph.op import NotEqual
from _pyngraph.op import Not
from _pyngraph.op import OneHot
from _pyngraph.op import Op
from _pyngraph.op import Pad
from _pyngraph.op import Parameter
from _pyngraph.op import ParameterVector
from _pyngraph.op import Power
......
......@@ -21,9 +21,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N
Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, \
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\
Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \
Parameter, Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh
Pad, Parameter, Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List
......@@ -565,3 +565,33 @@ def softmax(node, axes): # type: (Node, Iterable[int]) -> Node
if type(axes) is not set:
axes = set(axes)
return Softmax(node, AxisSet(axes))
@nameable_op
def pad(data_batch, # type: Node
value, # type: Node
padding_below=None, # type: TensorShape
padding_above=None, # type: TensorShape
padding_in=None, # type: TensorShape
name=None, # type: str
):
# type: (...) -> Node
"""Return padding node.
:param data_batch: The input node providing data.
:param value: The node producing the scalar value to be inserted for padding.
:param padding_below: The padding-below widths.
:param padding_above: The padding-above widths.
:param padding_in: The interior-padding widths.
:param name: The optional new name for output node.
:return: Return node that represents a padding of input nodes data.
"""
dim_count = len(data_batch.shape)
if padding_above is None:
padding_above = [0] * dim_count
if padding_below is None:
padding_below = [0] * dim_count
if padding_in is None:
padding_in = [0] * dim_count
return Pad(data_batch, value, Shape(padding_below), Shape(padding_above), Shape(padding_in))
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include "ngraph/op/pad.hpp"
#include "ngraph/shape.hpp"
#include "pyngraph/ops/pad.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Pad(py::module m)
{
py::class_<ngraph::op::Pad,
std::shared_ptr<ngraph::op::Pad>,
ngraph::op::util::RequiresTensorViewArgs>
pad(m, "Pad");
pad.doc() = "ngraph.impl.op.Pad wraps ngraph::op::Pad";
pad.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::Shape&,
const ngraph::Shape&,
const ngraph::Shape&>());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Pad(py::module m);
......@@ -56,6 +56,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Negative(m_op);
regclass_pyngraph_op_Not(m_op);
regclass_pyngraph_op_NotEqual(m_op);
regclass_pyngraph_op_Pad(m_op);
regclass_pyngraph_op_ParameterVector(m_op);
regclass_pyngraph_op_Parameter(m_op);
regclass_pyngraph_op_Power(m_op);
......
......@@ -56,6 +56,7 @@
#include "pyngraph/ops/max.hpp"
#include "pyngraph/ops/min.hpp"
#include "pyngraph/ops/one_hot.hpp"
#include "pyngraph/ops/pad.hpp"
#include "pyngraph/ops/parameter.hpp"
#include "pyngraph/ops/parameter_vector.hpp"
#include "pyngraph/ops/power.hpp"
......
......@@ -177,6 +177,7 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/not_equal.cpp',
'pyngraph/ops/op.cpp',
'pyngraph/ops/one_hot.cpp',
'pyngraph/ops/pad.cpp',
'pyngraph/ops/parameter.cpp',
'pyngraph/ops/parameter_vector.cpp',
'pyngraph/ops/power.cpp',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment