Unverified Commit b69d3fd6 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert fused ops to value-based constructors (#3385)

parent 31fae943
...@@ -674,6 +674,16 @@ OutputVector ngraph::as_output_vector(const NodeVector& args) ...@@ -674,6 +674,16 @@ OutputVector ngraph::as_output_vector(const NodeVector& args)
return output_vector; return output_vector;
} }
NodeVector ngraph::as_node_vector(const OutputVector& values)
{
NodeVector node_vector;
for (auto& value : values)
{
node_vector.push_back(value.as_single_output_node());
}
return node_vector;
}
std::tuple<element::Type, PartialShape> std::tuple<element::Type, PartialShape>
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob) Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
{ {
......
...@@ -74,6 +74,7 @@ namespace ngraph ...@@ -74,6 +74,7 @@ namespace ngraph
const NodeVector& check_single_output_args(const NodeVector& args); const NodeVector& check_single_output_args(const NodeVector& args);
OutputVector as_output_vector(const NodeVector& args); OutputVector as_output_vector(const NodeVector& args);
NodeVector as_node_vector(const OutputVector& values);
/// Alias useful for cloning /// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>; using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
......
...@@ -40,11 +40,11 @@ void op::Clamp::pre_validate_and_infer_types() ...@@ -40,11 +40,11 @@ void op::Clamp::pre_validate_and_infer_types()
NodeVector op::Clamp::decompose_op() const NodeVector op::Clamp::decompose_op() const
{ {
const auto data = get_argument(0); const auto data = input(0).get_source_output();
const auto data_shape = data->get_shape(); const auto data_shape = data.get_shape();
const auto clamp_min = builder::make_constant(data->get_element_type(), data_shape, m_min); const auto clamp_min = builder::make_constant(data.get_element_type(), data_shape, m_min);
const auto clamp_max = builder::make_constant(data->get_element_type(), data_shape, m_max); const auto clamp_max = builder::make_constant(data.get_element_type(), data_shape, m_max);
return {std::make_shared<ngraph::op::Minimum>( return {std::make_shared<ngraph::op::Minimum>(
clamp_max, std::make_shared<ngraph::op::Maximum>(clamp_min, data))}; clamp_max, std::make_shared<ngraph::op::Maximum>(clamp_min, data))};
......
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range /// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range /// \param max - the upper bound of the <min;max> range
Clamp(const Output<ngraph::Node>& data, const double min, const double max); Clamp(const Output<Node>& data, const double min, const double max);
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
......
This diff is collapsed.
...@@ -31,13 +31,14 @@ namespace ngraph ...@@ -31,13 +31,14 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ConvolutionBias() = default;
ConvolutionBias(const std::shared_ptr<op::Convolution>& conv, ConvolutionBias(const std::shared_ptr<op::Convolution>& conv,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const bool with_relu = false); const bool with_relu = false);
ConvolutionBias(const std::shared_ptr<Node>& data_batch, ConvolutionBias(const Output<Node>& data_batch,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
...@@ -45,18 +46,18 @@ namespace ngraph ...@@ -45,18 +46,18 @@ namespace ngraph
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const bool with_relu = false); const bool with_relu = false);
ConvolutionBias(const std::shared_ptr<Node>& data_batch, ConvolutionBias(const Output<Node>& data_batch,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& bias); const Output<Node>& bias);
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; } const Strides& get_window_dilation_strides() const { return m_window_dilation_strides; }
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_bias() { return get_argument(2); } Output<Node> get_bias() { return input(2).get_source_output(); }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -85,10 +86,11 @@ namespace ngraph ...@@ -85,10 +86,11 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ConvolutionBiasBackpropFiltersBias(const std::shared_ptr<Node>& data_batch, ConvolutionBiasBackpropFiltersBias() = default;
ConvolutionBiasBackpropFiltersBias(const Output<Node>& data_batch,
const Shape& filters_shape, const Shape& filters_shape,
const Shape& bias_shape, const Shape& bias_shape,
const std::shared_ptr<Node>& output_delta, const Output<Node>& output_delta,
const Strides& window_movement_strides_forward, const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward, const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward, const CoordinateDiff& padding_below_forward,
...@@ -178,14 +180,15 @@ namespace ngraph ...@@ -178,14 +180,15 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ConvolutionBiasAdd() = default;
ConvolutionBiasAdd(const std::shared_ptr<op::ConvolutionBias>& conv, ConvolutionBiasAdd(const std::shared_ptr<op::ConvolutionBias>& conv,
const std::shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
bool with_relu = false); bool with_relu = false);
ConvolutionBiasAdd(const std::shared_ptr<Node>& data_batch, ConvolutionBiasAdd(const Output<Node>& data_batch,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::shared_ptr<Node>& bias, const Output<Node>& bias,
const std::shared_ptr<Node>& sum_input, const Output<Node>& sum_input,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
...@@ -198,8 +201,8 @@ namespace ngraph ...@@ -198,8 +201,8 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
bool with_relu() const { return m_with_relu; } bool with_relu() const { return m_with_relu; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -27,8 +27,8 @@ using namespace ngraph; ...@@ -27,8 +27,8 @@ using namespace ngraph;
const string op::DepthToSpace::type_name{"DepthToSpace"}; const string op::DepthToSpace::type_name{"DepthToSpace"};
op::DepthToSpace::DepthToSpace(const shared_ptr<Node>& data, const size_t block_size) op::DepthToSpace::DepthToSpace(const Output<Node>& data, const size_t block_size)
: FusedOp(check_single_output_args({data})) : FusedOp({data})
, m_blocksize(block_size) , m_blocksize(block_size)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -36,8 +36,8 @@ op::DepthToSpace::DepthToSpace(const shared_ptr<Node>& data, const size_t block_ ...@@ -36,8 +36,8 @@ op::DepthToSpace::DepthToSpace(const shared_ptr<Node>& data, const size_t block_
NodeVector op::DepthToSpace::decompose_op() const NodeVector op::DepthToSpace::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
const Shape& data_shape = data->get_shape(); const Shape& data_shape = data.get_shape();
// Set default values to each dimension to be able to work with both 3D or 4D data. // Set default values to each dimension to be able to work with both 3D or 4D data.
size_t n{1}, c{1}, h{1}, w{1}; size_t n{1}, c{1}, h{1}, w{1};
......
...@@ -37,11 +37,12 @@ namespace ngraph ...@@ -37,11 +37,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
DepthToSpace() = default;
/// \brief Constructs a DepthToSpace operation. /// \brief Constructs a DepthToSpace operation.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param block_size - the size of the block of values to be moved /// \param block_size - the size of the block of values to be moved
DepthToSpace(const std::shared_ptr<ngraph::Node>& data, std::size_t block_size); DepthToSpace(const Output<Node>& data, std::size_t block_size);
std::size_t get_block_size() const { return m_blocksize; } std::size_t get_block_size() const { return m_blocksize; }
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -30,21 +30,21 @@ using namespace ngraph; ...@@ -30,21 +30,21 @@ using namespace ngraph;
const string op::Elu::type_name{"Elu"}; const string op::Elu::type_name{"Elu"};
op::Elu::Elu(const shared_ptr<Node>& data, const shared_ptr<Node>& alpha) op::Elu::Elu(const Output<Node>& data, const Output<Node>& alpha)
: FusedOp(check_single_output_args({data, alpha})) : FusedOp({data, alpha})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
NodeVector op::Elu::decompose_op() const NodeVector op::Elu::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto alpha_node = get_argument(1); auto alpha_node = input(1).get_source_output();
alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data->get_shape()); alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data.get_shape());
shared_ptr<ngraph::Node> zero_node = shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0); builder::make_constant(data.get_element_type(), data.get_shape(), 0);
return {make_shared<ngraph::op::Maximum>(data, zero_node) + return {make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * alpha_node *
......
...@@ -34,12 +34,12 @@ namespace ngraph ...@@ -34,12 +34,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Elu() = default;
/// \brief Constructs an Elu operation. /// \brief Constructs an Elu operation.
/// ///
/// \param data Input tensor /// \param data Input tensor
/// \param alpha Multiplier for negative values /// \param alpha Multiplier for negative values
Elu(const std::shared_ptr<ngraph::Node>& data, Elu(const Output<Node>& data, const Output<Node>& alpha);
const std::shared_ptr<ngraph::Node>& alpha);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -38,13 +38,13 @@ using namespace ngraph; ...@@ -38,13 +38,13 @@ using namespace ngraph;
const string op::FakeQuantize::type_name{"FakeQuantize"}; const string op::FakeQuantize::type_name{"FakeQuantize"};
op::FakeQuantize::FakeQuantize(const shared_ptr<Node>& data, op::FakeQuantize::FakeQuantize(const Output<Node>& data,
const shared_ptr<Node>& input_low, const Output<Node>& input_low,
const shared_ptr<Node>& input_high, const Output<Node>& input_high,
const shared_ptr<Node>& output_low, const Output<Node>& output_low,
const shared_ptr<Node>& output_high, const Output<Node>& output_high,
size_t levels) size_t levels)
: FusedOp(check_single_output_args({data, input_low, input_high, output_low, output_high})) : FusedOp({data, input_low, input_high, output_low, output_high})
, m_levels(levels) , m_levels(levels)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -100,16 +100,16 @@ void op::FakeQuantize::pre_validate_and_infer_types() ...@@ -100,16 +100,16 @@ void op::FakeQuantize::pre_validate_and_infer_types()
NodeVector op::FakeQuantize::decompose_op() const NodeVector op::FakeQuantize::decompose_op() const
{ {
shared_ptr<Node> data{get_argument(0)}; Output<Node> data{input(0).get_source_output()};
shared_ptr<Node> input_low{get_argument(1)}; Output<Node> input_low{input(1).get_source_output()};
shared_ptr<Node> input_high{get_argument(2)}; Output<Node> input_high{input(2).get_source_output()};
shared_ptr<Node> output_low{get_argument(3)}; Output<Node> output_low{input(3).get_source_output()};
shared_ptr<Node> output_high{get_argument(4)}; Output<Node> output_high{input(4).get_source_output()};
if (input_low->get_shape().size() == 0) if (input_low.get_shape().size() == 0)
{ {
NodeVector broadcasted_nodes = OutputVector broadcasted_nodes = numpy_style_broadcast_values(
numpy_style_broadcast(NodeVector{data, input_low, input_high, output_low, output_high}); OutputVector{data, input_low, input_high, output_low, output_high});
data = broadcasted_nodes.at(0); data = broadcasted_nodes.at(0);
input_low = broadcasted_nodes.at(1); input_low = broadcasted_nodes.at(1);
...@@ -119,14 +119,15 @@ NodeVector op::FakeQuantize::decompose_op() const ...@@ -119,14 +119,15 @@ NodeVector op::FakeQuantize::decompose_op() const
} }
else else
{ {
input_low = legacy_style_broadcast_for_binary_operation(data, input_low, 1).at(1); input_low = legacy_style_broadcast_values_for_binary_operation(data, input_low, 1).at(1);
input_high = legacy_style_broadcast_for_binary_operation(data, input_high, 1).at(1); input_high = legacy_style_broadcast_values_for_binary_operation(data, input_high, 1).at(1);
output_low = legacy_style_broadcast_for_binary_operation(data, output_low, 1).at(1); output_low = legacy_style_broadcast_values_for_binary_operation(data, output_low, 1).at(1);
output_high = legacy_style_broadcast_for_binary_operation(data, output_high, 1).at(1); output_high =
legacy_style_broadcast_values_for_binary_operation(data, output_high, 1).at(1);
} }
const auto input_data_shape = data->get_shape(); const auto input_data_shape = data.get_shape();
const auto input_data_type = data->get_element_type(); const auto input_data_type = data.get_element_type();
const auto levels_minus_one = const auto levels_minus_one =
Constant::create(input_data_type, Constant::create(input_data_type,
...@@ -138,7 +139,7 @@ NodeVector op::FakeQuantize::decompose_op() const ...@@ -138,7 +139,7 @@ NodeVector op::FakeQuantize::decompose_op() const
const auto dequant_scale = (output_high - output_low) / levels_minus_one; const auto dequant_scale = (output_high - output_low) / levels_minus_one;
// zero_point type needs to match the quantization output type // zero_point type needs to match the quantization output type
const auto zero_point = Constant::create(element::i32, data->get_shape(), {0.0}); const auto zero_point = Constant::create(element::i32, data.get_shape(), {0.0});
const auto axes = get_default_order(input_data_shape); const auto axes = get_default_order(input_data_shape);
// clip the input data to the range <input_low;input_high> // clip the input data to the range <input_low;input_high>
......
...@@ -41,6 +41,7 @@ namespace ngraph ...@@ -41,6 +41,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
FakeQuantize() = default;
/// ///
/// \brief Constructs a FakeQuantize operation node. /// \brief Constructs a FakeQuantize operation node.
/// ///
...@@ -51,11 +52,11 @@ namespace ngraph ...@@ -51,11 +52,11 @@ namespace ngraph
/// \param[in] output_high The maximum quantized value. /// \param[in] output_high The maximum quantized value.
/// \param[in] levels The number of quantization levels. /// \param[in] levels The number of quantization levels.
/// ///
FakeQuantize(const std::shared_ptr<ngraph::Node>& data, FakeQuantize(const Output<Node>& data,
const std::shared_ptr<ngraph::Node>& input_low, const Output<Node>& input_low,
const std::shared_ptr<ngraph::Node>& input_high, const Output<Node>& input_high,
const std::shared_ptr<ngraph::Node>& output_low, const Output<Node>& output_low,
const std::shared_ptr<ngraph::Node>& output_high, const Output<Node>& output_high,
std::size_t levels); std::size_t levels);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -29,8 +29,8 @@ using namespace ngraph; ...@@ -29,8 +29,8 @@ using namespace ngraph;
const string op::Gelu::type_name{"Gelu"}; const string op::Gelu::type_name{"Gelu"};
op::Gelu::Gelu(const shared_ptr<Node>& data) op::Gelu::Gelu(const Output<Node>& data)
: FusedOp(check_single_output_args({data})) : FusedOp({data})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -38,16 +38,16 @@ op::Gelu::Gelu(const shared_ptr<Node>& data) ...@@ -38,16 +38,16 @@ op::Gelu::Gelu(const shared_ptr<Node>& data)
// f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) ) // f(x) = 0.5 * x * (1.0 + erf( x / sqrt(2.0) )
NodeVector op::Gelu::decompose_op() const NodeVector op::Gelu::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
shared_ptr<ngraph::Node> half = shared_ptr<ngraph::Node> half =
builder::make_constant(data->get_element_type(), data->get_shape(), 0.5); builder::make_constant(data.get_element_type(), data.get_shape(), 0.5);
shared_ptr<ngraph::Node> one = shared_ptr<ngraph::Node> one =
builder::make_constant(data->get_element_type(), data->get_shape(), 1.0); builder::make_constant(data.get_element_type(), data.get_shape(), 1.0);
shared_ptr<ngraph::Node> sqrt_two = shared_ptr<ngraph::Node> sqrt_two =
builder::make_constant(data->get_element_type(), data->get_shape(), std::sqrt(2.0)); builder::make_constant(data.get_element_type(), data.get_shape(), std::sqrt(2.0));
return {half * data * (one + make_shared<ngraph::op::Erf>(data / sqrt_two))}; return {half * data * (one + make_shared<ngraph::op::Erf>(data / sqrt_two))};
} }
......
...@@ -35,10 +35,11 @@ namespace ngraph ...@@ -35,10 +35,11 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Gelu() = default;
/// \brief Constructs an Gelu operation. /// \brief Constructs an Gelu operation.
/// ///
/// \param data Input tensor /// \param data Input tensor
Gelu(const std::shared_ptr<ngraph::Node>& data); Gelu(const Output<Node>& data);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -27,14 +27,14 @@ using namespace ngraph; ...@@ -27,14 +27,14 @@ using namespace ngraph;
const string op::Gemm::type_name{"Gemm"}; const string op::Gemm::type_name{"Gemm"};
op::Gemm::Gemm(const std::shared_ptr<ngraph::Node>& A, op::Gemm::Gemm(const Output<Node>& A,
const std::shared_ptr<ngraph::Node>& B, const Output<Node>& B,
const std::shared_ptr<ngraph::Node>& C, const Output<Node>& C,
double alpha, double alpha,
double beta, double beta,
bool transA, bool transA,
bool transB) bool transB)
: FusedOp(check_single_output_args({A, B, C})) : FusedOp({A, B, C})
, m_alpha{alpha} , m_alpha{alpha}
, m_beta{beta} , m_beta{beta}
, m_transA{transA} , m_transA{transA}
...@@ -45,9 +45,9 @@ op::Gemm::Gemm(const std::shared_ptr<ngraph::Node>& A, ...@@ -45,9 +45,9 @@ op::Gemm::Gemm(const std::shared_ptr<ngraph::Node>& A,
NodeVector op::Gemm::decompose_op() const NodeVector op::Gemm::decompose_op() const
{ {
auto A = get_argument(0); auto A = input(0).get_source_output();
auto B = get_argument(1); auto B = input(1).get_source_output();
auto C = get_argument(2); auto C = input(2).get_source_output();
if (m_transA) if (m_transA)
{ {
...@@ -72,7 +72,7 @@ NodeVector op::Gemm::decompose_op() const ...@@ -72,7 +72,7 @@ NodeVector op::Gemm::decompose_op() const
// beta * C // beta * C
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
C->get_element_type(), C->get_shape(), std::vector<double>{m_beta}); C.get_element_type(), C.get_shape(), std::vector<double>{m_beta});
C = std::make_shared<ngraph::op::Multiply>(beta_node, C); C = std::make_shared<ngraph::op::Multiply>(beta_node, C);
// alpha * A' * B' + beta * C // alpha * A' * B' + beta * C
......
...@@ -39,6 +39,7 @@ namespace ngraph ...@@ -39,6 +39,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Gemm() = default;
/// \brief Constructs an Gemm operation. /// \brief Constructs an Gemm operation.
/// ///
/// \param A Input tensor A /// \param A Input tensor A
...@@ -48,9 +49,9 @@ namespace ngraph ...@@ -48,9 +49,9 @@ namespace ngraph
/// \param beta Scalar multiplier for input tensor C /// \param beta Scalar multiplier for input tensor C
/// \param transA Whether A should be transposed /// \param transA Whether A should be transposed
/// \param transB Whether B should be transposed /// \param transB Whether B should be transposed
Gemm(const std::shared_ptr<ngraph::Node>& A, Gemm(const Output<Node>& A,
const std::shared_ptr<ngraph::Node>& B, const Output<Node>& B,
const std::shared_ptr<ngraph::Node>& C, const Output<Node>& C,
double alpha = 1.0, double alpha = 1.0,
double beta = 1.0, double beta = 1.0,
bool transA = false, bool transA = false,
......
...@@ -29,8 +29,8 @@ using namespace ngraph; ...@@ -29,8 +29,8 @@ using namespace ngraph;
const string op::GRN::type_name{"GRN"}; const string op::GRN::type_name{"GRN"};
op::GRN::GRN(const shared_ptr<Node>& data, float bias) op::GRN::GRN(const Output<Node>& data, float bias)
: FusedOp(check_single_output_args({data})) : FusedOp({data})
, m_bias(bias) , m_bias(bias)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -56,8 +56,8 @@ void op::GRN::pre_validate_and_infer_types() ...@@ -56,8 +56,8 @@ void op::GRN::pre_validate_and_infer_types()
NodeVector op::GRN::decompose_op() const NodeVector op::GRN::decompose_op() const
{ {
shared_ptr<Node> data{get_argument(0)}; Output<Node> data{input(0).get_source_output()};
const Shape& input_shape{data->get_shape()}; const Shape& input_shape{data.get_shape()};
// Reshape to 4D tensor. // Reshape to 4D tensor.
if (input_shape.size() != 4) if (input_shape.size() != 4)
...@@ -70,7 +70,7 @@ NodeVector op::GRN::decompose_op() const ...@@ -70,7 +70,7 @@ NodeVector op::GRN::decompose_op() const
// Calculate l2 norm across channels. // Calculate l2 norm across channels.
shared_ptr<Node> norm = builder::l2_norm(data, AxisSet{1}, m_bias); shared_ptr<Node> norm = builder::l2_norm(data, AxisSet{1}, m_bias);
// Get back reduced axis. // Get back reduced axis.
norm = std::make_shared<Broadcast>(norm, data->get_shape(), AxisSet{1}); norm = std::make_shared<Broadcast>(norm, data.get_shape(), AxisSet{1});
data = data / norm; data = data / norm;
// get back original input tensor rank // get back original input tensor rank
...@@ -79,7 +79,7 @@ NodeVector op::GRN::decompose_op() const ...@@ -79,7 +79,7 @@ NodeVector op::GRN::decompose_op() const
data = builder::reshape(data, input_shape); data = builder::reshape(data, input_shape);
} }
return {data}; return as_node_vector({data});
} }
shared_ptr<Node> op::GRN::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GRN::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -33,12 +33,13 @@ namespace ngraph ...@@ -33,12 +33,13 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
GRN() = default;
/// \brief Constructs a GRN operation. /// \brief Constructs a GRN operation.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param bias - The bias added to the variance. /// \param bias - The bias added to the variance.
/// ///
GRN(const std::shared_ptr<ngraph::Node>& data, float bias); GRN(const Output<Node>& data, float bias);
float get_bias() const { return m_bias; } float get_bias() const { return m_bias; }
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
......
...@@ -29,10 +29,6 @@ using namespace ngraph; ...@@ -29,10 +29,6 @@ using namespace ngraph;
const string op::GroupConvolution::type_name{"GroupConvolution"}; const string op::GroupConvolution::type_name{"GroupConvolution"};
op::GroupConvolution::GroupConvolution()
{
}
op::GroupConvolution::GroupConvolution(const Output<Node>& data_batch, op::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
const Output<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
GroupConvolution(); GroupConvolution() = default;
GroupConvolution(const Output<Node>& data_batch, GroupConvolution(const Output<Node>& data_batch,
const Output<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -49,8 +49,8 @@ namespace ngraph ...@@ -49,8 +49,8 @@ namespace ngraph
const CoordinateDiff& get_padding_below() const { return m_padding_below; } const CoordinateDiff& get_padding_below() const { return m_padding_below; }
const CoordinateDiff& get_padding_above() const { return m_padding_above; } const CoordinateDiff& get_padding_above() const { return m_padding_above; }
const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; } const Strides& get_data_dilation_strides() const { return m_data_dilation_strides; }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
std::shared_ptr<Node> get_data_batch() { return get_argument(0); } Output<Node> get_data_batch() { return input(0).get_source_output(); }
size_t get_groups() const { return m_groups; } size_t get_groups() const { return m_groups; }
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -29,8 +29,8 @@ using namespace ngraph; ...@@ -29,8 +29,8 @@ using namespace ngraph;
const string op::GroupConvolutionTranspose::type_name{"GroupConvolutionTranspose"}; const string op::GroupConvolutionTranspose::type_name{"GroupConvolutionTranspose"};
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const shared_ptr<Node>& data, op::GroupConvolutionTranspose::GroupConvolutionTranspose(const Output<Node>& data,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& strides, const Strides& strides,
const Strides& dilations, const Strides& dilations,
const CoordinateDiff& padding_begin, const CoordinateDiff& padding_begin,
...@@ -39,7 +39,7 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const shared_ptr<Node>& ...@@ -39,7 +39,7 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const shared_ptr<Node>&
const size_t groups, const size_t groups,
const PadType& pad_type, const PadType& pad_type,
const Shape& output_shape) const Shape& output_shape)
: FusedOp(check_single_output_args({data, filters})) : FusedOp({data, filters})
, m_strides(strides) , m_strides(strides)
, m_dilations(dilations) , m_dilations(dilations)
, m_padding_begin(padding_begin) , m_padding_begin(padding_begin)
...@@ -52,8 +52,8 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const shared_ptr<Node>& ...@@ -52,8 +52,8 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const shared_ptr<Node>&
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<Node>& data, op::GroupConvolutionTranspose::GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::size_t groups) const std::size_t groups)
: GroupConvolutionTranspose(data, : GroupConvolutionTranspose(data,
filters, filters,
...@@ -68,8 +68,8 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<N ...@@ -68,8 +68,8 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<N
{ {
} }
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<Node>& data, op::GroupConvolutionTranspose::GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& strides, const Strides& strides,
const Strides& dilations, const Strides& dilations,
const CoordinateDiff& output_padding, const CoordinateDiff& output_padding,
...@@ -88,8 +88,8 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<N ...@@ -88,8 +88,8 @@ op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<N
{ {
} }
op::GroupConvolutionTranspose::GroupConvolutionTranspose(const std::shared_ptr<Node>& data, op::GroupConvolutionTranspose::GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Shape& output_shape, const Shape& output_shape,
const std::size_t groups) const std::size_t groups)
: GroupConvolutionTranspose(data, : GroupConvolutionTranspose(data,
...@@ -232,8 +232,8 @@ shared_ptr<Node> op::GroupConvolutionTranspose::copy_with_new_args(const NodeVec ...@@ -232,8 +232,8 @@ shared_ptr<Node> op::GroupConvolutionTranspose::copy_with_new_args(const NodeVec
Shape op::GroupConvolutionTranspose::get_data_batch_shape() const Shape op::GroupConvolutionTranspose::get_data_batch_shape() const
{ {
const auto& data_shape = get_argument(0)->get_shape(); const auto& data_shape = input(0).get_shape();
const auto& filters_shape = get_argument(1)->get_shape(); const auto& filters_shape = input(1).get_shape();
const size_t num_spatial_dims = data_shape.size() - 2; const size_t num_spatial_dims = data_shape.size() - 2;
Shape data_batch_shape(data_shape.size(), 1); Shape data_batch_shape(data_shape.size(), 1);
...@@ -268,27 +268,27 @@ Shape op::GroupConvolutionTranspose::get_data_batch_shape() const ...@@ -268,27 +268,27 @@ Shape op::GroupConvolutionTranspose::get_data_batch_shape() const
NodeVector op::GroupConvolutionTranspose::decompose_op() const NodeVector op::GroupConvolutionTranspose::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto filters = get_argument(1); auto filters = input(1).get_source_output();
const Shape data_batch_shape = get_data_batch_shape(); const Shape data_batch_shape = get_data_batch_shape();
const size_t num_spatial_dims = data->get_shape().size() - 2; const size_t num_spatial_dims = data.get_shape().size() - 2;
if (m_groups > 1) if (m_groups > 1)
{ {
// Split one convolution op to N ops where N is the number of groups // Split one convolution op to N ops where N is the number of groups
// and concat results after computation. // and concat results after computation.
const size_t n_data_channels{data->get_shape().at(1)}; const size_t n_data_channels{data.get_shape().at(1)};
const size_t n_filters_channels{filters->get_shape().at(0)}; const size_t n_filters_channels{filters.get_shape().at(0)};
const size_t data_group_size{n_data_channels / m_groups}; const size_t data_group_size{n_data_channels / m_groups};
const size_t filters_group_size{n_filters_channels / m_groups}; const size_t filters_group_size{n_filters_channels / m_groups};
NodeVector convolution_nodes; NodeVector convolution_nodes;
// initial bounds for slice // initial bounds for slice
vector<size_t> data_lower_bounds(data->get_shape().size()); vector<size_t> data_lower_bounds(data.get_shape().size());
vector<size_t> data_upper_bounds{data->get_shape()}; vector<size_t> data_upper_bounds{data.get_shape()};
vector<size_t> filters_lower_bounds(filters->get_shape().size()); vector<size_t> filters_lower_bounds(filters.get_shape().size());
vector<size_t> filters_upper_bounds{filters->get_shape()}; vector<size_t> filters_upper_bounds{filters.get_shape()};
for (size_t group{0}; group < m_groups; ++group) for (size_t group{0}; group < m_groups; ++group)
{ {
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
GroupConvolutionTranspose() = default;
/// ///
/// \brief Constructs GroupConvolutionTranspose operation. /// \brief Constructs GroupConvolutionTranspose operation.
/// ///
...@@ -54,8 +55,8 @@ namespace ngraph ...@@ -54,8 +55,8 @@ namespace ngraph
/// \param[in] output_shape The output shape. When provided padding values are /// \param[in] output_shape The output shape. When provided padding values are
/// automatically inferred. /// automatically inferred.
/// ///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data, GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& strides, const Strides& strides,
const Strides& dilations, const Strides& dilations,
const CoordinateDiff& padding_begin, const CoordinateDiff& padding_begin,
...@@ -73,8 +74,8 @@ namespace ngraph ...@@ -73,8 +74,8 @@ namespace ngraph
/// \param[in] groups The number of groups the input channels and output channels /// \param[in] groups The number of groups the input channels and output channels
/// are divided into. /// are divided into.
/// ///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data, GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const std::size_t groups = 1UL); const std::size_t groups = 1UL);
/// ///
...@@ -90,8 +91,8 @@ namespace ngraph ...@@ -90,8 +91,8 @@ namespace ngraph
/// \param[in] groups The number of groups the input channels and output channels /// \param[in] groups The number of groups the input channels and output channels
/// are divided into. /// are divided into.
/// ///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data, GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& strides, const Strides& strides,
const Strides& dilations, const Strides& dilations,
const CoordinateDiff& output_padding, const CoordinateDiff& output_padding,
...@@ -108,13 +109,13 @@ namespace ngraph ...@@ -108,13 +109,13 @@ namespace ngraph
/// \param[in] groups The number of groups the input channels and output channels /// \param[in] groups The number of groups the input channels and output channels
/// are divided into. /// are divided into.
/// ///
GroupConvolutionTranspose(const std::shared_ptr<Node>& data, GroupConvolutionTranspose(const Output<Node>& data,
const std::shared_ptr<Node>& filters, const Output<Node>& filters,
const Shape& output_shape, const Shape& output_shape,
const std::size_t groups = 1UL); const std::size_t groups = 1UL);
std::shared_ptr<Node> get_data() { return get_argument(0); } Output<Node> get_data() { return input(0).get_source_output(); }
std::shared_ptr<Node> get_filters() { return get_argument(1); } Output<Node> get_filters() { return input(1).get_source_output(); }
const Strides& get_strides() const { return m_strides; } const Strides& get_strides() const { return m_strides; }
const Strides& get_dilations() const { return m_dilations; } const Strides& get_dilations() const { return m_dilations; }
const CoordinateDiff& get_padding_begin() const { return m_padding_begin; } const CoordinateDiff& get_padding_begin() const { return m_padding_begin; }
......
...@@ -33,10 +33,10 @@ using namespace ngraph; ...@@ -33,10 +33,10 @@ using namespace ngraph;
const string op::GRUCell::type_name{"GRUCell"}; const string op::GRUCell::type_name{"GRUCell"};
op::GRUCell::GRUCell(const shared_ptr<Node>& X, op::GRUCell::GRUCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
size_t hidden_size) size_t hidden_size)
: GRUCell(X, : GRUCell(X,
W, W,
...@@ -51,17 +51,17 @@ op::GRUCell::GRUCell(const shared_ptr<Node>& X, ...@@ -51,17 +51,17 @@ op::GRUCell::GRUCell(const shared_ptr<Node>& X,
{ {
} }
op::GRUCell::GRUCell(const shared_ptr<Node>& X, op::GRUCell::GRUCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
size_t hidden_size, size_t hidden_size,
const vector<string>& activations, const vector<string>& activations,
const vector<float>& activation_alpha, const vector<float>& activation_alpha,
const vector<float>& activation_beta, const vector<float>& activation_beta,
float clip, float clip,
bool linear_before_reset) bool linear_before_reset)
: FusedOp(check_single_output_args({X, W, R, H_t})) : FusedOp({X, W, R, H_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta) , RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)} , m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)} , m_activation_g{get_activation_function(1)}
...@@ -71,18 +71,18 @@ op::GRUCell::GRUCell(const shared_ptr<Node>& X, ...@@ -71,18 +71,18 @@ op::GRUCell::GRUCell(const shared_ptr<Node>& X,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::GRUCell::GRUCell(const shared_ptr<Node>& X, op::GRUCell::GRUCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
size_t hidden_size, size_t hidden_size,
const shared_ptr<Node>& B, const Output<Node>& B,
const vector<string>& activations, const vector<string>& activations,
const vector<float>& activation_alpha, const vector<float>& activation_alpha,
const vector<float>& activation_beta, const vector<float>& activation_beta,
float clip, float clip,
bool linear_before_reset) bool linear_before_reset)
: FusedOp(check_single_output_args({X, W, R, H_t, B})) : FusedOp({X, W, R, H_t, B})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta) , RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)} , m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)} , m_activation_g{get_activation_function(1)}
...@@ -189,11 +189,11 @@ NodeVector op::GRUCell::decompose_op() const ...@@ -189,11 +189,11 @@ NodeVector op::GRUCell::decompose_op() const
// Ht = (1 - zt) (.) ht + zt (.) Ht-1 // Ht = (1 - zt) (.) ht + zt (.) Ht-1
// ------------------- // -------------------
std::shared_ptr<Node> X = get_argument(0); Output<Node> X = input(0).get_source_output();
std::shared_ptr<Node> W = get_argument(1); Output<Node> W = input(1).get_source_output();
std::shared_ptr<Node> R = get_argument(2); Output<Node> R = input(2).get_source_output();
std::shared_ptr<Node> H_t = get_argument(3); Output<Node> H_t = input(3).get_source_output();
std::shared_ptr<Node> B = get_argument(4); Output<Node> B = input(4).get_source_output();
// Get W and R biases separately. // Get W and R biases separately.
NodeVector b_W_R = builder::split(B, 2); NodeVector b_W_R = builder::split(B, 2);
...@@ -245,7 +245,7 @@ NodeVector op::GRUCell::decompose_op() const ...@@ -245,7 +245,7 @@ NodeVector op::GRUCell::decompose_op() const
const auto& z_t = zr_t_gates.at(0); const auto& z_t = zr_t_gates.at(0);
const auto& r_t = zr_t_gates.at(1); const auto& r_t = zr_t_gates.at(1);
shared_ptr<Node> h_t; Output<Node> h_t;
if (m_linear_before_reset) if (m_linear_before_reset)
{ {
...@@ -269,16 +269,16 @@ NodeVector op::GRUCell::decompose_op() const ...@@ -269,16 +269,16 @@ NodeVector op::GRUCell::decompose_op() const
// Ht = (1 - zt) (.) ht + zt (.) Ht-1 // Ht = (1 - zt) (.) ht + zt (.) Ht-1
H_t = add(mul(sub(one, z_t), h_t), mul(z_t, H_t)); H_t = add(mul(sub(one, z_t), h_t), mul(z_t, H_t));
return {H_t}; return {H_t.get_node_shared_ptr()};
} }
void op::GRUCell::add_default_bias_input() void op::GRUCell::add_default_bias_input()
{ {
shared_ptr<Node> B = Output<Node> B =
op::Constant::create(input(0).get_element_type(), op::Constant::create(input(0).get_element_type(),
Shape{2 * s_gates_count * get_hidden_size()}, Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f)); vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(4, B->output(0)); set_argument(4, B);
} }
shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GRUCell::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -44,6 +44,7 @@ namespace ngraph ...@@ -44,6 +44,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
GRUCell() = default;
/// ///
/// \brief Constructs GRUCell node. /// \brief Constructs GRUCell node.
/// ///
...@@ -56,10 +57,10 @@ namespace ngraph ...@@ -56,10 +57,10 @@ namespace ngraph
/// shape: [batch_size, hidden_size]. /// shape: [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell. /// \param[in] hidden_size The number of hidden units for recurrent cell.
/// ///
GRUCell(const std::shared_ptr<Node>& X, GRUCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
std::size_t hidden_size); std::size_t hidden_size);
/// ///
...@@ -82,10 +83,10 @@ namespace ngraph ...@@ -82,10 +83,10 @@ namespace ngraph
/// \param[in] clip The value defining clipping range [-clip, clip] on /// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions. /// input of activation functions.
/// ///
GRUCell(const std::shared_ptr<Node>& X, GRUCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::vector<std::string>& activations, const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha, const std::vector<float>& activation_alpha,
...@@ -115,12 +116,12 @@ namespace ngraph ...@@ -115,12 +116,12 @@ namespace ngraph
/// \param[in] clip The value defining clipping range [-clip, clip] on /// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions. /// input of activation functions.
/// ///
GRUCell(const std::shared_ptr<Node>& X, GRUCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::shared_ptr<Node>& B, const Output<Node>& B,
const std::vector<std::string>& activations = const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh"}, std::vector<std::string>{"sigmoid", "tanh"},
const std::vector<float>& activation_alpha = {}, const std::vector<float>& activation_alpha = {},
......
...@@ -29,8 +29,8 @@ using namespace ngraph; ...@@ -29,8 +29,8 @@ using namespace ngraph;
const string op::HardSigmoid::type_name{"HardSigmoid"}; const string op::HardSigmoid::type_name{"HardSigmoid"};
op::HardSigmoid::HardSigmoid(const shared_ptr<Node>& data, float alpha, float beta) op::HardSigmoid::HardSigmoid(const Output<Node>& data, float alpha, float beta)
: FusedOp(check_single_output_args({data})) : FusedOp({data})
, m_alpha(alpha) , m_alpha(alpha)
, m_beta(beta) , m_beta(beta)
{ {
...@@ -39,21 +39,21 @@ op::HardSigmoid::HardSigmoid(const shared_ptr<Node>& data, float alpha, float be ...@@ -39,21 +39,21 @@ op::HardSigmoid::HardSigmoid(const shared_ptr<Node>& data, float alpha, float be
NodeVector op::HardSigmoid::decompose_op() const NodeVector op::HardSigmoid::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto data_shape = data->get_shape(); auto data_shape = data.get_shape();
size_t elem_count = shape_size(data_shape); size_t elem_count = shape_size(data_shape);
std::shared_ptr<ngraph::Node> alpha_node = ngraph::op::Constant::create<float>( std::shared_ptr<ngraph::Node> alpha_node = ngraph::op::Constant::create<float>(
data->get_element_type(), data_shape, std::vector<float>(elem_count, m_alpha)); data.get_element_type(), data_shape, std::vector<float>(elem_count, m_alpha));
std::shared_ptr<ngraph::Node> beta_node = ngraph::op::Constant::create<float>( std::shared_ptr<ngraph::Node> beta_node = ngraph::op::Constant::create<float>(
data->get_element_type(), data_shape, std::vector<float>(elem_count, m_beta)); data.get_element_type(), data_shape, std::vector<float>(elem_count, m_beta));
std::shared_ptr<ngraph::Node> one_node = ngraph::op::Constant::create<float>( std::shared_ptr<ngraph::Node> one_node = ngraph::op::Constant::create<float>(
data->get_element_type(), data_shape, std::vector<float>(elem_count, 1.0)); data.get_element_type(), data_shape, std::vector<float>(elem_count, 1.0));
std::shared_ptr<ngraph::Node> zero_node = ngraph::op::Constant::create<float>( std::shared_ptr<ngraph::Node> zero_node = ngraph::op::Constant::create<float>(
data->get_element_type(), data_shape, std::vector<float>(elem_count, 0.0)); data.get_element_type(), data_shape, std::vector<float>(elem_count, 0.0));
return {std::make_shared<op::Minimum>( return {std::make_shared<op::Minimum>(
std::make_shared<op::Maximum>(alpha_node * data + beta_node, zero_node), one_node)}; std::make_shared<op::Maximum>(alpha_node * data + beta_node, zero_node), one_node)};
......
...@@ -33,13 +33,14 @@ namespace ngraph ...@@ -33,13 +33,14 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
HardSigmoid() = default;
/// \brief Constructs a HardSigmoid operation. /// \brief Constructs a HardSigmoid operation.
/// ///
/// \param data Input tensor. /// \param data Input tensor.
/// \param[in] alpha The alpha parameter. /// \param[in] alpha The alpha parameter.
/// \param[in] beta The beta parameter. /// \param[in] beta The beta parameter.
/// ///
HardSigmoid(const std::shared_ptr<ngraph::Node>& data, float alpha, float beta); HardSigmoid(const Output<Node>& data, float alpha, float beta);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -25,18 +25,18 @@ using namespace ngraph; ...@@ -25,18 +25,18 @@ using namespace ngraph;
const string op::LeakyRelu::type_name{"LeakyRelu"}; const string op::LeakyRelu::type_name{"LeakyRelu"};
op::LeakyRelu::LeakyRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& alpha) op::LeakyRelu::LeakyRelu(const Output<Node>& data, const Output<Node>& alpha)
: FusedOp(check_single_output_args({data, alpha})) : FusedOp({data, alpha})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
NodeVector op::LeakyRelu::decompose_op() const NodeVector op::LeakyRelu::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto alpha_node = get_argument(1); auto alpha_node = input(1).get_source_output();
alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data->get_shape()); alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data.get_shape());
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)}; return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
} }
......
...@@ -33,8 +33,8 @@ namespace ngraph ...@@ -33,8 +33,8 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
LeakyRelu(const std::shared_ptr<ngraph::Node>& data, LeakyRelu() = default;
const std::shared_ptr<ngraph::Node>& alpha); LeakyRelu(const Output<Node>& data, const Output<Node>& alpha);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -33,11 +33,11 @@ using namespace ngraph; ...@@ -33,11 +33,11 @@ using namespace ngraph;
const string op::LSTMCell::type_name{"LSTMCell"}; const string op::LSTMCell::type_name{"LSTMCell"};
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, op::LSTMCell::LSTMCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
const shared_ptr<Node>& C_t, const Output<Node>& C_t,
size_t hidden_size) size_t hidden_size)
: LSTMCell(X, : LSTMCell(X,
W, W,
...@@ -53,18 +53,18 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, ...@@ -53,18 +53,18 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
{ {
} }
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, op::LSTMCell::LSTMCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
const shared_ptr<Node>& C_t, const Output<Node>& C_t,
size_t hidden_size, size_t hidden_size,
const vector<string>& activations, const vector<string>& activations,
const vector<float>& activation_alpha, const vector<float>& activation_alpha,
const vector<float>& activation_beta, const vector<float>& activation_beta,
float clip, float clip,
bool input_forget) bool input_forget)
: FusedOp(check_single_output_args({X, W, R, H_t, C_t})) : FusedOp({X, W, R, H_t, C_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta) , RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)} , m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)} , m_activation_g{get_activation_function(1)}
...@@ -76,20 +76,20 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, ...@@ -76,20 +76,20 @@ op::LSTMCell::LSTMCell(const shared_ptr<Node>& X,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::LSTMCell::LSTMCell(const shared_ptr<Node>& X, op::LSTMCell::LSTMCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
const shared_ptr<Node>& C_t, const Output<Node>& C_t,
size_t hidden_size, size_t hidden_size,
const shared_ptr<Node>& B, const Output<Node>& B,
const shared_ptr<Node>& P, const Output<Node>& P,
const vector<string>& activations, const vector<string>& activations,
const vector<float>& activation_alpha, const vector<float>& activation_alpha,
const vector<float>& activation_beta, const vector<float>& activation_beta,
float clip, float clip,
bool input_forget) bool input_forget)
: FusedOp(check_single_output_args({X, W, R, H_t, C_t, B, P})) : FusedOp({X, W, R, H_t, C_t, B, P})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta) , RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)} , m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)} , m_activation_g{get_activation_function(1)}
...@@ -226,13 +226,13 @@ NodeVector op::LSTMCell::decompose_op() const ...@@ -226,13 +226,13 @@ NodeVector op::LSTMCell::decompose_op() const
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
// -------------------- // --------------------
shared_ptr<Node> X = get_argument(0); Output<Node> X = input(0).get_source_output();
shared_ptr<Node> W = get_argument(1); Output<Node> W = input(1).get_source_output();
shared_ptr<Node> R = get_argument(2); Output<Node> R = input(2).get_source_output();
shared_ptr<Node> H_t = get_argument(3); Output<Node> H_t = input(3).get_source_output();
shared_ptr<Node> C_t = get_argument(4); Output<Node> C_t = input(4).get_source_output();
shared_ptr<Node> bias = get_bias(); Output<Node> bias = get_bias();
NodeVector p_iof = get_peephole_weigths(); NodeVector p_iof = get_peephole_weights();
const auto& p_i = p_iof.at(0); const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1); const auto& p_o = p_iof.at(1);
...@@ -276,38 +276,38 @@ NodeVector op::LSTMCell::decompose_op() const ...@@ -276,38 +276,38 @@ NodeVector op::LSTMCell::decompose_op() const
return {H, C}; return {H, C};
} }
shared_ptr<Node> op::LSTMCell::get_bias() const Output<Node> op::LSTMCell::get_bias() const
{ {
shared_ptr<Node> bias; Output<Node> bias;
// Split B onto Wb an Rb and add them. // Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(get_argument(5), 2); NodeVector b_W_R = builder::split(input(5).get_source_output(), 2);
bias = b_W_R.at(0) + b_W_R.at(1); bias = b_W_R.at(0) + b_W_R.at(1);
return bias; return bias;
} }
NodeVector op::LSTMCell::get_peephole_weigths() const NodeVector op::LSTMCell::get_peephole_weights() const
{ {
shared_ptr<Node> P; Output<Node> P;
P = get_argument(6); P = input(6).get_source_output();
return builder::split(P, s_peepholes_count); return builder::split(P, s_peepholes_count);
} }
void op::LSTMCell::add_default_bias_input() void op::LSTMCell::add_default_bias_input()
{ {
shared_ptr<Node> B = Output<Node> B =
op::Constant::create(input(0).get_element_type(), op::Constant::create(input(0).get_element_type(),
Shape{2 * s_gates_count * get_hidden_size()}, Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f)); vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(5, B->output(0)); set_argument(5, B);
} }
void op::LSTMCell::add_default_peepholes_input() void op::LSTMCell::add_default_peepholes_input()
{ {
shared_ptr<Node> P = Output<Node> P =
op::Constant::create(input(0).get_element_type(), op::Constant::create(input(0).get_element_type(),
Shape{s_peepholes_count * get_hidden_size()}, Shape{s_peepholes_count * get_hidden_size()},
vector<float>(s_peepholes_count * get_hidden_size(), 0.f)); vector<float>(s_peepholes_count * get_hidden_size(), 0.f));
set_argument(6, P->output(0)); set_argument(6, P);
} }
shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -45,6 +45,7 @@ namespace ngraph ...@@ -45,6 +45,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
LSTMCell() = default;
/// ///
/// \brief Constructs LSTMCell node. /// \brief Constructs LSTMCell node.
/// ///
...@@ -58,11 +59,11 @@ namespace ngraph ...@@ -58,11 +59,11 @@ namespace ngraph
/// [batch_size, hidden_size]. /// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell. /// \param[in] hidden_size The number of hidden units for recurrent cell.
/// ///
LSTMCell(const std::shared_ptr<Node>& X, LSTMCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
const std::shared_ptr<Node>& C_t, const Output<Node>& C_t,
std::size_t hidden_size); std::size_t hidden_size);
/// ///
...@@ -87,11 +88,11 @@ namespace ngraph ...@@ -87,11 +88,11 @@ namespace ngraph
/// input of activation functions. /// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates. /// \param[in] input_forget Controls coupling input and forget gates.
/// ///
LSTMCell(const std::shared_ptr<Node>& X, LSTMCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
const std::shared_ptr<Node>& C_t, const Output<Node>& C_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::vector<std::string>& activations, const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha, const std::vector<float>& activation_alpha,
...@@ -124,14 +125,14 @@ namespace ngraph ...@@ -124,14 +125,14 @@ namespace ngraph
/// input of activation functions. /// input of activation functions.
/// \param[in] input_forget Controls coupling input and forget gates. /// \param[in] input_forget Controls coupling input and forget gates.
/// ///
LSTMCell(const std::shared_ptr<Node>& X, LSTMCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
const std::shared_ptr<Node>& C_t, const Output<Node>& C_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::shared_ptr<Node>& B, const Output<Node>& B,
const std::shared_ptr<Node>& P, const Output<Node>& P,
const std::vector<std::string>& activations = const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"}, std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activation_alpha = {}, const std::vector<float>& activation_alpha = {},
...@@ -146,8 +147,8 @@ namespace ngraph ...@@ -146,8 +147,8 @@ namespace ngraph
bool get_input_forget() const { return m_input_forget; } bool get_input_forget() const { return m_input_forget; }
private: private:
std::shared_ptr<Node> get_bias() const; Output<Node> get_bias() const;
NodeVector get_peephole_weigths() const; NodeVector get_peephole_weights() const;
/// brief Add and initialize bias input to all zeros. /// brief Add and initialize bias input to all zeros.
void add_default_bias_input(); void add_default_bias_input();
......
...@@ -29,11 +29,8 @@ using namespace ngraph; ...@@ -29,11 +29,8 @@ using namespace ngraph;
const string op::MVN::type_name{"MVN"}; const string op::MVN::type_name{"MVN"};
op::MVN::MVN(const std::shared_ptr<Node>& data, op::MVN::MVN(const Output<Node>& data, bool across_channels, bool normalize_variance, double eps)
bool across_channels, : FusedOp({data})
bool normalize_variance,
double eps)
: FusedOp(check_single_output_args({data}))
, m_eps{eps} , m_eps{eps}
, m_across_channels{across_channels} , m_across_channels{across_channels}
, m_normalize_variance{normalize_variance} , m_normalize_variance{normalize_variance}
...@@ -44,17 +41,14 @@ op::MVN::MVN(const std::shared_ptr<Node>& data, ...@@ -44,17 +41,14 @@ op::MVN::MVN(const std::shared_ptr<Node>& data,
// else we calculate these per channel // else we calculate these per channel
m_reduction_axes.insert(0); m_reduction_axes.insert(0);
size_t start_axis = m_across_channels ? 1 : 2; size_t start_axis = m_across_channels ? 1 : 2;
for (size_t i = start_axis; i < data->get_shape().size(); ++i) for (size_t i = start_axis; i < data.get_shape().size(); ++i)
{ {
m_reduction_axes.insert(i); m_reduction_axes.insert(i);
} }
} }
op::MVN::MVN(const std::shared_ptr<Node>& data, op::MVN::MVN(const Output<Node>& data, AxisSet reduction_axes, bool normalize_variance, double eps)
AxisSet reduction_axes, : FusedOp({data})
bool normalize_variance,
double eps)
: FusedOp(check_single_output_args({data}))
, m_eps{eps} , m_eps{eps}
, m_across_channels{false} , m_across_channels{false}
, m_normalize_variance{normalize_variance} , m_normalize_variance{normalize_variance}
...@@ -65,8 +59,8 @@ op::MVN::MVN(const std::shared_ptr<Node>& data, ...@@ -65,8 +59,8 @@ op::MVN::MVN(const std::shared_ptr<Node>& data,
NodeVector op::MVN::decompose_op() const NodeVector op::MVN::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto data_shape = data->get_shape(); // assume that data has n and c channels. auto data_shape = data.get_shape(); // assume that data has n and c channels.
// calculate mean normalization // calculate mean normalization
auto mean = builder::mean(data, m_reduction_axes); auto mean = builder::mean(data, m_reduction_axes);
...@@ -84,11 +78,11 @@ NodeVector op::MVN::decompose_op() const ...@@ -84,11 +78,11 @@ NodeVector op::MVN::decompose_op() const
variance = make_shared<op::Sqrt>(variance); variance = make_shared<op::Sqrt>(variance);
// add epsilon // add epsilon
auto eps_node = op::Constant::create( auto eps_node = op::Constant::create(
data->get_element_type(), variance->get_shape(), vector<double>{m_eps}); data.get_element_type(), Output<Node>(variance).get_shape(), vector<double>{m_eps});
variance = variance + eps_node; variance = variance + eps_node;
variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes); variance = std::make_shared<op::Broadcast>(variance, data_shape, m_reduction_axes);
return {mean_normalization / variance}; return as_node_vector({mean_normalization / variance});
} }
} }
......
...@@ -32,6 +32,7 @@ namespace ngraph ...@@ -32,6 +32,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
MVN() = default;
/// \brief Constructs an MVN operation. /// \brief Constructs an MVN operation.
/// ///
/// \param data Input tensor with data /// \param data Input tensor with data
...@@ -39,7 +40,7 @@ namespace ngraph ...@@ -39,7 +40,7 @@ namespace ngraph
/// \param across_channels flag that denotes if mean values are shared across channels. /// \param across_channels flag that denotes if mean values are shared across channels.
/// \param eps the number to be added to the variance to avoid division by zero when normalizing the value /// \param eps the number to be added to the variance to avoid division by zero when normalizing the value
/// ///
MVN(const std::shared_ptr<ngraph::Node>& data, MVN(const Output<Node>& data,
bool across_channels = true, bool across_channels = true,
bool normalize_variance = true, bool normalize_variance = true,
double eps = 1e-9); double eps = 1e-9);
...@@ -51,7 +52,7 @@ namespace ngraph ...@@ -51,7 +52,7 @@ namespace ngraph
/// \param normalize_variance flag that denotes whether to perform variance normalization. /// \param normalize_variance flag that denotes whether to perform variance normalization.
/// \param eps the number to be added to the variance to avoid division by zero when normalizing the value /// \param eps the number to be added to the variance to avoid division by zero when normalizing the value
/// ///
MVN(const std::shared_ptr<ngraph::Node>& data, MVN(const Output<Node>& data,
AxisSet reduction_axes, AxisSet reduction_axes,
bool normalize_variance = true, bool normalize_variance = true,
double eps = 1e-9); double eps = 1e-9);
......
...@@ -28,12 +28,12 @@ using namespace ngraph; ...@@ -28,12 +28,12 @@ using namespace ngraph;
const string op::Normalize::type_name{"Normalize"}; const string op::Normalize::type_name{"Normalize"};
op::Normalize::Normalize(const shared_ptr<ngraph::Node>& data, op::Normalize::Normalize(const Output<Node>& data,
const shared_ptr<ngraph::Node>& scale, const Output<Node>& scale,
bool across_spatial, bool across_spatial,
bool channel_shared, bool channel_shared,
float eps) float eps)
: FusedOp(check_single_output_args({data, scale})) : FusedOp({data, scale})
, m_across_spatial{across_spatial} , m_across_spatial{across_spatial}
, m_channel_shared{channel_shared} , m_channel_shared{channel_shared}
, m_eps{eps} , m_eps{eps}
...@@ -88,8 +88,8 @@ void op::Normalize::pre_validate_and_infer_types() ...@@ -88,8 +88,8 @@ void op::Normalize::pre_validate_and_infer_types()
NodeVector op::Normalize::decompose_op() const NodeVector op::Normalize::decompose_op() const
{ {
shared_ptr<Node> data{get_argument(0)}; Output<Node> data{input(0).get_source_output()};
const Shape input_shape{data->get_shape()}; const Shape input_shape{data.get_shape()};
// Reshape to 4D tensor. // Reshape to 4D tensor.
if (input_shape.size() != 4) if (input_shape.size() != 4)
...@@ -108,21 +108,21 @@ NodeVector op::Normalize::decompose_op() const ...@@ -108,21 +108,21 @@ NodeVector op::Normalize::decompose_op() const
} }
// Calculate l2 norm across channels. // Calculate l2 norm across channels.
shared_ptr<Node> norm = builder::l2_norm(data, reduction_axes, m_eps); Output<Node> norm = builder::l2_norm(data, reduction_axes, m_eps);
norm = make_broadcast_node(norm, data->get_shape(), 0); norm = make_broadcast_node(norm, data.get_shape(), 0);
shared_ptr<Node> scale_node{get_argument(1)}; Output<Node> scale_node{input(1).get_source_output()};
// Broadcast scale to data tensor shape. // Broadcast scale to data tensor shape.
if (m_channel_shared) if (m_channel_shared)
{ {
// Scale is a scalar. // Scale is a scalar.
scale_node = make_broadcast_node(scale_node, data->get_shape()); scale_node = make_broadcast_node(scale_node, data.get_shape());
} }
else else
{ {
// Scale is a vector of size equal to C axis. // Scale is a vector of size equal to C axis.
scale_node = make_broadcast_node(scale_node, data->get_shape(), 1); scale_node = make_broadcast_node(scale_node, data.get_shape(), 1);
} }
data = data / norm * scale_node; data = data / norm * scale_node;
...@@ -133,7 +133,7 @@ NodeVector op::Normalize::decompose_op() const ...@@ -133,7 +133,7 @@ NodeVector op::Normalize::decompose_op() const
data = builder::reshape(data, input_shape); data = builder::reshape(data, input_shape);
} }
return {data}; return as_node_vector({data});
} }
shared_ptr<Node> op::Normalize::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Normalize::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -33,6 +33,7 @@ namespace ngraph ...@@ -33,6 +33,7 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Normalize() = default;
/// ///
/// \brief Constructs a Normalize operation. /// \brief Constructs a Normalize operation.
/// ///
...@@ -42,8 +43,8 @@ namespace ngraph ...@@ -42,8 +43,8 @@ namespace ngraph
/// \param channel_shared - Whether scale is shared across channels. /// \param channel_shared - Whether scale is shared across channels.
/// \param eps - The epsilon added to L2 norm. /// \param eps - The epsilon added to L2 norm.
/// ///
Normalize(const std::shared_ptr<ngraph::Node>& data, Normalize(const Output<Node>& data,
const std::shared_ptr<ngraph::Node>& scale, const Output<Node>& scale,
bool across_spatial, bool across_spatial,
bool channel_shared, bool channel_shared,
float eps); float eps);
......
...@@ -31,8 +31,8 @@ using namespace ngraph; ...@@ -31,8 +31,8 @@ using namespace ngraph;
const string op::PRelu::type_name{"PRelu"}; const string op::PRelu::type_name{"PRelu"};
op::PRelu::PRelu(const shared_ptr<Node>& data, const shared_ptr<Node>& slope) op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
: FusedOp(check_single_output_args({data, slope})) : FusedOp({data, slope})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -34,12 +34,12 @@ namespace ngraph ...@@ -34,12 +34,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
PRelu() = default;
/// \brief Constructs a PRelu operation. /// \brief Constructs a PRelu operation.
/// ///
/// \param data Input tensor /// \param data Input tensor
/// \param slope Multipliers for negative values /// \param slope Multipliers for negative values
PRelu(const std::shared_ptr<ngraph::Node>& data, PRelu(const Output<Node>& data, const Output<Node>& slope);
const std::shared_ptr<ngraph::Node>& slope);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -33,26 +33,26 @@ using namespace ngraph; ...@@ -33,26 +33,26 @@ using namespace ngraph;
const string op::RNNCell::type_name{"RNNCell"}; const string op::RNNCell::type_name{"RNNCell"};
op::RNNCell::RNNCell(const shared_ptr<Node>& X, op::RNNCell::RNNCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
size_t hidden_size) size_t hidden_size)
: RNNCell( : RNNCell(
X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f) X, W, R, H_t, hidden_size, vector<string>{"tanh"}, vector<float>{}, vector<float>{}, 0.f)
{ {
} }
op::RNNCell::RNNCell(const shared_ptr<Node>& X, op::RNNCell::RNNCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
size_t hidden_size, size_t hidden_size,
const vector<string>& activations, const vector<string>& activations,
const vector<float>& activation_alpha, const vector<float>& activation_alpha,
const vector<float>& activation_beta, const vector<float>& activation_beta,
float clip) float clip)
: FusedOp(check_single_output_args({X, W, R, H_t})) : FusedOp({X, W, R, H_t})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta) , RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)} , m_activation_f{get_activation_function(0)}
{ {
...@@ -60,17 +60,17 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X, ...@@ -60,17 +60,17 @@ op::RNNCell::RNNCell(const shared_ptr<Node>& X,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::RNNCell::RNNCell(const shared_ptr<Node>& X, op::RNNCell::RNNCell(const Output<Node>& X,
const shared_ptr<Node>& W, const Output<Node>& W,
const shared_ptr<Node>& R, const Output<Node>& R,
const shared_ptr<Node>& H_t, const Output<Node>& H_t,
size_t hidden_size, size_t hidden_size,
const shared_ptr<Node>& B, const Output<Node>& B,
const vector<string>& activations, const vector<string>& activations,
const vector<float>& activation_alpha, const vector<float>& activation_alpha,
const vector<float>& activation_beta, const vector<float>& activation_beta,
float clip) float clip)
: FusedOp(check_single_output_args({X, W, R, H_t, B})) : FusedOp({X, W, R, H_t, B})
, RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta) , RNNCellBase(hidden_size, clip, activations, activation_alpha, activation_beta)
, m_activation_f{get_activation_function(0)} , m_activation_f{get_activation_function(0)}
{ {
...@@ -169,11 +169,11 @@ NodeVector op::RNNCell::decompose_op() const ...@@ -169,11 +169,11 @@ NodeVector op::RNNCell::decompose_op() const
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) // Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// -------------------- // --------------------
std::shared_ptr<Node> X = get_argument(0); Output<Node> X = input(0).get_source_output();
std::shared_ptr<Node> W = get_argument(1); Output<Node> W = input(1).get_source_output();
std::shared_ptr<Node> R = get_argument(2); Output<Node> R = input(2).get_source_output();
std::shared_ptr<Node> H_t = get_argument(3); Output<Node> H_t = input(3).get_source_output();
std::shared_ptr<Node> bias = get_bias(); Output<Node> bias = get_bias();
// Xt*(W^T) // Xt*(W^T)
auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W)); auto Xt_W = std::make_shared<op::Dot>(X, builder::transpose(W));
...@@ -188,22 +188,22 @@ NodeVector op::RNNCell::decompose_op() const ...@@ -188,22 +188,22 @@ NodeVector op::RNNCell::decompose_op() const
return {i_t}; return {i_t};
} }
shared_ptr<Node> op::RNNCell::get_bias() const Output<Node> op::RNNCell::get_bias() const
{ {
shared_ptr<Node> bias; Output<Node> bias;
// Split B onto Wb an Rb and add them. // Split B onto Wb an Rb and add them.
NodeVector b_W_R = builder::split(get_argument(4), 2); NodeVector b_W_R = builder::split(input(4).get_source_output(), 2);
bias = b_W_R.at(0) + b_W_R.at(1); bias = b_W_R.at(0) + b_W_R.at(1);
return bias; return bias;
} }
void op::RNNCell::add_default_bias_input() void op::RNNCell::add_default_bias_input()
{ {
shared_ptr<Node> B = Output<Node> B =
op::Constant::create(input(0).get_element_type(), op::Constant::create(input(0).get_element_type(),
Shape{2 * s_gates_count * get_hidden_size()}, Shape{2 * s_gates_count * get_hidden_size()},
vector<float>(2 * s_gates_count * get_hidden_size(), 0.f)); vector<float>(2 * s_gates_count * get_hidden_size(), 0.f));
set_argument(4, B->output(0)); set_argument(4, B);
} }
shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::RNNCell::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -55,10 +55,10 @@ namespace ngraph ...@@ -55,10 +55,10 @@ namespace ngraph
/// [batch_size, hidden_size]. /// [batch_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell. /// \param[in] hidden_size The number of hidden units for recurrent cell.
/// ///
RNNCell(const std::shared_ptr<Node>& X, RNNCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
std::size_t hidden_size); std::size_t hidden_size);
/// ///
...@@ -80,10 +80,10 @@ namespace ngraph ...@@ -80,10 +80,10 @@ namespace ngraph
/// \param[in] clip The value defining clipping range [-clip, clip] on /// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions. /// input of activation functions.
/// ///
RNNCell(const std::shared_ptr<Node>& X, RNNCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::vector<std::string>& activations, const std::vector<std::string>& activations,
const std::vector<float>& activation_alpha, const std::vector<float>& activation_alpha,
...@@ -110,12 +110,12 @@ namespace ngraph ...@@ -110,12 +110,12 @@ namespace ngraph
/// \param[in] clip The value defining clipping range [-clip, clip] on /// \param[in] clip The value defining clipping range [-clip, clip] on
/// input of activation functions. /// input of activation functions.
/// ///
RNNCell(const std::shared_ptr<Node>& X, RNNCell(const Output<Node>& X,
const std::shared_ptr<Node>& W, const Output<Node>& W,
const std::shared_ptr<Node>& R, const Output<Node>& R,
const std::shared_ptr<Node>& H_t, const Output<Node>& H_t,
std::size_t hidden_size, std::size_t hidden_size,
const std::shared_ptr<Node>& B, const Output<Node>& B,
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"}, const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
const std::vector<float>& activation_alpha = {}, const std::vector<float>& activation_alpha = {},
const std::vector<float>& activation_beta = {}, const std::vector<float>& activation_beta = {},
...@@ -127,7 +127,7 @@ namespace ngraph ...@@ -127,7 +127,7 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
private: private:
std::shared_ptr<Node> get_bias() const; Output<Node> get_bias() const;
/// brief Add and initialize bias input to all zeros. /// brief Add and initialize bias input to all zeros.
void add_default_bias_input(); void add_default_bias_input();
......
...@@ -23,10 +23,10 @@ using namespace ngraph; ...@@ -23,10 +23,10 @@ using namespace ngraph;
const string op::ScaleShift::type_name{"ScaleShift"}; const string op::ScaleShift::type_name{"ScaleShift"};
op::ScaleShift::ScaleShift(const std::shared_ptr<ngraph::Node>& data, op::ScaleShift::ScaleShift(const Output<Node>& data,
const std::shared_ptr<ngraph::Node>& scale, const Output<Node>& scale,
const std::shared_ptr<ngraph::Node>& shift) const Output<Node>& shift)
: FusedOp(check_single_output_args({data, scale, shift})) : FusedOp({data, scale, shift})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -34,14 +34,15 @@ namespace ngraph ...@@ -34,14 +34,15 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ScaleShift() = default;
/// \brief Constructs an ScaleShift operation. /// \brief Constructs an ScaleShift operation.
/// ///
/// \param data Input tensor /// \param data Input tensor
/// \param scale Input tensor that scale input data /// \param scale Input tensor that scale input data
/// \param shift Input tensor that shift input data /// \param shift Input tensor that shift input data
ScaleShift(const std::shared_ptr<ngraph::Node>& data, ScaleShift(const Output<Node>& data,
const std::shared_ptr<ngraph::Node>& scale, const Output<Node>& scale,
const std::shared_ptr<ngraph::Node>& shift); const Output<Node>& shift);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -23,10 +23,8 @@ using namespace ngraph; ...@@ -23,10 +23,8 @@ using namespace ngraph;
const string op::ShuffleChannels::type_name{"ShuffleChannels"}; const string op::ShuffleChannels::type_name{"ShuffleChannels"};
op::ShuffleChannels::ShuffleChannels(const shared_ptr<Node>& data, op::ShuffleChannels::ShuffleChannels(const Output<Node>& data, const int axis, const size_t groups)
const int axis, : FusedOp({data})
const size_t groups)
: FusedOp(check_single_output_args({data}))
, m_axis(axis) , m_axis(axis)
, m_groups{groups} , m_groups{groups}
{ {
...@@ -56,7 +54,7 @@ void op::ShuffleChannels::pre_validate_and_infer_types() ...@@ -56,7 +54,7 @@ void op::ShuffleChannels::pre_validate_and_infer_types()
{ {
if (get_input_partial_shape(0).is_static()) if (get_input_partial_shape(0).is_static())
{ {
const auto shape = get_argument(0)->get_shape(); const auto shape = input(0).get_shape();
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D."); this, shape.size() >= 1, "The input tensor's shape is expected to be at least 1D.");
...@@ -77,8 +75,8 @@ void op::ShuffleChannels::pre_validate_and_infer_types() ...@@ -77,8 +75,8 @@ void op::ShuffleChannels::pre_validate_and_infer_types()
NodeVector op::ShuffleChannels::decompose_op() const NodeVector op::ShuffleChannels::decompose_op() const
{ {
const auto data = get_argument(0); const auto data = input(0).get_source_output();
const auto& data_shape = data->get_shape(); const auto& data_shape = data.get_shape();
const auto reshaped = builder::reshape(data, get_pre_shuffle_shape(data_shape)); const auto reshaped = builder::reshape(data, get_pre_shuffle_shape(data_shape));
const auto shuffled = builder::reorder_axes(reshaped, {0, 2, 1, 3}); const auto shuffled = builder::reorder_axes(reshaped, {0, 2, 1, 3});
......
...@@ -32,12 +32,13 @@ namespace ngraph ...@@ -32,12 +32,13 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
ShuffleChannels() = default;
/// \brief Constructs a ShuffleChannels node. /// \brief Constructs a ShuffleChannels node.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param axis - channel dimension index in the data tensor. A negative value means that the index should be calculated from the back of the input data shape. /// \param axis - channel dimension index in the data tensor. A negative value means that the index should be calculated from the back of the input data shape.
/// \param groups - number of groups the channel dimension specified by axis should be split into /// \param groups - number of groups the channel dimension specified by axis should be split into
ShuffleChannels(const std::shared_ptr<ngraph::Node>& data, ShuffleChannels(const Output<Node>& data,
const int axis = 1, const int axis = 1,
const size_t groups = 1UL); const size_t groups = 1UL);
......
...@@ -26,8 +26,8 @@ using namespace ngraph; ...@@ -26,8 +26,8 @@ using namespace ngraph;
const string op::SpaceToDepth::type_name{"SpaceToDepth"}; const string op::SpaceToDepth::type_name{"SpaceToDepth"};
op::SpaceToDepth::SpaceToDepth(const shared_ptr<Node>& data, const size_t block_size) op::SpaceToDepth::SpaceToDepth(const Output<Node>& data, const size_t block_size)
: FusedOp(check_single_output_args({data})) : FusedOp({data})
, m_blocksize(block_size) , m_blocksize(block_size)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -35,8 +35,8 @@ op::SpaceToDepth::SpaceToDepth(const shared_ptr<Node>& data, const size_t block_ ...@@ -35,8 +35,8 @@ op::SpaceToDepth::SpaceToDepth(const shared_ptr<Node>& data, const size_t block_
NodeVector op::SpaceToDepth::decompose_op() const NodeVector op::SpaceToDepth::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
const Shape& data_shape = data->get_shape(); const Shape& data_shape = data.get_shape();
// Set default values to each dimension to be able to work with both 3D or 4D data. // Set default values to each dimension to be able to work with both 3D or 4D data.
size_t n{1}, c{1}, h{1}, w{1}; size_t n{1}, c{1}, h{1}, w{1};
...@@ -74,7 +74,7 @@ NodeVector op::SpaceToDepth::decompose_op() const ...@@ -74,7 +74,7 @@ NodeVector op::SpaceToDepth::decompose_op() const
// First we have to disperse the data from height and width channels, then // First we have to disperse the data from height and width channels, then
// rearrange them so as appropriate chunks of data where close to their // rearrange them so as appropriate chunks of data where close to their
// destination place. Finally squeeze data from respective dimensions. // destination place. Finally squeeze data from respective dimensions.
shared_ptr<Node> flat_node = builder::reshape(data, Shape{n, c, h_flat, bs, w_flat, bs}); Output<Node> flat_node = builder::reshape(data, Shape{n, c, h_flat, bs, w_flat, bs});
flat_node = builder::reorder_axes(flat_node, {0, 3, 5, 1, 2, 4}); flat_node = builder::reorder_axes(flat_node, {0, 3, 5, 1, 2, 4});
return NodeVector{builder::reshape(flat_node, Shape{n, c_high, h_flat, w_flat})}; return NodeVector{builder::reshape(flat_node, Shape{n, c_high, h_flat, w_flat})};
} }
......
...@@ -35,11 +35,12 @@ namespace ngraph ...@@ -35,11 +35,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
SpaceToDepth() = default;
/// \brief Constructs a SpaceToDepth operation. /// \brief Constructs a SpaceToDepth operation.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param block_size - the size of the block of values to be moved /// \param block_size - the size of the block of values to be moved
SpaceToDepth(const std::shared_ptr<ngraph::Node>& data, std::size_t block_size); SpaceToDepth(const Output<Node>& data, std::size_t block_size);
std::size_t get_block_size() const { return m_blocksize; } std::size_t get_block_size() const { return m_blocksize; }
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -23,8 +23,8 @@ using namespace ngraph; ...@@ -23,8 +23,8 @@ using namespace ngraph;
const string op::Split::type_name{"Split"}; const string op::Split::type_name{"Split"};
op::Split::Split(const shared_ptr<Node>& data, const int axis, const size_t num_split) op::Split::Split(const Output<Node>& data, const int axis, const size_t num_split)
: FusedOp(check_single_output_args({data})) : FusedOp({data})
, m_split_evenly{true} , m_split_evenly{true}
, m_axis{axis} , m_axis{axis}
, m_num_split{num_split} , m_num_split{num_split}
...@@ -32,10 +32,8 @@ op::Split::Split(const shared_ptr<Node>& data, const int axis, const size_t num_ ...@@ -32,10 +32,8 @@ op::Split::Split(const shared_ptr<Node>& data, const int axis, const size_t num_
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Split::Split(const std::shared_ptr<ngraph::Node>& data, op::Split::Split(const Output<Node>& data, const int axis, const std::vector<size_t>& splits)
const int axis, : FusedOp({data})
const std::vector<size_t>& splits)
: FusedOp(check_single_output_args({data}))
, m_split_evenly{false} , m_split_evenly{false}
, m_axis{axis} , m_axis{axis}
, m_splits{splits} , m_splits{splits}
...@@ -45,7 +43,7 @@ op::Split::Split(const std::shared_ptr<ngraph::Node>& data, ...@@ -45,7 +43,7 @@ op::Split::Split(const std::shared_ptr<ngraph::Node>& data,
void op::Split::pre_validate_and_infer_types() void op::Split::pre_validate_and_infer_types()
{ {
const auto shape = get_argument(0)->get_shape(); const auto shape = input(0).get_shape();
m_axis = adjust_axis_value(m_axis, shape.size()); m_axis = adjust_axis_value(m_axis, shape.size());
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -86,7 +84,7 @@ void op::Split::pre_validate_and_infer_types() ...@@ -86,7 +84,7 @@ void op::Split::pre_validate_and_infer_types()
NodeVector op::Split::decompose_op() const NodeVector op::Split::decompose_op() const
{ {
return builder::split(get_argument(0), m_splits, m_axis); return builder::split(input(0).get_source_output(), m_splits, m_axis);
} }
shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Split::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -33,23 +33,20 @@ namespace ngraph ...@@ -33,23 +33,20 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Split() = default;
/// \brief Constructs a Split op that evenly divides the input tensor. /// \brief Constructs a Split op that evenly divides the input tensor.
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split. Negative values mean counting from the back of the input tensor's shape. /// \param axis - indicates an axis along which the input tensor should be split. Negative values mean counting from the back of the input tensor's shape.
/// \param num_split - a number of "pieces" the input tensor will be split to /// \param num_split - a number of "pieces" the input tensor will be split to
Split(const std::shared_ptr<ngraph::Node>& data, Split(const Output<Node>& data, const int axis, const size_t num_split);
const int axis,
const size_t num_split);
/// \brief Constructs a Split op that splits the input tensor into variable length "pieces" /// \brief Constructs a Split op that splits the input tensor into variable length "pieces"
/// ///
/// \param data - Node producing the input tensor /// \param data - Node producing the input tensor
/// \param axis - indicates an axis along which the input tensor should be split. Negative values mean counting from the back of the input tensor's shape. /// \param axis - indicates an axis along which the input tensor should be split. Negative values mean counting from the back of the input tensor's shape.
/// \param splits - a list of lengths that the input tensor should be split to. Use this constructor to split the input tensor to variable length chunks. /// \param splits - a list of lengths that the input tensor should be split to. Use this constructor to split the input tensor to variable length chunks.
Split(const std::shared_ptr<ngraph::Node>& data, Split(const Output<Node>& data, const int axis, const std::vector<size_t>& splits);
const int axis,
const std::vector<size_t>& splits);
void pre_validate_and_infer_types() override; void pre_validate_and_infer_types() override;
......
...@@ -26,8 +26,8 @@ using namespace ngraph; ...@@ -26,8 +26,8 @@ using namespace ngraph;
const string op::SquaredDifference::type_name{"SquaredDifference"}; const string op::SquaredDifference::type_name{"SquaredDifference"};
op::SquaredDifference::SquaredDifference(const shared_ptr<Node>& x1, const shared_ptr<Node>& x2) op::SquaredDifference::SquaredDifference(const Output<Node>& x1, const Output<Node>& x2)
: FusedOp(check_single_output_args({x1, x2})) : FusedOp({x1, x2})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -33,12 +33,12 @@ namespace ngraph ...@@ -33,12 +33,12 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
SquaredDifference() = default;
/// \brief Constructs the squared difference operation. /// \brief Constructs the squared difference operation.
/// ///
/// \param x1 First input tensor /// \param x1 First input tensor
/// \param x2 Second input tensor /// \param x2 Second input tensor
SquaredDifference(const std::shared_ptr<ngraph::Node>& x1, SquaredDifference(const Output<Node>& x1, const Output<Node>& x2);
const std::shared_ptr<ngraph::Node>& x2);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -28,16 +28,16 @@ using namespace ngraph; ...@@ -28,16 +28,16 @@ using namespace ngraph;
const string op::Squeeze::type_name{"Squeeze"}; const string op::Squeeze::type_name{"Squeeze"};
op::Squeeze::Squeeze(const shared_ptr<Node>& data, const shared_ptr<Node>& axes) op::Squeeze::Squeeze(const Output<Node>& data, const Output<Node>& axes)
: FusedOp(check_single_output_args({data, axes})) : FusedOp({data, axes})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
NodeVector op::Squeeze::decompose_op() const NodeVector op::Squeeze::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto axes_node = get_argument(1); auto axes_node = input(1).get_source_output().get_node_shared_ptr();
// Currently only support Constant node for axes. // Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -48,7 +48,7 @@ NodeVector op::Squeeze::decompose_op() const ...@@ -48,7 +48,7 @@ NodeVector op::Squeeze::decompose_op() const
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node); auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data->get_shape(); auto data_shape = data.get_shape();
// Prepare set of unique axes marked to be removed from input data. // Prepare set of unique axes marked to be removed from input data.
if (axes.empty()) if (axes.empty())
......
...@@ -33,8 +33,8 @@ namespace ngraph ...@@ -33,8 +33,8 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Squeeze(const std::shared_ptr<ngraph::Node>& data, Squeeze() = default;
const std::shared_ptr<ngraph::Node>& axes); Squeeze(const Output<Node>& data, const Output<Node>& axes);
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -28,15 +28,15 @@ using namespace ngraph; ...@@ -28,15 +28,15 @@ using namespace ngraph;
const string op::Unsqueeze::type_name{"Unsqueeze"}; const string op::Unsqueeze::type_name{"Unsqueeze"};
op::Unsqueeze::Unsqueeze(const shared_ptr<Node>& data, const shared_ptr<Node>& axes) op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
: FusedOp(check_single_output_args({data, axes})) : FusedOp({data, axes})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void op::Unsqueeze::pre_validate_and_infer_types() void op::Unsqueeze::pre_validate_and_infer_types()
{ {
auto axes_node = get_argument(1); auto axes_node = input(1).get_source_output().get_node_shared_ptr();
// Currently only support Constant node for axes. // Currently only support Constant node for axes.
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
...@@ -46,14 +46,14 @@ void op::Unsqueeze::pre_validate_and_infer_types() ...@@ -46,14 +46,14 @@ void op::Unsqueeze::pre_validate_and_infer_types()
NodeVector op::Unsqueeze::decompose_op() const NodeVector op::Unsqueeze::decompose_op() const
{ {
auto data = get_argument(0); auto data = input(0).get_source_output();
auto axes_node = get_argument(1); auto axes_node = input(1).get_source_output().get_node_shared_ptr();
// Get value of axes from Constant // Get value of axes from Constant
auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node); auto axes_constant = dynamic_pointer_cast<op::Constant>(axes_node);
auto axes = axes_constant->get_vector<size_t>(); auto axes = axes_constant->get_vector<size_t>();
auto data_shape = data->get_shape(); auto data_shape = data.get_shape();
NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory."); NODE_VALIDATION_CHECK(this, !axes.empty(), "'axes' input is mandatory.");
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -33,8 +33,8 @@ namespace ngraph ...@@ -33,8 +33,8 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
Unsqueeze(const std::shared_ptr<ngraph::Node>& data, Unsqueeze() = default;
const std::shared_ptr<ngraph::Node>& axes); Unsqueeze(const Output<Node>& data, const Output<Node>& axes);
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment