Unverified Commit 1e100a5a authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/25tomaster

parents 6403865a 3fec597e
......@@ -37,6 +37,7 @@ ngraph.ops
elu
equal
exp
fake_quantize
floor
gelu
gemm
......
......@@ -50,6 +50,7 @@ from ngraph.ops import dot
from ngraph.ops import elu
from ngraph.ops import equal
from ngraph.ops import exp
from ngraph.ops import fake_quantize
from ngraph.ops import floor
from ngraph.ops import gelu
from ngraph.ops import gemm
......
......@@ -74,6 +74,7 @@ from _pyngraph.op import Dot
from _pyngraph.op import Elu
from _pyngraph.op import Equal
from _pyngraph.op import Exp
from _pyngraph.op import FakeQuantize
from _pyngraph.op import Floor
from _pyngraph.op import Gelu
from _pyngraph.op import Gemm
......
......@@ -22,9 +22,9 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, Equal, Exp, \
Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, LRN, Max, \
Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Convolution, ConvolutionBackpropData, Cos, Cosh, DepthToSpace, Divide, Dot, Elu, FakeQuantize, \
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \
LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK
......@@ -537,6 +537,38 @@ def broadcast_to(node, new_shape, axis=None, name=None):
@nameable_op
def fake_quantize(data, input_low, input_high, output_low, output_high, levels, name=None):
# type: (Node, Node, Node, Node, Node, int, str) -> Node
r"""Perform an element-wise linear quantization on input data.
Input floating point values are quantized into a discrete set of floating point values.
.. code-block:: python
if x <= input_low:
output = output_low
if x > input_high:
output = output_high
else:
output = fake_quantize(output)
Fake quantize uses the following logic:
.. math:: output =
\dfrac{round( \dfrac{data - input\_low}{(input\_high - input\_low)\cdot (levels-1)})}
{(levels-1)\cdot (output\_high - output\_low)} + output\_low
:param data: The node with data tensor.
:param input_low: The node with the minimum for input values.
:param input_high: The node with the maximum for input values.
:param output_low: The node with the minimum quantized value.
:param output_high: The node with the maximum quantized value.
:param levels: The number of quantization levels. Integer value.
:return: New node with quantized value.
"""
return FakeQuantize(data, input_low, input_high, output_low, output_high, levels)
def gemm(A, # type: Node
B, # type: Node
C, # type: Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/fake_quantize.hpp"
#include "pyngraph/ops/fused/fake_quantize.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_FakeQuantize(py::module m)
{
py::class_<ngraph::op::FakeQuantize, std::shared_ptr<ngraph::op::FakeQuantize>, ngraph::op::Op>
fakequantize(m, "FakeQuantize");
fakequantize.doc() = "ngraph.impl.op.FakeQuantize wraps ngraph::op::FakeQuantize";
fakequantize.def(py::init<const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
const std::shared_ptr<ngraph::Node>&,
int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_FakeQuantize(py::module m);
......@@ -54,6 +54,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Elu(m_op);
regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_FakeQuantize(m_op);
regclass_pyngraph_op_Floor(m_op);
regclass_pyngraph_op_Gelu(m_op);
regclass_pyngraph_op_Gemm(m_op);
......
......@@ -45,6 +45,7 @@
#include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/depth_to_space.hpp"
#include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/fake_quantize.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp"
......
......@@ -184,6 +184,7 @@ sources = [
'pyngraph/ops/fused/elu.cpp',
'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp',
'pyngraph/ops/fused/fake_quantize.cpp',
'pyngraph/ops/floor.cpp',
'pyngraph/ops/fused/gelu.cpp',
'pyngraph/ops/fused/gemm.cpp',
......
......@@ -69,6 +69,52 @@ def test_elu_operator_with_scalar():
assert np.allclose(result, expected)
def test_fake_quantize():
runtime = get_runtime()
data_value = np.arange(24.0, dtype=np.float32).reshape(1, 2, 3, 4)
input_low_value = np.float32(0)
input_high_value = np.float32(23)
output_low_value = np.float32(2)
output_high_value = np.float32(16)
levels = np.float32(4)
data_shape = [1, 2, 3, 4]
bound_shape = []
parameter_data = ng.parameter(data_shape, name='data', dtype=np.float32)
parameter_input_low = ng.parameter(bound_shape, name='input_low', dtype=np.float32)
parameter_input_high = ng.parameter(bound_shape, name='input_high', dtype=np.float32)
parameter_output_low = ng.parameter(bound_shape, name='output_low', dtype=np.float32)
parameter_output_high = ng.parameter(bound_shape, name='output_high', dtype=np.float32)
model = ng.fake_quantize(parameter_data,
parameter_input_low,
parameter_input_high,
parameter_output_low,
parameter_output_high,
levels)
computation = runtime.computation(model,
parameter_data,
parameter_input_low,
parameter_input_high,
parameter_output_low,
parameter_output_high)
result = computation(data_value,
input_low_value,
input_high_value,
output_low_value,
output_high_value)
expected = np.array([[[[[2., 2., 2., 2.],
[6.6666669, 6.6666669, 6.6666669, 6.6666669],
[6.6666669, 6.6666669, 6.6666669, 6.6666669]],
[[11.33333301, 11.33333301, 11.33333301, 11.33333301],
[11.33333301, 11.33333301, 11.33333301, 11.33333301],
[16., 16., 16., 16.]]]]], dtype=np.float32)
assert np.allclose(result, expected)
def test_depth_to_space():
runtime = get_runtime()
......@@ -219,4 +265,5 @@ def test_grn_operator():
[[0.9970545, 0.98994946, 0.9805807, 0.97014254],
[0.9593655, 0.9486833, 0.9383431, 0.9284767],
[0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32)
assert np.allclose(result, expected)
......@@ -26,8 +26,8 @@ using namespace ngraph;
const string op::BatchMatMul::type_name{"BatchMatMul"};
op::BatchMatMul::BatchMatMul(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
: Op(check_single_output_args({arg0, arg1}))
op::BatchMatMul::BatchMatMul(const Output<Node>& arg0, const Output<Node>& arg1)
: Op({arg0, arg1})
{
constructor_validate_and_infer_types();
}
......
......@@ -35,11 +35,12 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchMatMul() = default;
/// \brief Constructs a batch of matmul product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
BatchMatMul(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
BatchMatMul(const Output<Node>& arg0, const Output<Node>& arg1);
virtual void validate_and_infer_types() override;
......
......@@ -63,6 +63,13 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
return std::make_shared<CompiledKernel>(new_node_list, new_outputs, new_args);
}
ngraph::op::CompiledKernel::CompiledKernel(const OutputVector& node_list,
const OutputVector& outputs,
const OutputVector& args)
: CompiledKernel(as_node_vector(node_list), as_node_vector(outputs), as_node_vector(args))
{
}
ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs,
const NodeVector& args)
......
......@@ -35,9 +35,13 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
CompiledKernel() = default;
CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs,
const NodeVector& args);
CompiledKernel(const OutputVector& node_list,
const OutputVector& outputs,
const OutputVector& args);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -23,10 +23,10 @@ using namespace ngraph;
const string op::DynBroadcast::type_name{"DynBroadcast"};
op::DynBroadcast::DynBroadcast(const shared_ptr<Node>& arg,
const shared_ptr<Node>& shape,
const shared_ptr<Node>& broadcast_axes)
: Op(check_single_output_args({arg, shape, broadcast_axes}))
op::DynBroadcast::DynBroadcast(const Output<Node>& arg,
const Output<Node>& shape,
const Output<Node>& broadcast_axes)
: Op({arg, shape, broadcast_axes})
{
constructor_validate_and_infer_types();
}
......
......@@ -31,15 +31,16 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
DynBroadcast() = default;
/// \brief Constructs a dynamic broadcast operation.
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param shape Node that produces shape of the output tensor.
/// \param broadcast_axes Node that produces the axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg.
DynBroadcast(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& shape,
const std::shared_ptr<Node>& broadcast_axes);
DynBroadcast(const Output<Node>& arg,
const Output<Node>& shape,
const Output<Node>& broadcast_axes);
void validate_and_infer_types() override;
......
......@@ -21,12 +21,13 @@ using namespace ngraph;
const string op::DynPad::type_name{"DynPad"};
op::DynPad::DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value,
op::DynPad::DynPad(const Output<Node>& arg,
const Output<Node>& padding_below,
const Output<Node>& padding_above,
const Output<Node>& padding_value,
op::PadMode pad_mode)
: Op(check_single_output_args({arg, padding_below, padding_above, padding_value}))
: Op({arg, padding_below, padding_above, padding_value})
, m_pad_mode(pad_mode)
{
constructor_validate_and_infer_types();
}
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
DynPad() = default;
/// \brief Perform dynamic padding of a tensor
///
/// \param arg The node producing input tensor to be padded.
......@@ -37,10 +38,10 @@ namespace ngraph
/// \param padding_above The node producing the padding-above widths.
/// \param padding_value The value to be used for padding. Must be scalar.
/// \param pad_mode The padding mode: CONSTANT(default), EDGE or REFLECT.
DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value,
DynPad(const Output<Node>& arg,
const Output<Node>& padding_below,
const Output<Node>& padding_above,
const Output<Node>& padding_value,
PadMode pad_mode = PadMode::CONSTANT);
PadMode get_pad_mode() const { return m_pad_mode; }
......
......@@ -26,17 +26,17 @@ using namespace ngraph;
const string op::DynReplaceSlice::type_name{"DynReplaceSlice"};
op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& replacement,
const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds,
const shared_ptr<Node>& strides,
op::DynReplaceSlice::DynReplaceSlice(const Output<Node>& arg,
const Output<Node>& replacement,
const Output<Node>& lower_bounds,
const Output<Node>& upper_bounds,
const Output<Node>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
: Op(check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides}))
: Op({arg, replacement, lower_bounds, upper_bounds, strides})
, m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis)
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
DynReplaceSlice() = default;
/// \brief Constructs a dynamic tensor replace-slice operation.
///
/// \param arg The tensor in which to replace the slice.
......@@ -43,11 +44,11 @@ namespace ngraph
/// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position
DynReplaceSlice(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& replacement,
const std::shared_ptr<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds,
const std::shared_ptr<Node>& strides,
DynReplaceSlice(const Output<Node>& arg,
const Output<Node>& replacement,
const Output<Node>& lower_bounds,
const Output<Node>& upper_bounds,
const Output<Node>& strides,
const AxisSet& lower_bounds_mask = AxisSet{},
const AxisSet& upper_bounds_mask = AxisSet{},
const AxisSet& new_axis = AxisSet{},
......
......@@ -26,10 +26,8 @@ using namespace ngraph;
const string op::DynReshape::type_name{"DynReshape"};
op::DynReshape::DynReshape(const shared_ptr<Node>& arg,
const shared_ptr<Node>& pattern,
bool zero_flag)
: Op(check_single_output_args({arg, pattern}))
op::DynReshape::DynReshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
: Op({arg, pattern})
, m_zero_flag(zero_flag)
{
constructor_validate_and_infer_types();
......
......@@ -34,6 +34,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
DynReshape() = default;
/// \brief Constructs a dynamic reshape operation. This operation does not perform transpose.
///
/// \param arg The tensor to be reshaped.
......@@ -44,8 +45,8 @@ namespace ngraph
/// size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy from input
/// shape at the same index.
DynReshape(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& pattern,
DynReshape(const Output<Node>& arg,
const Output<Node>& pattern,
bool zero_flag = false);
void validate_and_infer_types() override;
......
......@@ -26,16 +26,16 @@ using namespace ngraph;
const string op::DynSlice::type_name{"DynSlice"};
op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds,
const shared_ptr<Node>& strides,
op::DynSlice::DynSlice(const Output<Node>& arg,
const Output<Node>& lower_bounds,
const Output<Node>& upper_bounds,
const Output<Node>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
: Op(check_single_output_args({arg, lower_bounds, upper_bounds, strides}))
: Op({arg, lower_bounds, upper_bounds, strides})
, m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis)
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
DynSlice() = default;
/// \brief Constructs a dynamic tensor slice operation.
///
/// \param arg The tensor to be sliced.
......@@ -42,10 +43,10 @@ namespace ngraph
/// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position
DynSlice(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds,
const std::shared_ptr<Node>& strides,
DynSlice(const Output<Node>& arg,
const Output<Node>& lower_bounds,
const Output<Node>& upper_bounds,
const Output<Node>& strides,
const AxisSet& lower_bounds_mask = AxisSet{},
const AxisSet& upper_bounds_mask = AxisSet{},
const AxisSet& new_axis = AxisSet{},
......
......@@ -21,11 +21,6 @@ using namespace ngraph;
const string op::GenerateMask::type_name{"GenerateMask"};
op::GenerateMask::GenerateMask()
: Op()
{
}
#if 0
// Not supported until all transformers use nodes instead of attributes
op::GenerateMask::GenerateMask(const Output<Node>& training,
......
......@@ -34,7 +34,7 @@ namespace ngraph
const std::string& description() const override { return type_name; }
/// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode
GenerateMask();
GenerateMask() = default;
#if 0
/// Switch to dynamic arguments when all transformers have switched to using the node values
......
......@@ -25,6 +25,13 @@ using namespace ngraph;
const string op::QuantizedConcat::type_name{"QuantizedConcat"};
op::QuantizedConcat::QuantizedConcat(const OutputVector& args, size_t concatenation_axis)
: Op(args)
, m_concatenation_axis(concatenation_axis)
{
constructor_validate_and_infer_types();
}
op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis)
: Op(check_single_output_args(args))
, m_concatenation_axis(concatenation_axis)
......
......@@ -31,12 +31,19 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConcat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
QuantizedConcat(const NodeVector& args, size_t concatenation_axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
QuantizedConcat(const OutputVector& args, size_t concatenation_axis);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionBias() = default;
QuantizedConvolutionBias(const Output<Node>& data_batch,
const Output<Node>& filters,
const Output<Node>& bias,
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedConvolutionRelu() = default;
QuantizedConvolutionRelu(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedDot() = default;
QuantizedDot(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& scale,
......
......@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedDotBias() = default;
QuantizedDotBias(const Output<Node>& data,
const Output<Node>& weights,
const Output<Node>& bias,
......
......@@ -29,6 +29,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
QuantizedMaxPool() = default;
/// \brief Constructs a batched max pooling operation.
///
/// \param arg The node producing the input data batch tensor.
......
......@@ -24,10 +24,6 @@ using namespace ngraph;
const string op::Range::type_name = "Range";
op::Range::Range()
{
}
op::Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step)
: Op({start, stop, step})
{
......
......@@ -31,7 +31,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized range operation.
Range();
Range() = default;
/// \brief Constructs a range operation.
///
......
......@@ -22,8 +22,8 @@ using namespace ngraph;
const string op::ShapeOf::type_name{"ShapeOf"};
op::ShapeOf::ShapeOf(const shared_ptr<Node>& arg)
: Op(check_single_output_args({arg}))
op::ShapeOf::ShapeOf(const Output<Node>& arg)
: Op({arg})
{
constructor_validate_and_infer_types();
}
......
......@@ -29,8 +29,9 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
ShapeOf() = default;
/// \brief Constructs a shape-of operation.
ShapeOf(const std::shared_ptr<Node>& arg);
ShapeOf(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -23,8 +23,8 @@ using namespace ngraph;
const string op::Tile::type_name{"Tile"};
op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats)
: Op(check_single_output_args({arg, repeats}))
op::Tile::Tile(const Output<Node>& arg, const Output<Node>& repeats)
: Op({arg, repeats})
{
constructor_validate_and_infer_types();
}
......
......@@ -30,11 +30,12 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Tile() = default;
/// \brief Perform dynamic padding of a tensor
///
/// \param arg The node producing input tensor to be padded.
/// \param repeats The node producing the per-dimension replication factor
Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats);
Tile(const Output<Node>& arg, const Output<Node>& repeats);
void validate_and_infer_types() override;
......
......@@ -24,8 +24,8 @@ using namespace ngraph;
const string op::Transpose::type_name{"Transpose"};
op::Transpose::Transpose(const shared_ptr<Node>& arg, const shared_ptr<Node>& input_order)
: Op(check_single_output_args({arg, input_order}))
op::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_order)
: Op({arg, input_order})
{
constructor_validate_and_infer_types();
}
......
......@@ -31,6 +31,7 @@ namespace ngraph
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Transpose() = default;
/// \brief Constructs a transpose operation.
///
/// \param arg Node producing the tensor to be transposed.
......@@ -38,7 +39,7 @@ namespace ngraph
/// input shape. Must be a vector of element type element::i64,
/// with shape [n], where n is the rank of arg. The tensor's
/// value must contain every integer in the range [0,n-1].
Transpose(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& input_order);
Transpose(const Output<Node>& arg, const Output<Node>& input_order);
void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment