Commit 045be61c authored by Scott Cyphers's avatar Scott Cyphers

Convert remaininf experimental op dir to output args

parent 0bcf5142
...@@ -26,8 +26,8 @@ using namespace ngraph; ...@@ -26,8 +26,8 @@ using namespace ngraph;
const string op::BatchMatMul::type_name{"BatchMatMul"}; const string op::BatchMatMul::type_name{"BatchMatMul"};
op::BatchMatMul::BatchMatMul(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) op::BatchMatMul::BatchMatMul(const Output<Node>& arg0, const Output<Node>& arg1)
: Op(check_single_output_args({arg0, arg1})) : Op({arg0, arg1})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -35,11 +35,12 @@ namespace ngraph ...@@ -35,11 +35,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchMatMul() = default;
/// \brief Constructs a batch of matmul product operation. /// \brief Constructs a batch of matmul product operation.
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument. /// \param arg1 The node producing the second argument.
BatchMatMul(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); BatchMatMul(const Output<Node>& arg0, const Output<Node>& arg1);
virtual void validate_and_infer_types() override; virtual void validate_and_infer_types() override;
......
...@@ -63,6 +63,13 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector ...@@ -63,6 +63,13 @@ shared_ptr<Node> ngraph::op::CompiledKernel::copy_with_new_args(const NodeVector
return std::make_shared<CompiledKernel>(new_node_list, new_outputs, new_args); return std::make_shared<CompiledKernel>(new_node_list, new_outputs, new_args);
} }
ngraph::op::CompiledKernel::CompiledKernel(const OutputVector& node_list,
const OutputVector& outputs,
const OutputVector& args)
: CompiledKernel(as_node_vector(node_list), as_node_vector(outputs), as_node_vector(args))
{
}
ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list, ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args) const NodeVector& args)
......
...@@ -35,9 +35,13 @@ namespace ngraph ...@@ -35,9 +35,13 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
CompiledKernel() = default;
CompiledKernel(const NodeVector& node_list, CompiledKernel(const NodeVector& node_list,
const NodeVector& outputs, const NodeVector& outputs,
const NodeVector& args); const NodeVector& args);
CompiledKernel(const OutputVector& node_list,
const OutputVector& outputs,
const OutputVector& args);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -23,10 +23,10 @@ using namespace ngraph; ...@@ -23,10 +23,10 @@ using namespace ngraph;
const string op::DynBroadcast::type_name{"DynBroadcast"}; const string op::DynBroadcast::type_name{"DynBroadcast"};
op::DynBroadcast::DynBroadcast(const shared_ptr<Node>& arg, op::DynBroadcast::DynBroadcast(const Output<Node>& arg,
const shared_ptr<Node>& shape, const Output<Node>& shape,
const shared_ptr<Node>& broadcast_axes) const Output<Node>& broadcast_axes)
: Op(check_single_output_args({arg, shape, broadcast_axes})) : Op({arg, shape, broadcast_axes})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -31,15 +31,16 @@ namespace ngraph ...@@ -31,15 +31,16 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
DynBroadcast() = default;
/// \brief Constructs a dynamic broadcast operation. /// \brief Constructs a dynamic broadcast operation.
/// ///
/// \param arg Node that produces the input tensor to be broadcast. /// \param arg Node that produces the input tensor to be broadcast.
/// \param shape Node that produces shape of the output tensor. /// \param shape Node that produces shape of the output tensor.
/// \param broadcast_axes Node that produces the axis positions (0-based) in the result that are being broadcast. The /// \param broadcast_axes Node that produces the axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg. /// remaining axes in shape must be the same as the shape of arg.
DynBroadcast(const std::shared_ptr<Node>& arg, DynBroadcast(const Output<Node>& arg,
const std::shared_ptr<Node>& shape, const Output<Node>& shape,
const std::shared_ptr<Node>& broadcast_axes); const Output<Node>& broadcast_axes);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -21,12 +21,12 @@ using namespace ngraph; ...@@ -21,12 +21,12 @@ using namespace ngraph;
const string op::DynPad::type_name{"DynPad"}; const string op::DynPad::type_name{"DynPad"};
op::DynPad::DynPad(const std::shared_ptr<Node>& arg, op::DynPad::DynPad(const Output<Node>& arg,
const std::shared_ptr<Node>& padding_below, const Output<Node>& padding_below,
const std::shared_ptr<Node>& padding_above, const Output<Node>& padding_above,
const std::shared_ptr<Node>& padding_value, const Output<Node>& padding_value,
op::PadMode pad_mode) op::PadMode pad_mode)
: Op(check_single_output_args({arg, padding_below, padding_above, padding_value})) : Op({arg, padding_below, padding_above, padding_value})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
DynPad() = default;
/// \brief Perform dynamic padding of a tensor /// \brief Perform dynamic padding of a tensor
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
...@@ -37,10 +38,10 @@ namespace ngraph ...@@ -37,10 +38,10 @@ namespace ngraph
/// \param padding_above The node producing the padding-above widths. /// \param padding_above The node producing the padding-above widths.
/// \param padding_value The value to be used for padding. Must be scalar. /// \param padding_value The value to be used for padding. Must be scalar.
/// \param pad_mode The padding mode: CONSTANT(default), EDGE or REFLECT. /// \param pad_mode The padding mode: CONSTANT(default), EDGE or REFLECT.
DynPad(const std::shared_ptr<Node>& arg, DynPad(const Output<Node>& arg,
const std::shared_ptr<Node>& padding_below, const Output<Node>& padding_below,
const std::shared_ptr<Node>& padding_above, const Output<Node>& padding_above,
const std::shared_ptr<Node>& padding_value, const Output<Node>& padding_value,
PadMode pad_mode = PadMode::CONSTANT); PadMode pad_mode = PadMode::CONSTANT);
PadMode get_pad_mode() const { return m_pad_mode; } PadMode get_pad_mode() const { return m_pad_mode; }
......
...@@ -26,17 +26,17 @@ using namespace ngraph; ...@@ -26,17 +26,17 @@ using namespace ngraph;
const string op::DynReplaceSlice::type_name{"DynReplaceSlice"}; const string op::DynReplaceSlice::type_name{"DynReplaceSlice"};
op::DynReplaceSlice::DynReplaceSlice(const shared_ptr<Node>& arg, op::DynReplaceSlice::DynReplaceSlice(const Output<Node>& arg,
const shared_ptr<Node>& replacement, const Output<Node>& replacement,
const shared_ptr<Node>& lower_bounds, const Output<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds, const Output<Node>& upper_bounds,
const shared_ptr<Node>& strides, const Output<Node>& strides,
const AxisSet& lower_bounds_mask, const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask, const AxisSet& upper_bounds_mask,
const AxisSet& new_axis, const AxisSet& new_axis,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
: Op(check_single_output_args({arg, replacement, lower_bounds, upper_bounds, strides})) : Op({arg, replacement, lower_bounds, upper_bounds, strides})
, m_lower_bounds_mask(lower_bounds_mask) , m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask) , m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis) , m_new_axis(new_axis)
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
DynReplaceSlice() = default;
/// \brief Constructs a dynamic tensor replace-slice operation. /// \brief Constructs a dynamic tensor replace-slice operation.
/// ///
/// \param arg The tensor in which to replace the slice. /// \param arg The tensor in which to replace the slice.
...@@ -43,11 +44,11 @@ namespace ngraph ...@@ -43,11 +44,11 @@ namespace ngraph
/// \param new_axis Add dimension one axis at the set positions /// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions /// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position /// \param ellipsis_mask Inserts missing dimensions on the set position
DynReplaceSlice(const std::shared_ptr<Node>& arg, DynReplaceSlice(const Output<Node>& arg,
const std::shared_ptr<Node>& replacement, const Output<Node>& replacement,
const std::shared_ptr<Node>& lower_bounds, const Output<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds, const Output<Node>& upper_bounds,
const std::shared_ptr<Node>& strides, const Output<Node>& strides,
const AxisSet& lower_bounds_mask = AxisSet{}, const AxisSet& lower_bounds_mask = AxisSet{},
const AxisSet& upper_bounds_mask = AxisSet{}, const AxisSet& upper_bounds_mask = AxisSet{},
const AxisSet& new_axis = AxisSet{}, const AxisSet& new_axis = AxisSet{},
......
...@@ -26,10 +26,8 @@ using namespace ngraph; ...@@ -26,10 +26,8 @@ using namespace ngraph;
const string op::DynReshape::type_name{"DynReshape"}; const string op::DynReshape::type_name{"DynReshape"};
op::DynReshape::DynReshape(const shared_ptr<Node>& arg, op::DynReshape::DynReshape(const Output<Node>& arg, const Output<Node>& pattern, bool zero_flag)
const shared_ptr<Node>& pattern, : Op({arg, pattern})
bool zero_flag)
: Op(check_single_output_args({arg, pattern}))
, m_zero_flag(zero_flag) , m_zero_flag(zero_flag)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -34,6 +34,7 @@ namespace ngraph ...@@ -34,6 +34,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
DynReshape() = default;
/// \brief Constructs a dynamic reshape operation. This operation does not perform transpose. /// \brief Constructs a dynamic reshape operation. This operation does not perform transpose.
/// ///
/// \param arg The tensor to be reshaped. /// \param arg The tensor to be reshaped.
...@@ -44,8 +45,8 @@ namespace ngraph ...@@ -44,8 +45,8 @@ namespace ngraph
/// size is inferred based on element count of input tensor. /// size is inferred based on element count of input tensor.
/// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy from input /// \param zero_flag Treats zeros in `pattern` as wildcard flags indicating a copy from input
/// shape at the same index. /// shape at the same index.
DynReshape(const std::shared_ptr<Node>& arg, DynReshape(const Output<Node>& arg,
const std::shared_ptr<Node>& pattern, const Output<Node>& pattern,
bool zero_flag = false); bool zero_flag = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -26,16 +26,16 @@ using namespace ngraph; ...@@ -26,16 +26,16 @@ using namespace ngraph;
const string op::DynSlice::type_name{"DynSlice"}; const string op::DynSlice::type_name{"DynSlice"};
op::DynSlice::DynSlice(const shared_ptr<Node>& arg, op::DynSlice::DynSlice(const Output<Node>& arg,
const shared_ptr<Node>& lower_bounds, const Output<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds, const Output<Node>& upper_bounds,
const shared_ptr<Node>& strides, const Output<Node>& strides,
const AxisSet& lower_bounds_mask, const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask, const AxisSet& upper_bounds_mask,
const AxisSet& new_axis, const AxisSet& new_axis,
const AxisSet& shrink_axis, const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask) const AxisSet& ellipsis_mask)
: Op(check_single_output_args({arg, lower_bounds, upper_bounds, strides})) : Op({arg, lower_bounds, upper_bounds, strides})
, m_lower_bounds_mask(lower_bounds_mask) , m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask) , m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis) , m_new_axis(new_axis)
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
DynSlice() = default;
/// \brief Constructs a dynamic tensor slice operation. /// \brief Constructs a dynamic tensor slice operation.
/// ///
/// \param arg The tensor to be sliced. /// \param arg The tensor to be sliced.
...@@ -42,10 +43,10 @@ namespace ngraph ...@@ -42,10 +43,10 @@ namespace ngraph
/// \param new_axis Add dimension one axis at the set positions /// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions /// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position /// \param ellipsis_mask Inserts missing dimensions on the set position
DynSlice(const std::shared_ptr<Node>& arg, DynSlice(const Output<Node>& arg,
const std::shared_ptr<Node>& lower_bounds, const Output<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds, const Output<Node>& upper_bounds,
const std::shared_ptr<Node>& strides, const Output<Node>& strides,
const AxisSet& lower_bounds_mask = AxisSet{}, const AxisSet& lower_bounds_mask = AxisSet{},
const AxisSet& upper_bounds_mask = AxisSet{}, const AxisSet& upper_bounds_mask = AxisSet{},
const AxisSet& new_axis = AxisSet{}, const AxisSet& new_axis = AxisSet{},
......
...@@ -21,11 +21,6 @@ using namespace ngraph; ...@@ -21,11 +21,6 @@ using namespace ngraph;
const string op::GenerateMask::type_name{"GenerateMask"}; const string op::GenerateMask::type_name{"GenerateMask"};
op::GenerateMask::GenerateMask()
: Op()
{
}
#if 0 #if 0
// Not supported until all transformers use nodes instead of attributes // Not supported until all transformers use nodes instead of attributes
op::GenerateMask::GenerateMask(const Output<Node>& training, op::GenerateMask::GenerateMask(const Output<Node>& training,
......
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a GenerateMask node with a given shape, seed, /// \brief Constructs a GenerateMask node with a given shape, seed,
/// probability and training/inference mode /// probability and training/inference mode
GenerateMask(); GenerateMask() = default;
#if 0 #if 0
/// Switch to dynamic arguments when all transformers have switched to using the node values /// Switch to dynamic arguments when all transformers have switched to using the node values
......
...@@ -25,6 +25,13 @@ using namespace ngraph; ...@@ -25,6 +25,13 @@ using namespace ngraph;
const string op::QuantizedConcat::type_name{"QuantizedConcat"}; const string op::QuantizedConcat::type_name{"QuantizedConcat"};
op::QuantizedConcat::QuantizedConcat(const OutputVector& args, size_t concatenation_axis)
: Op(args)
, m_concatenation_axis(concatenation_axis)
{
constructor_validate_and_infer_types();
}
op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis) op::QuantizedConcat::QuantizedConcat(const NodeVector& args, size_t concatenation_axis)
: Op(check_single_output_args(args)) : Op(check_single_output_args(args))
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
......
...@@ -31,12 +31,19 @@ namespace ngraph ...@@ -31,12 +31,19 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedConcat() = default;
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors. /// \param concatenation_axis The axis along which to concatenate the input tensors.
QuantizedConcat(const NodeVector& args, size_t concatenation_axis); QuantizedConcat(const NodeVector& args, size_t concatenation_axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
QuantizedConcat(const OutputVector& args, size_t concatenation_axis);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedConvolutionBias() = default;
QuantizedConvolutionBias(const Output<Node>& data_batch, QuantizedConvolutionBias(const Output<Node>& data_batch,
const Output<Node>& filters, const Output<Node>& filters,
const Output<Node>& bias, const Output<Node>& bias,
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedConvolutionRelu() = default;
QuantizedConvolutionRelu(const Output<Node>& data_batch, QuantizedConvolutionRelu(const Output<Node>& data_batch,
const Output<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedDot() = default;
QuantizedDot(const Output<Node>& data, QuantizedDot(const Output<Node>& data,
const Output<Node>& weights, const Output<Node>& weights,
const Output<Node>& scale, const Output<Node>& scale,
......
...@@ -30,6 +30,7 @@ namespace ngraph ...@@ -30,6 +30,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedDotBias() = default;
QuantizedDotBias(const Output<Node>& data, QuantizedDotBias(const Output<Node>& data,
const Output<Node>& weights, const Output<Node>& weights,
const Output<Node>& bias, const Output<Node>& bias,
......
...@@ -29,6 +29,7 @@ namespace ngraph ...@@ -29,6 +29,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
QuantizedMaxPool() = default;
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
......
...@@ -24,10 +24,6 @@ using namespace ngraph; ...@@ -24,10 +24,6 @@ using namespace ngraph;
const string op::Range::type_name = "Range"; const string op::Range::type_name = "Range";
op::Range::Range()
{
}
op::Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step) op::Range::Range(const Output<Node>& start, const Output<Node>& stop, const Output<Node>& step)
: Op({start, stop, step}) : Op({start, stop, step})
{ {
......
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized range operation. /// \brief Constructs an unitialized range operation.
Range(); Range() = default;
/// \brief Constructs a range operation. /// \brief Constructs a range operation.
/// ///
......
...@@ -22,8 +22,8 @@ using namespace ngraph; ...@@ -22,8 +22,8 @@ using namespace ngraph;
const string op::ShapeOf::type_name{"ShapeOf"}; const string op::ShapeOf::type_name{"ShapeOf"};
op::ShapeOf::ShapeOf(const shared_ptr<Node>& arg) op::ShapeOf::ShapeOf(const Output<Node>& arg)
: Op(check_single_output_args({arg})) : Op({arg})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,8 +29,9 @@ namespace ngraph ...@@ -29,8 +29,9 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ShapeOf() = default;
/// \brief Constructs a shape-of operation. /// \brief Constructs a shape-of operation.
ShapeOf(const std::shared_ptr<Node>& arg); ShapeOf(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -23,8 +23,8 @@ using namespace ngraph; ...@@ -23,8 +23,8 @@ using namespace ngraph;
const string op::Tile::type_name{"Tile"}; const string op::Tile::type_name{"Tile"};
op::Tile::Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats) op::Tile::Tile(const Output<Node>& arg, const Output<Node>& repeats)
: Op(check_single_output_args({arg, repeats})) : Op({arg, repeats})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -30,11 +30,12 @@ namespace ngraph ...@@ -30,11 +30,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Tile() = default;
/// \brief Perform dynamic padding of a tensor /// \brief Perform dynamic padding of a tensor
/// ///
/// \param arg The node producing input tensor to be padded. /// \param arg The node producing input tensor to be padded.
/// \param repeats The node producing the per-dimension replication factor /// \param repeats The node producing the per-dimension replication factor
Tile(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& repeats); Tile(const Output<Node>& arg, const Output<Node>& repeats);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -24,8 +24,8 @@ using namespace ngraph; ...@@ -24,8 +24,8 @@ using namespace ngraph;
const string op::Transpose::type_name{"Transpose"}; const string op::Transpose::type_name{"Transpose"};
op::Transpose::Transpose(const shared_ptr<Node>& arg, const shared_ptr<Node>& input_order) op::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_order)
: Op(check_single_output_args({arg, input_order})) : Op({arg, input_order})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -31,6 +31,7 @@ namespace ngraph ...@@ -31,6 +31,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Transpose() = default;
/// \brief Constructs a transpose operation. /// \brief Constructs a transpose operation.
/// ///
/// \param arg Node producing the tensor to be transposed. /// \param arg Node producing the tensor to be transposed.
...@@ -38,7 +39,7 @@ namespace ngraph ...@@ -38,7 +39,7 @@ namespace ngraph
/// input shape. Must be a vector of element type element::i64, /// input shape. Must be a vector of element type element::i64,
/// with shape [n], where n is the rank of arg. The tensor's /// with shape [n], where n is the rank of arg. The tensor's
/// value must contain every integer in the range [0,n-1]. /// value must contain every integer in the range [0,n-1].
Transpose(const std::shared_ptr<Node>& arg, const std::shared_ptr<Node>& input_order); Transpose(const Output<Node>& arg, const Output<Node>& input_order);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment