Commit 13770af2 authored by tsocha's avatar tsocha Committed by Nick Korovaiko

[Py] Enable Pad op, hotfix for tox (#804)

- Enable Padding op
- Supress multiline comment warning
- improve tox configuration
parent e189f9c6
...@@ -50,6 +50,7 @@ from ngraph.ops import minimum ...@@ -50,6 +50,7 @@ from ngraph.ops import minimum
from ngraph.ops import multiply from ngraph.ops import multiply
from ngraph.ops import negative from ngraph.ops import negative
from ngraph.ops import not_equal from ngraph.ops import not_equal
from ngraph.ops import pad
from ngraph.ops import parameter from ngraph.ops import parameter
from ngraph.ops import prod from ngraph.ops import prod
from ngraph.ops import reshape from ngraph.ops import reshape
......
...@@ -71,6 +71,7 @@ from _pyngraph.op import NotEqual ...@@ -71,6 +71,7 @@ from _pyngraph.op import NotEqual
from _pyngraph.op import Not from _pyngraph.op import Not
from _pyngraph.op import OneHot from _pyngraph.op import OneHot
from _pyngraph.op import Op from _pyngraph.op import Op
from _pyngraph.op import Pad
from _pyngraph.op import Parameter from _pyngraph.op import Parameter
from _pyngraph.op import ParameterVector from _pyngraph.op import ParameterVector
from _pyngraph.op import Power from _pyngraph.op import Power
......
...@@ -21,9 +21,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N ...@@ -21,9 +21,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Node, N
Shape, Strides Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \ from ngraph.impl.op import Abs, Acos, Add, Asin, Atan, AvgPool, Broadcast, Ceiling, Concat, \
Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq, \ Constant, Convert, Convolution, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, Greater, GreaterEq,\
Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \ Less, LessEq, Log, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, \
Parameter, Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh Pad, Parameter, Product, Reshape, Slice, Softmax, Sqrt, Subtract, Sum, Tanh
from typing import Iterable, List from typing import Iterable, List
...@@ -565,3 +565,33 @@ def softmax(node, axes): # type: (Node, Iterable[int]) -> Node ...@@ -565,3 +565,33 @@ def softmax(node, axes): # type: (Node, Iterable[int]) -> Node
if type(axes) is not set: if type(axes) is not set:
axes = set(axes) axes = set(axes)
return Softmax(node, AxisSet(axes)) return Softmax(node, AxisSet(axes))
@nameable_op
def pad(data_batch, # type: Node
value, # type: Node
padding_below=None, # type: TensorShape
padding_above=None, # type: TensorShape
padding_in=None, # type: TensorShape
name=None, # type: str
):
# type: (...) -> Node
"""Return padding node.
:param data_batch: The input node providing data.
:param value: The node producing the scalar value to be inserted for padding.
:param padding_below: The padding-below widths.
:param padding_above: The padding-above widths.
:param padding_in: The interior-padding widths.
:param name: The optional new name for output node.
:return: Return node that represents a padding of input nodes data.
"""
dim_count = len(data_batch.shape)
if padding_above is None:
padding_above = [0] * dim_count
if padding_below is None:
padding_below = [0] * dim_count
if padding_in is None:
padding_in = [0] * dim_count
return Pad(data_batch, value, Shape(padding_below), Shape(padding_above), Shape(padding_in))
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <pybind11/pybind11.h>
#include "ngraph/op/pad.hpp"
#include "ngraph/shape.hpp"
#include "pyngraph/ops/pad.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Pad(py::module m)
{
py::class_<ngraph::op::Pad,
std::shared_ptr<ngraph::op::Pad>,
ngraph::op::util::RequiresTensorViewArgs>
pad(m, "Pad");
pad.doc() = "ngraph.impl.op.Pad wraps ngraph::op::Pad";
pad.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const ngraph::Shape&,
const ngraph::Shape&,
const ngraph::Shape&>());
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Pad(py::module m);
...@@ -56,6 +56,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -56,6 +56,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Negative(m_op); regclass_pyngraph_op_Negative(m_op);
regclass_pyngraph_op_Not(m_op); regclass_pyngraph_op_Not(m_op);
regclass_pyngraph_op_NotEqual(m_op); regclass_pyngraph_op_NotEqual(m_op);
regclass_pyngraph_op_Pad(m_op);
regclass_pyngraph_op_ParameterVector(m_op); regclass_pyngraph_op_ParameterVector(m_op);
regclass_pyngraph_op_Parameter(m_op); regclass_pyngraph_op_Parameter(m_op);
regclass_pyngraph_op_Power(m_op); regclass_pyngraph_op_Power(m_op);
......
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
#include "pyngraph/ops/max.hpp" #include "pyngraph/ops/max.hpp"
#include "pyngraph/ops/min.hpp" #include "pyngraph/ops/min.hpp"
#include "pyngraph/ops/one_hot.hpp" #include "pyngraph/ops/one_hot.hpp"
#include "pyngraph/ops/pad.hpp"
#include "pyngraph/ops/parameter.hpp" #include "pyngraph/ops/parameter.hpp"
#include "pyngraph/ops/parameter_vector.hpp" #include "pyngraph/ops/parameter_vector.hpp"
#include "pyngraph/ops/power.hpp" #include "pyngraph/ops/power.hpp"
......
...@@ -177,6 +177,7 @@ sources = ['pyngraph/function.cpp', ...@@ -177,6 +177,7 @@ sources = ['pyngraph/function.cpp',
'pyngraph/ops/not_equal.cpp', 'pyngraph/ops/not_equal.cpp',
'pyngraph/ops/op.cpp', 'pyngraph/ops/op.cpp',
'pyngraph/ops/one_hot.cpp', 'pyngraph/ops/one_hot.cpp',
'pyngraph/ops/pad.cpp',
'pyngraph/ops/parameter.cpp', 'pyngraph/ops/parameter.cpp',
'pyngraph/ops/parameter_vector.cpp', 'pyngraph/ops/parameter_vector.cpp',
'pyngraph/ops/power.cpp', 'pyngraph/ops/power.cpp',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment